Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 126 additions & 1 deletion tests/unit_tests/client/test_rest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import uuid
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, Mock, patch

import pytest
import requests
import responses
from packaging.version import Version
from responses import DELETE, GET, PUT, matchers

from blueapi import __version__
from blueapi.client.rest import (
Expand All @@ -20,7 +22,17 @@
)
from blueapi.config import OIDCConfig
from blueapi.service.authentication import SessionCacheManager, SessionManager
from blueapi.service.model import EnvironmentResponse
from blueapi.service.model import (
DeviceModel,
EnvironmentResponse,
PlanModel,
TaskResponse,
TasksListResponse,
WorkerTask,
)
from blueapi.worker.event import WorkerState
from blueapi.worker.task import Task
from blueapi.worker.task_worker import TrackableTask


@pytest.fixture
Expand Down Expand Up @@ -231,3 +243,116 @@ def test_server_and_client_versions(
)
else:
mock_logger.assert_not_called()


@pytest.mark.parametrize(
"method_name,args,http_method,path,response,result",
[
(
"get_plan",
("foo",),
GET,
"/plans/foo",
'{"name": "foo"}',
PlanModel(name="foo"),
),
(
"get_device",
("foo",),
GET,
"/devices/foo",
'{"name": "foo", "protocols": []}',
DeviceModel(name="foo", protocols=[]),
),
(
"get_task",
("foo",),
GET,
"/tasks/foo",
'{"task_id": "foo", "task": {"name": "bar"}}',
TrackableTask(task_id="foo", task=Task(name="bar")),
),
(
"get_all_tasks",
(),
GET,
"/tasks",
'{"tasks": [{"task_id": "foo", "task": {"name": "bar"}}]}',
TasksListResponse(
tasks=[TrackableTask(task_id="foo", task=Task(name="bar"))]
),
),
(
"get_active_task",
(),
GET,
"/worker/task",
'{"task_id": "foo"}',
WorkerTask(task_id="foo"),
),
(
"clear_task",
("foo",),
DELETE,
"/tasks/foo",
'{"task_id": "foo"}',
TaskResponse(task_id="foo"),
),
],
)
@responses.activate
def test_individual_endpoints(
rest: BlueapiRestClient,
# input args
method_name: str,
args: tuple[Any],
# setup args
http_method: str,
path: str,
response: str,
result: Any,
):
responses.add(http_method, "http://localhost:8000" + path, body=response)

method = getattr(rest, method_name)
actual = method(*args)
assert actual == result


@pytest.mark.parametrize(
"method_name,args,data,response,result",
[
(
"set_state",
(WorkerState.PAUSED,),
{"new_state": "PAUSED", "defer": False},
"PAUSED",
WorkerState.PAUSED,
),
(
"cancel_current_task",
(WorkerState.ABORTING, "no reason"),
{"new_state": "ABORTING", "reason": "no reason"},
"ABORTING",
WorkerState.ABORTING,
),
],
)
@responses.activate
def test_set_state(
rest: BlueapiRestClient,
method_name: str,
args: tuple[Any],
data: Any,
response: str,
result: Any,
):
responses.add(
PUT,
"http://localhost:8000/worker/state",
match=[matchers.json_params_matcher(data)],
json=response,
)
method = getattr(rest, method_name)
res = method(*args)
assert res == result
Loading