diff --git a/tests/unit_tests/client/test_rest.py b/tests/unit_tests/client/test_rest.py index 804fb5ee5..19d912694 100644 --- a/tests/unit_tests/client/test_rest.py +++ b/tests/unit_tests/client/test_rest.py @@ -1,11 +1,13 @@ import uuid from pathlib import Path +from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest import requests import responses from packaging.version import Version +from responses import DELETE, GET, PUT, matchers from blueapi import __version__ from blueapi.client.rest import ( @@ -20,7 +22,17 @@ ) from blueapi.config import OIDCConfig from blueapi.service.authentication import SessionCacheManager, SessionManager -from blueapi.service.model import EnvironmentResponse +from blueapi.service.model import ( + DeviceModel, + EnvironmentResponse, + PlanModel, + TaskResponse, + TasksListResponse, + WorkerTask, +) +from blueapi.worker.event import WorkerState +from blueapi.worker.task import Task +from blueapi.worker.task_worker import TrackableTask @pytest.fixture @@ -231,3 +243,116 @@ def test_server_and_client_versions( ) else: mock_logger.assert_not_called() + + +@pytest.mark.parametrize( + "method_name,args,http_method,path,response,result", + [ + ( + "get_plan", + ("foo",), + GET, + "/plans/foo", + '{"name": "foo"}', + PlanModel(name="foo"), + ), + ( + "get_device", + ("foo",), + GET, + "/devices/foo", + '{"name": "foo", "protocols": []}', + DeviceModel(name="foo", protocols=[]), + ), + ( + "get_task", + ("foo",), + GET, + "/tasks/foo", + '{"task_id": "foo", "task": {"name": "bar"}}', + TrackableTask(task_id="foo", task=Task(name="bar")), + ), + ( + "get_all_tasks", + (), + GET, + "/tasks", + '{"tasks": [{"task_id": "foo", "task": {"name": "bar"}}]}', + TasksListResponse( + tasks=[TrackableTask(task_id="foo", task=Task(name="bar"))] + ), + ), + ( + "get_active_task", + (), + GET, + "/worker/task", + '{"task_id": "foo"}', + WorkerTask(task_id="foo"), + ), + ( + "clear_task", + ("foo",), + DELETE, + "/tasks/foo", + '{"task_id": "foo"}', + TaskResponse(task_id="foo"), + ), + ], +) +@responses.activate +def test_individual_endpoints( + rest: BlueapiRestClient, + # input args + method_name: str, + args: tuple[Any], + # setup args + http_method: str, + path: str, + response: str, + result: Any, +): + responses.add(http_method, "http://localhost:8000" + path, body=response) + + method = getattr(rest, method_name) + actual = method(*args) + assert actual == result + + +@pytest.mark.parametrize( + "method_name,args,data,response,result", + [ + ( + "set_state", + (WorkerState.PAUSED,), + {"new_state": "PAUSED", "defer": False}, + "PAUSED", + WorkerState.PAUSED, + ), + ( + "cancel_current_task", + (WorkerState.ABORTING, "no reason"), + {"new_state": "ABORTING", "reason": "no reason"}, + "ABORTING", + WorkerState.ABORTING, + ), + ], +) +@responses.activate +def test_set_state( + rest: BlueapiRestClient, + method_name: str, + args: tuple[Any], + data: Any, + response: str, + result: Any, +): + responses.add( + PUT, + "http://localhost:8000/worker/state", + match=[matchers.json_params_matcher(data)], + json=response, + ) + method = getattr(rest, method_name) + res = method(*args) + assert res == result