diff --git a/docs/reference/openapi.yaml b/docs/reference/openapi.yaml index 499e8cc52..35146b01a 100644 --- a/docs/reference/openapi.yaml +++ b/docs/reference/openapi.yaml @@ -246,6 +246,26 @@ components: - name title: Task type: object + TaskError: + additionalProperties: false + description: Wrapper around an exception raised by a plan + properties: + message: + title: Message + type: string + outcome: + const: error + default: error + title: Outcome + type: string + type: + title: Type + type: string + required: + - type + - message + title: TaskError + type: object TaskRequest: additionalProperties: false description: Request to run a task with related info @@ -280,6 +300,31 @@ components: - task_id title: TaskResponse type: object + TaskResult: + additionalProperties: false + description: 'Serializable wrapper around the result of a plan + + + If the result is not serializable, the result will be None but the type + + will be the name of the type. If the result is actually None, the type will + + be ''NoneType''.' + properties: + outcome: + const: success + default: success + title: Outcome + type: string + result: + title: Result + type: + title: Type + type: string + required: + - type + title: TaskResult + type: object TasksListResponse: additionalProperties: false description: Diagnostic information on the tasks @@ -311,6 +356,12 @@ components: default: true title: Is Pending type: boolean + outcome: + anyOf: + - $ref: '#/components/schemas/TaskResult' + - $ref: '#/components/schemas/TaskError' + - type: 'null' + title: Outcome request_id: title: Request Id type: string @@ -377,7 +428,7 @@ info: name: Apache 2.0 url: https://www.apache.org/licenses/LICENSE-2.0.html title: BlueAPI Control - version: 1.1.2 + version: 1.2.0 openapi: 3.1.0 paths: /config/oidc: diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 4d974d35e..3036bc054 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -8,7 +8,7 @@ from functools import wraps from pathlib import Path from pprint import pprint -from typing import ParamSpec, TypeVar +from typing import ParamSpec, TypeVar, cast import click from bluesky.callbacks.best_effort import BestEffortCallback @@ -38,6 +38,7 @@ from blueapi.service.authentication import SessionCacheManager, SessionManager from blueapi.service.model import SourceInfo, TaskRequest from blueapi.worker import ProgressEvent, WorkerEvent +from blueapi.worker.event import TaskError, TaskResult from .scratch import setup_scratch from .updates import CliEventRenderer @@ -290,7 +291,7 @@ def run_plan( instrument_session: str, ) -> None: """Run a plan with parameters""" - client: BlueapiClient = obj["client"] + client: BlueapiClient = cast(BlueapiClient, obj["client"]) parameters = parameters or "{}" try: @@ -320,9 +321,15 @@ def on_event(event: AnyEvent) -> None: callback(event.name, event.doc) resp = client.run_task(task, on_event=on_event) - - if resp.task_status is not None and not resp.task_status.task_failed: - print("Plan Succeeded") + match resp.result: + case TaskResult(result=None, type="NoneType"): + print("Plan succeeded") + case TaskResult(result=None, type=t): + print(f"Plan returned unserializable result of type '{t}'") + case TaskResult(result=r): + print(f"Plan succeeded: {r}") + case TaskError(type=exc, message=m): + print(f"Plan failed: {exc}: {m}") else: server_task = client.create_and_start_task(task) click.echo(server_task.task_id) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 0930e240a..24448bc2b 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -28,7 +28,7 @@ from blueapi.worker import TrackableTask, WorkerEvent, WorkerState from blueapi.worker.event import ProgressEvent, TaskStatus -from .event_bus import AnyEvent, BlueskyStreamingError, EventBusClient, OnAnyEvent +from .event_bus import AnyEvent, EventBusClient, OnAnyEvent from .rest import BlueapiRestClient, BlueskyRemoteControlError TRACER = get_tracer("client") @@ -201,7 +201,7 @@ def run_task( task: TaskRequest, on_event: OnAnyEvent | None = None, timeout: float | None = None, - ) -> WorkerEvent: + ) -> TaskStatus: """ Synchronously run a task, requires a message bus connection @@ -224,7 +224,7 @@ def run_task( task_response = self.create_task(task) task_id = task_response.task_id - complete: Future[WorkerEvent] = Future() + complete: Future[TaskStatus] = Future() def inner_on_event(event: AnyEvent, ctx: MessageContext) -> None: match event: @@ -239,19 +239,19 @@ def inner_on_event(event: AnyEvent, ctx: MessageContext) -> None: if relates_to_task: if on_event is not None: on_event(event) - if isinstance(event, WorkerEvent) and ( - (event.is_complete()) and (ctx.correlation_id == task_id) + if ( + isinstance(event, WorkerEvent) + and (event.is_complete()) + and (ctx.correlation_id == task_id) ): - if event.task_status is not None and event.task_status.task_failed: + if event.task_status is None: complete.set_exception( - BlueskyStreamingError( - "\n".join(event.errors) - if len(event.errors) > 0 - else "Unknown error" + BlueskyRemoteControlError( + "Server completed without task status" ) ) else: - complete.set_result(event) + complete.set_result(event.task_status) with self._events: self._events.subscribe_to_all_events(inner_on_event) diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index 9bf29aec3..56fd163a2 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -120,7 +120,7 @@ class BlueskyContext: configuration: InitVar[ApplicationConfig | None] = None run_engine: RunEngine = field( - default_factory=lambda: RunEngine(context_managers=[]) + default_factory=lambda: RunEngine(context_managers=[], call_returns_result=True) ) tiled_conf: TiledConfig | None = field(default=None, init=False, repr=False) numtracker: NumtrackerClient | None = field(default=None, init=False, repr=False) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 5aa44c533..6f2acbd15 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -58,7 +58,7 @@ from .runner import WorkerDispatcher #: API version to publish in OpenAPI schema -REST_API_VERSION = "1.1.2" +REST_API_VERSION = "1.2.0" LICENSE_INFO: dict[str, str] = { "name": "Apache 2.0", diff --git a/src/blueapi/utils/base_model.py b/src/blueapi/utils/base_model.py index 142052a8c..1842ee02d 100644 --- a/src/blueapi/utils/base_model.py +++ b/src/blueapi/utils/base_model.py @@ -1,5 +1,11 @@ -from pydantic import BaseModel, ConfigDict +import logging +from pydantic import ( + BaseModel, + ConfigDict, +) + +logger = logging.getLogger(__name__) # Pydantic config for blueapi API models with common config. BlueapiModelConfig = ConfigDict( extra="forbid", diff --git a/src/blueapi/worker/event.py b/src/blueapi/worker/event.py index 2004a1a63..bfba07bd6 100644 --- a/src/blueapi/worker/event.py +++ b/src/blueapi/worker/event.py @@ -1,8 +1,10 @@ +import logging from collections.abc import Mapping from enum import StrEnum +from typing import Any, Literal, Self from bluesky.run_engine import RunEngineStateMachine -from pydantic import Field +from pydantic import Field, PydanticSchemaGenerationError, TypeAdapter from super_state_machine.extras import PropertyMachine, ProxyString from blueapi.utils import BlueapiBaseModel @@ -11,6 +13,8 @@ # RawRunEngineState = type[PropertyMachine | ProxyString | str] RawRunEngineState = PropertyMachine | ProxyString | str +log = logging.getLogger(__name__) + # NOTE this is interim until refactor class TaskStatusEnum(StrEnum): @@ -52,6 +56,47 @@ def from_bluesky_state(cls, bluesky_state: RawRunEngineState) -> "WorkerState": return WorkerState(str(bluesky_state).upper()) +class TaskResult(BlueapiBaseModel): + """ + Serializable wrapper around the result of a plan + + If the result is not serializable, the result will be None but the type + will be the name of the type. If the result is actually None, the type will + be 'NoneType'. + """ + + outcome: Literal["success"] = "success" + """Discriminant for serialization""" + result: Any = Field(None) + """The serialized result (or None if it is not serializable)""" + type: str + """The type of the result""" + + @classmethod + def from_result(cls, result: Any) -> Self: + type_str = type(result).__name__ + try: + value = TypeAdapter(type(result)).dump_python(result) + except PydanticSchemaGenerationError: + value = None + return cls(result=value, type=type_str) + + +class TaskError(BlueapiBaseModel): + """Wrapper around an exception raised by a plan""" + + outcome: Literal["error"] = "error" + """Discriminant for serialization""" + type: str + """The class of exception""" + message: str + """The message of the raised exception""" + + @classmethod + def from_exception(cls, err: Exception) -> Self: + return cls(type=type(err).__name__, message=str(err)) + + class StatusView(BlueapiBaseModel): """ A snapshot of a Status of an operation, optionally representing progress @@ -107,6 +152,7 @@ class TaskStatus(BlueapiBaseModel): """ task_id: str + result: TaskResult | TaskError | None = Field(None, discriminator="outcome") task_complete: bool task_failed: bool diff --git a/src/blueapi/worker/task.py b/src/blueapi/worker/task.py index 32060bc64..9ce373c76 100644 --- a/src/blueapi/worker/task.py +++ b/src/blueapi/worker/task.py @@ -38,7 +38,11 @@ def do_task(self, ctx: BlueskyContext) -> None: func = ctx.plan_functions[self.name] prepared_params = self.prepare_params(ctx) ctx.run_engine.md.update(self.metadata) - ctx.run_engine(func(**prepared_params)) + result = ctx.run_engine(func(**prepared_params)) + if isinstance(result, tuple): # pragma: no cover + # this is never true if the run_engine is configured correctly + return None + return result.plan_result def _lookup_params(ctx: BlueskyContext, task: Task) -> BaseModel: diff --git a/src/blueapi/worker/task_worker.py b/src/blueapi/worker/task_worker.py index e1ed4970a..c30fc0220 100644 --- a/src/blueapi/worker/task_worker.py +++ b/src/blueapi/worker/task_worker.py @@ -39,6 +39,8 @@ ProgressEvent, RawRunEngineState, StatusView, + TaskError, + TaskResult, TaskStatus, TaskStatusEnum, WorkerEvent, @@ -69,6 +71,13 @@ class TrackableTask(BlueapiBaseModel): is_complete: bool = False is_pending: bool = True errors: list[str] = Field(default_factory=list) + outcome: TaskResult | TaskError | None = None + + def set_result(self, result: Any): + self.outcome = TaskResult.from_result(result) + + def set_exception(self, err: Exception): + self.outcome = TaskError.from_exception(err) class TaskWorker: @@ -427,7 +436,12 @@ def process_task(): LOGGER.info(f"Got new task: {next_task}") self._current = next_task self._current.is_pending = False - self._current.task.do_task(self._ctx) + try: + result = self._current.task.do_task(self._ctx) + self._current.set_result(result) + except Exception as e: + self._current.set_exception(e) + self._report_error(e) with plan_tag_filter_context(next_task.task.name, LOGGER): if self._current_task_otel_context is not None: @@ -528,6 +542,7 @@ def _report_status( task_id=self._current.task_id, task_complete=self._current.is_complete, task_failed=bool(self._current.errors), + result=self._current.outcome, ) correlation_id = self._current.task_id add_span_attributes( diff --git a/tests/system_tests/test_blueapi_system.py b/tests/system_tests/test_blueapi_system.py index f2d6395dc..9b26e22ec 100644 --- a/tests/system_tests/test_blueapi_system.py +++ b/tests/system_tests/test_blueapi_system.py @@ -12,7 +12,7 @@ from requests.exceptions import ConnectionError from blueapi.client import BlueapiClient -from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError +from blueapi.client.event_bus import AnyEvent from blueapi.client.rest import ( BlueskyRemoteControlError, BlueskyRequestError, @@ -31,7 +31,13 @@ TaskResponse, WorkerTask, ) -from blueapi.worker.event import TaskStatus, WorkerEvent, WorkerState +from blueapi.worker.event import ( + TaskError, + TaskResult, + TaskStatus, + WorkerEvent, + WorkerState, +) from blueapi.worker.task_worker import TrackableTask AUTHORIZED_INSTRUMENT_SESSION = "cm12345-1" @@ -418,6 +424,7 @@ def on_event(event: AnyEvent): task_id=task_id, task_complete=False, task_failed=False, + result=None, ), ), WorkerEvent( @@ -426,6 +433,7 @@ def on_event(event: AnyEvent): task_id=task_id, task_complete=False, task_failed=False, + result=None, ), ), WorkerEvent( @@ -434,6 +442,7 @@ def on_event(event: AnyEvent): task_id=task_id, task_complete=True, task_failed=False, + result=TaskResult(result=None, type="NoneType"), ), ), ] @@ -510,8 +519,9 @@ def on_event(event: AnyEvent) -> None: resource.put_nowait(event.doc) final_event = client_with_stomp.run_task(task, on_event) - assert final_event.is_complete() and not final_event.is_error() - assert final_event.state is WorkerState.IDLE + assert isinstance(final_event.result, TaskResult) + assert final_event.task_complete + assert not final_event.task_failed start_doc = start.get_nowait() assert start_doc["scan_id"] == scan_id @@ -558,8 +568,9 @@ def on_event(event: AnyEvent) -> None: ) def test_stub_runs(client_with_stomp: BlueapiClient, task: TaskRequest): final_event = client_with_stomp.run_task(task) - assert final_event.is_complete() and not final_event.is_error() - assert final_event.state is WorkerState.IDLE + assert isinstance(final_event.result, TaskResult) + assert final_event.task_complete + assert not final_event.task_failed @pytest.mark.parametrize( @@ -580,7 +591,7 @@ def test_stub_runs(client_with_stomp: BlueapiClient, task: TaskRequest): ), ], ) -def test_unauthozied_plan_run( +def test_unauthorized_plan_run( client_with_stomp: BlueapiClient, task: TaskRequest, scan_id: int ): resource = Queue(maxsize=1) @@ -593,8 +604,9 @@ def on_event(event: AnyEvent) -> None: if event.name == "stream_resource": resource.put_nowait(event.doc) - with pytest.raises( - BlueskyStreamingError, - match="404: No such entry", - ): - client_with_stomp.run_task(task, on_event) + outcome = client_with_stomp.run_task(task, on_event) + assert outcome.task_failed + assert outcome.task_complete + assert isinstance(outcome.result, TaskError) + assert outcome.result.type == "ClientError" + assert outcome.result.message.startswith("404: No such entry: [") diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index d6d2e1f22..8bed8927d 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -10,7 +10,7 @@ ) from blueapi.client import BlueapiClient -from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient +from blueapi.client.event_bus import AnyEvent, EventBusClient from blueapi.client.rest import BlueapiRestClient, BlueskyRemoteControlError from blueapi.config import MissingStompConfigurationError from blueapi.core import DataEvent @@ -26,7 +26,7 @@ WorkerTask, ) from blueapi.worker import ProgressEvent, Task, TrackableTask, WorkerEvent, WorkerState -from blueapi.worker.event import TaskStatus +from blueapi.worker.event import TaskError, TaskResult, TaskStatus PLANS = PlanResponse( plans=[ @@ -55,6 +55,7 @@ task_id="foo", task_complete=True, task_failed=False, + result=TaskResult(type="NoneType", result=None), ), ) FAILED_EVENT = WorkerEvent( @@ -63,6 +64,7 @@ task_id="foo", task_complete=True, task_failed=True, + result=TaskError(type="PlanFailure", message="The plan failed"), ), ) @@ -431,10 +433,15 @@ def test_run_task_fails_on_failing_event( mock_events.subscribe_to_all_events = lambda on_event: on_event(FAILED_EVENT, ctx) on_event = Mock() - with pytest.raises(BlueskyStreamingError): - client_with_events.run_task( - TaskRequest(name="foo", instrument_session="cm12345-1"), on_event=on_event - ) + outcome = client_with_events.run_task( + TaskRequest(name="foo", instrument_session="cm12345-1"), + on_event=on_event, + ) + assert outcome.task_failed + assert outcome.task_complete + assert isinstance(outcome.result, TaskError) + assert outcome.result.message == "The plan failed" + assert outcome.result.type == "PlanFailure" on_event.assert_called_with(FAILED_EVENT) @@ -448,6 +455,7 @@ def test_run_task_fails_on_failing_event( task_id="foo", task_complete=False, task_failed=False, + result=TaskError(type="ValueError", message="Task failed"), ), ), ProgressEvent(task_id="foo"), @@ -489,6 +497,7 @@ def callback(on_event: Callable[[AnyEvent, MessageContext], None]): task_id="bar", task_complete=False, task_failed=False, + result=None, ), ), ProgressEvent(task_id="bar"), diff --git a/tests/unit_tests/service/test_interface.py b/tests/unit_tests/service/test_interface.py index da6f1a4d9..77581e7b1 100644 --- a/tests/unit_tests/service/test_interface.py +++ b/tests/unit_tests/service/test_interface.py @@ -43,7 +43,13 @@ from blueapi.utils.invalid_config_error import InvalidConfigError from blueapi.utils.numtracker import NumtrackerClient from blueapi.utils.path_provider import StartDocumentPathProvider -from blueapi.worker.event import TaskStatus, TaskStatusEnum, WorkerEvent, WorkerState +from blueapi.worker.event import ( + TaskResult, + TaskStatus, + TaskStatusEnum, + WorkerEvent, + WorkerState, +) from blueapi.worker.task import Task from blueapi.worker.task_worker import TrackableTask @@ -392,6 +398,7 @@ def test_remove_tiled_subscriber(worker, context, from_uri, writer): task_id="foo_bar", task_complete=False, task_failed=False, + result=None, ), ), "c_id", @@ -406,6 +413,7 @@ def test_remove_tiled_subscriber(worker, context, from_uri, writer): task_id="foo_bar", task_complete=True, task_failed=False, + result=TaskResult(result=None, type="NoneType"), ), ), "c_id", diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index b0bc09dd3..d9b571692 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -336,6 +336,7 @@ def test_get_tasks(mock_runner: Mock, client: TestClient) -> None: "params": {"time": 0.0}, "metadata": {}, }, + "outcome": None, "task_id": "0", }, { @@ -348,6 +349,7 @@ def test_get_tasks(mock_runner: Mock, client: TestClient) -> None: "params": {}, "metadata": {}, }, + "outcome": None, "task_id": "1", }, ] @@ -379,6 +381,7 @@ def test_get_tasks_by_status(mock_runner: Mock, client: TestClient) -> None: "params": {}, "metadata": {}, }, + "outcome": None, "task_id": "3", } ] @@ -472,6 +475,7 @@ def test_get_task(mock_runner: Mock, client: TestClient): "foo": "bar", }, }, + "outcome": None, "task_id": f"{task_id}", } @@ -500,6 +504,7 @@ def test_get_all_tasks(mock_runner: Mock, client: TestClient): "is_complete": False, "is_pending": True, "request_id": None, + "outcome": None, "errors": [], } ] diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 2a49a1fe8..96e14dc3d 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -8,6 +8,7 @@ from pathlib import Path from textwrap import dedent from typing import Any, TypeVar +from unittest import mock from unittest.mock import Mock, patch import pytest @@ -52,7 +53,14 @@ TaskRequest, TaskResponse, ) -from blueapi.worker.event import ProgressEvent, TaskStatus, WorkerEvent, WorkerState +from blueapi.worker.event import ( + ProgressEvent, + TaskError, + TaskResult, + TaskStatus, + WorkerEvent, + WorkerState, +) @pytest.fixture(autouse=True) @@ -302,7 +310,10 @@ def mock_events(topic: MessageTopic, callback: Callable[[Any, Any], Any]): WorkerEvent( state=WorkerState.RUNNING, task_status=TaskStatus( - task_id=task_id, task_complete=False, task_failed=False + task_id=task_id, + task_complete=False, + task_failed=False, + result=None, ), ), ctx, @@ -313,7 +324,7 @@ def mock_events(topic: MessageTopic, callback: Callable[[Any, Any], Any]): WorkerEvent( state=WorkerState.IDLE, task_status=TaskStatus( - task_id=task_id, task_complete=False, task_failed=False + task_id=task_id, task_complete=False, task_failed=False, result=None ), ), ctx, @@ -322,7 +333,10 @@ def mock_events(topic: MessageTopic, callback: Callable[[Any, Any], Any]): WorkerEvent( state=WorkerState.IDLE, task_status=TaskStatus( - task_id=task_id, task_complete=True, task_failed=False + task_id=task_id, + task_complete=True, + task_failed=False, + result=TaskResult(result=None, type="NoneType"), ), ), ctx, @@ -351,6 +365,50 @@ def mock_events(topic: MessageTopic, callback: Callable[[Any, Any], Any]): assert run_response.call_count == 1 +@pytest.mark.parametrize( + "result,failed,message", + [ + (TaskResult(result=None, type="NoneType"), False, "Plan succeeded\n"), + (TaskResult(result=32, type="int"), False, "Plan succeeded: 32\n"), + ( + TaskResult(result=None, type="CustomType"), + False, + "Plan returned unserializable result of type 'CustomType'\n", + ), + ( + TaskError(type="ValueError", message="Error with value"), + True, + "Plan failed: ValueError: Error with value\n", + ), + ], +) +@patch("blueapi.cli.cli.BlueapiClient") +def test_run_plan_feedback( + mock_client: Mock, + runner: CliRunner, + result: TaskResult | TaskError | None, + failed: bool, + message: str, +): + bc = mock_client.from_config() + bc.run_task.return_value = TaskStatus( + task_id="foo_bar", + task_complete=True, + task_failed=failed, + result=result, + ) + res = runner.invoke( + main, + ["controller", "run", "-i", "cm12345-1", "name"], + ) + bc.run_task.assert_called_once_with( + TaskRequest(name="name", params={}, instrument_session="cm12345-1"), + on_event=mock.ANY, + ) + assert res.exit_code == 0 + assert res.stdout == message + + @responses.activate def test_run_plan_background_without_stomp(runner: CliRunner): submit_response = responses.post( @@ -897,7 +955,9 @@ def test_event_formatting(): ) worker = WorkerEvent( state=WorkerState.RUNNING, - task_status=TaskStatus(task_id="count", task_complete=False, task_failed=False), + task_status=TaskStatus( + task_id="count", task_complete=False, task_failed=False, result=None + ), errors=[], warnings=[], ) @@ -929,9 +989,12 @@ def test_event_formatting(): OutputFormat.JSON, worker, ( - """{"state": "RUNNING", "task_status": """ - """{"task_id": "count", "task_complete": false, "task_failed": false}, """ - """"errors": [], "warnings": []}\n""" + '{"state": "RUNNING", "task_status": {' + '"task_id": "count", ' + '"result": null, ' + '"task_complete": false, ' + '"task_failed": false' + '}, "errors": [], "warnings": []}\n' ), ) _assert_matching_formatting(OutputFormat.COMPACT, worker, "Worker Event: RUNNING\n") diff --git a/tests/unit_tests/worker/test_task_worker.py b/tests/unit_tests/worker/test_task_worker.py index 2793e8807..5b1cb6f71 100644 --- a/tests/unit_tests/worker/test_task_worker.py +++ b/tests/unit_tests/worker/test_task_worker.py @@ -35,7 +35,7 @@ WorkerEvent, WorkerState, ) -from blueapi.worker.event import TaskStatusEnum +from blueapi.worker.event import TaskResult, TaskStatusEnum _SIMPLE_TASK = Task(name="sleep", params={"time": 0.0}) _LONG_TASK = Task(name="sleep", params={"time": 1.0}) @@ -79,6 +79,23 @@ def failing_plan() -> MsgGenerator: raise KeyError("I failed") +@dataclasses.dataclass +class ComplexReturn: + foo: int + bar: list[str] + + +class ModelReturn(pydantic.BaseModel): + foo: int + bar: list[str] + + +class Unreturnable: + def __init__(self, foo: int, bar: list[str]): + self.foo = foo + self.bar = bar + + @pytest.fixture def fake_device() -> FakeDevice: return FakeDevice() @@ -269,6 +286,42 @@ def test_begin_task_uses_plan_name_filter( filter_mock.assert_called_once() +@pytest.mark.parametrize( + "result,serial,type_name", + [ + (None, None, "NoneType"), + (42, 42, "int"), + ("helloWorld", "helloWorld", "str"), + (ComplexReturn(34, ["foo"]), {"foo": 34, "bar": ["foo"]}, "ComplexReturn"), + ( + ModelReturn(foo=42, bar=["helloWorld"]), + {"foo": 42, "bar": ["helloWorld"]}, + "ModelReturn", + ), + (Unreturnable(17, ["fizzbuzz"]), None, "Unreturnable"), + ], +) +def test_plan_result_serialized( + worker: TaskWorker, result: Any, serial: Any, type_name: str +): + def result_plan() -> MsgGenerator: + yield from [] + return result + + worker._ctx.register_plan(result_plan) + task_id = worker.submit_task(Task(name="result_plan")) + events = begin_task_and_wait_until_complete(worker, task_id) + ts = events[-1].task_status + assert ts is not None + assert ts.result == TaskResult(result=serial, type=type_name) + + assert ts.model_dump()["result"] == { + "outcome": "success", + "result": serial, + "type": type_name, + } + + def test_plan_failure_recorded_in_active_task(worker: TaskWorker) -> None: task_id = worker.submit_task(_FAILING_TASK) events_future: Future[list[WorkerEvent]] = take_events( @@ -313,7 +366,10 @@ def _sleep_events(task_id: str) -> list[WorkerEvent]: WorkerEvent( state=WorkerState.RUNNING, task_status=TaskStatus( - task_id=task_id, task_complete=False, task_failed=False + task_id=task_id, + task_complete=False, + task_failed=False, + result=None, ), errors=[], warnings=[], @@ -321,7 +377,10 @@ def _sleep_events(task_id: str) -> list[WorkerEvent]: WorkerEvent( state=WorkerState.IDLE, task_status=TaskStatus( - task_id=task_id, task_complete=False, task_failed=False + task_id=task_id, + task_complete=False, + task_failed=False, + result=None, ), errors=[], warnings=[], @@ -329,7 +388,10 @@ def _sleep_events(task_id: str) -> list[WorkerEvent]: WorkerEvent( state=WorkerState.IDLE, task_status=TaskStatus( - task_id=task_id, task_complete=True, task_failed=False + task_id=task_id, + task_complete=True, + task_failed=False, + result=TaskResult.from_result(None), ), errors=[], warnings=[], @@ -405,7 +467,10 @@ def test_worker_and_data_events_produce_in_order( WorkerEvent( state=WorkerState.RUNNING, task_status=TaskStatus( - task_id="count", task_complete=False, task_failed=False + task_id="count", + task_complete=False, + task_failed=False, + result=None, ), errors=[], warnings=[], @@ -417,7 +482,10 @@ def test_worker_and_data_events_produce_in_order( WorkerEvent( state=WorkerState.IDLE, task_status=TaskStatus( - task_id="count", task_complete=False, task_failed=False + task_id="count", + task_complete=False, + task_failed=False, + result=None, ), errors=[], warnings=[], @@ -425,7 +493,10 @@ def test_worker_and_data_events_produce_in_order( WorkerEvent( state=WorkerState.IDLE, task_status=TaskStatus( - task_id="count", task_complete=True, task_failed=False + task_id="count", + task_complete=True, + task_failed=False, + result=TaskResult.from_result(None), ), errors=[], warnings=[], @@ -702,13 +773,16 @@ def test_cycle_without_otel_context(mock_logger: Mock, inert_worker: TaskWorker) task = TrackableTask(task_id="0", task=_SIMPLE_TASK) inert_worker._task_channel.put_nowait(task) inert_worker._pending_tasks["0"] = task + + # task changes during cycle so need to cache log message first + expected = f"Got new task: {task}" inert_worker._cycle() assert inert_worker._current_task_otel_context is None # Bad way to tell that this branch ahs been run, but I can't think of a better way # Have to set these values to match output task.is_complete = False task.is_pending = True - mock_logger.info.assert_called_with(f"Got new task: {task}") + mock_logger.info.assert_called_with(expected) class MyComposite(BlueapiBaseModel):