Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions src/blueapi/client/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,44 +170,44 @@ def __init__(
self._pool = requests.Session()

def get_plans(self) -> PlanResponse:
return self._request_and_deserialize("/plans", PlanResponse)
return self._request_and_deserialize("/api/v1/plans", PlanResponse)
Copy link
Copy Markdown
Contributor

@ZohebShaikh ZohebShaikh Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    def _request_and_deserialize(
        self,
        suffix: str,
        target_type: type[T],
        data: Mapping[str, Any] | None = None,
        method="GET",
        get_exception: Callable[[requests.Response], Exception | None] = _exception,
        params: Mapping[str, Any] | None = None,
        prefix:str="/api/v1"
    ) -> T:
        url = self._config.url.unicode_string().removesuffix("/") + prefix + suffix

Something like this looks more maintainable ?

This endpoint will need the /api/v1 prefix as well for this to work

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would make this a smaller change but does bake in the assumption that v1 will always be the default. If we move to v2 being the main version in future we still need to make this change or accept that we will forever be passing prefix="/api/v2" in every call.

This endpoint will need the /api/v1 prefix as well for this to work

Why couldn't the client pass prefix="" when it was called?


def get_plan(self, name: str) -> PlanModel:
return self._request_and_deserialize(f"/plans/{name}", PlanModel)
return self._request_and_deserialize(f"/api/v1/plans/{name}", PlanModel)

def get_devices(self) -> DeviceResponse:
return self._request_and_deserialize("/devices", DeviceResponse)
return self._request_and_deserialize("/api/v1/devices", DeviceResponse)

def get_device(self, name: str) -> DeviceModel:
return self._request_and_deserialize(f"/devices/{name}", DeviceModel)
return self._request_and_deserialize(f"/api/v1/devices/{name}", DeviceModel)

def get_state(self) -> WorkerState:
return self._request_and_deserialize("/worker/state", WorkerState)
return self._request_and_deserialize("/api/v1/worker/state", WorkerState)

def set_state(
self,
state: Literal[WorkerState.RUNNING, WorkerState.PAUSED],
defer: bool | None = False,
):
return self._request_and_deserialize(
"/worker/state",
"/api/v1/worker/state",
target_type=WorkerState,
method="PUT",
data={"new_state": state, "defer": defer},
)

def get_task(self, task_id: str) -> TrackableTask:
return self._request_and_deserialize(f"/tasks/{task_id}", TrackableTask)
return self._request_and_deserialize(f"/api/v1/tasks/{task_id}", TrackableTask)

def get_all_tasks(self) -> TasksListResponse:
return self._request_and_deserialize("/tasks", TasksListResponse)
return self._request_and_deserialize("/api/v1/tasks", TasksListResponse)

def get_active_task(self) -> WorkerTask:
return self._request_and_deserialize("/worker/task", WorkerTask)
return self._request_and_deserialize("/api/v1/worker/task", WorkerTask)

def create_task(self, task: TaskRequest) -> TaskResponse:
return self._request_and_deserialize(
"/tasks",
"/api/v1/tasks",
TaskResponse,
method="POST",
get_exception=_create_task_exceptions,
Expand All @@ -216,12 +216,12 @@ def create_task(self, task: TaskRequest) -> TaskResponse:

def clear_task(self, task_id: str) -> TaskResponse:
return self._request_and_deserialize(
f"/tasks/{task_id}", TaskResponse, method="DELETE"
f"/api/v1/tasks/{task_id}", TaskResponse, method="DELETE"
)

def update_worker_task(self, task: WorkerTask) -> WorkerTask:
return self._request_and_deserialize(
"/worker/task",
"/api/v1/worker/task",
WorkerTask,
method="PUT",
data=task.model_dump(),
Expand All @@ -233,18 +233,18 @@ def cancel_current_task(
reason: str | None = None,
):
return self._request_and_deserialize(
"/worker/state",
"/api/v1/worker/state",
target_type=WorkerState,
method="PUT",
data={"new_state": state, "reason": reason},
)

def get_environment(self) -> EnvironmentResponse:
return self._request_and_deserialize("/environment", EnvironmentResponse)
return self._request_and_deserialize("/api/v1/environment", EnvironmentResponse)

def delete_environment(self) -> EnvironmentResponse:
return self._request_and_deserialize(
"/environment", EnvironmentResponse, method="DELETE"
"/api/v1/environment", EnvironmentResponse, method="DELETE"
)

def get_oidc_config(self) -> OIDCConfig | None:
Expand All @@ -258,7 +258,7 @@ def get_python_environment(
self, name: str | None = None, source: SourceInfo | None = None
) -> PythonEnvironmentResponse:
return self._request_and_deserialize(
"/python_environment",
"/api/v1/python_environment",
PythonEnvironmentResponse,
params={"name": name, "source": source},
)
Expand Down
48 changes: 27 additions & 21 deletions tests/unit_tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_get_plans(runner: CliRunner):

response = responses.add(
responses.GET,
"http://localhost:8000/plans",
"http://localhost:8000/api/v1/plans",
json=PlanResponse(plans=[PlanModel.from_plan(plan)]).model_dump(),
status=200,
)
Expand All @@ -176,7 +176,7 @@ def test_get_devices(runner: CliRunner):

response = responses.add(
responses.GET,
"http://localhost:8000/devices",
"http://localhost:8000/api/v1/devices",
json=DeviceResponse(devices=[DeviceModel.from_device(device)]).model_dump(),
status=200,
)
Expand Down Expand Up @@ -218,7 +218,7 @@ def test_submit_plan(runner: CliRunner):
}

response = responses.post(
url="http://a.fake.host:12345/tasks",
url="http://a.fake.host:12345/api/v1/tasks",
match=[matchers.json_params_matcher(body_data)],
)

Expand Down Expand Up @@ -268,7 +268,7 @@ def test_submit_plan_without_stomp(runner: CliRunner):
def test_run_plan(stomp_client: StompClient, runner: CliRunner):
task_id = "abcd-1234"
submit_response = responses.post(
url="http://a.fake.host:12345/tasks",
url="http://a.fake.host:12345/api/v1/tasks",
match=[
matchers.json_params_matcher(
{
Expand All @@ -282,7 +282,7 @@ def test_run_plan(stomp_client: StompClient, runner: CliRunner):
status=201,
)
run_response = responses.put(
url="http://a.fake.host:12345/worker/task",
url="http://a.fake.host:12345/api/v1/worker/task",
match=[matchers.json_params_matcher({"task_id": task_id})],
json={"task_id": task_id},
)
Expand Down Expand Up @@ -398,7 +398,7 @@ def test_run_plan_feedback(
@responses.activate
def test_run_plan_background_without_stomp(runner: CliRunner):
submit_response = responses.post(
url="http://a.fake.host:12345/tasks",
url="http://a.fake.host:12345/api/v1/tasks",
match=[
matchers.json_params_matcher(
{
Expand All @@ -412,7 +412,7 @@ def test_run_plan_background_without_stomp(runner: CliRunner):
status=201,
)
run_response = responses.put(
url="http://a.fake.host:12345/worker/task",
url="http://a.fake.host:12345/api/v1/worker/task",
match=[matchers.json_params_matcher({"task_id": "abcd-1234"})],
json={"task_id": "abcd-1234"},
)
Expand Down Expand Up @@ -541,7 +541,7 @@ def test_get_env(runner: CliRunner):
environment_id = uuid.uuid4()
responses.add(
responses.GET,
"http://localhost:8000/environment",
"http://localhost:8000/api/v1/environment",
json=EnvironmentResponse(
environment_id=environment_id, initialized=True
).model_dump(mode="json"),
Expand All @@ -559,7 +559,10 @@ def test_get_env(runner: CliRunner):
@responses.activate
def test_get_state(runner: CliRunner):
responses.add(
responses.GET, "http://localhost:8000/worker/state", json="IDLE", status=200
responses.GET,
"http://localhost:8000/api/v1/worker/state",
json="IDLE",
status=200,
)
state = runner.invoke(main, ["controller", "state"])
print(state.stderr)
Expand All @@ -576,7 +579,7 @@ def test_reset_env_client_behavior(
environment_id = uuid.uuid4()
responses.add(
responses.DELETE,
"http://localhost:8000/environment",
"http://localhost:8000/api/v1/environment",
json=EnvironmentResponse(
environment_id=environment_id, initialized=False
).model_dump(mode="json"),
Expand All @@ -588,7 +591,7 @@ def test_reset_env_client_behavior(
for state in env_state:
responses.add(
responses.GET,
"http://localhost:8000/environment",
"http://localhost:8000/api/v1/environment",
json=EnvironmentResponse(
environment_id=environment_id, initialized=state
).model_dump(mode="json"),
Expand All @@ -604,10 +607,10 @@ def test_reset_env_client_behavior(
for index, call in enumerate(responses.calls):
if index == 0:
assert call.request.method == "DELETE"
assert call.request.url == "http://localhost:8000/environment"
assert call.request.url == "http://localhost:8000/api/v1/environment"
else:
assert call.request.method == "GET"
assert call.request.url == "http://localhost:8000/environment"
assert call.request.url == "http://localhost:8000/api/v1/environment"

# Check if the final environment status is printed correctly
# assert "Environment is initialized." in result.output
Expand All @@ -625,7 +628,7 @@ def test_env_timeout(mock_sleep: Mock, runner: CliRunner):
environment_id = uuid.uuid4()
responses.add(
responses.DELETE,
"http://localhost:8000/environment",
"http://localhost:8000/api/v1/environment",
status=200,
json=EnvironmentResponse(
environment_id=environment_id, initialized=False
Expand All @@ -634,7 +637,7 @@ def test_env_timeout(mock_sleep: Mock, runner: CliRunner):
# Add responses for each polling attempt, all indicating not initialized
responses.add(
responses.GET,
"http://localhost:8000/environment",
"http://localhost:8000/api/v1/environment",
json=EnvironmentResponse(
environment_id=environment_id, initialized=False
).model_dump(mode="json"),
Expand All @@ -655,12 +658,12 @@ def test_env_timeout(mock_sleep: Mock, runner: CliRunner):

# First call should be DELETE
assert responses.calls[0].request.method == "DELETE"
assert responses.calls[0].request.url == "http://localhost:8000/environment"
assert responses.calls[0].request.url == "http://localhost:8000/api/v1/environment"

# Remaining calls should all be GET
for call in responses.calls[1:]: # Skip the first DELETE request # type: ignore
assert call.request.method == "GET"
assert call.request.url == "http://localhost:8000/environment"
assert call.request.url == "http://localhost:8000/api/v1/environment"

# Check the output for the timeout message
assert result.output == "Reloading environment\n"
Expand All @@ -673,7 +676,10 @@ def test_env_timeout(mock_sleep: Mock, runner: CliRunner):
def test_env_reload_server_side_error(runner: CliRunner):
# Setup mocked error response from the server
responses.add(
responses.DELETE, "http://localhost:8000/environment", status=500, json={}
responses.DELETE,
"http://localhost:8000/api/v1/environment",
status=500,
json={},
)

result = runner.invoke(main, ["controller", "env", "-r"])
Expand All @@ -687,7 +693,7 @@ def test_env_reload_server_side_error(runner: CliRunner):

# Only call should be DELETE
assert responses.calls[0].request.method == "DELETE"
assert responses.calls[0].request.url == "http://localhost:8000/environment"
assert responses.calls[0].request.url == "http://localhost:8000/api/v1/environment"

# Check the output for the timeout message
# TODO this seems wrong but this is the current behaviour
Expand Down Expand Up @@ -1279,7 +1285,7 @@ def test_get_python_environment(runner: CliRunner):
}
response = responses.add(
responses.GET,
"http://localhost:8000/python_environment",
"http://localhost:8000/api/v1/python_environment",
json=scratch_config,
status=200,
)
Expand All @@ -1302,7 +1308,7 @@ def test_get_python_env_with_empty_response(runner: CliRunner):
}
response = responses.add(
responses.GET,
"http://localhost:8000/python_environment",
"http://localhost:8000/api/v1/python_environment",
json=scratch_config,
status=200,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/client/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_auth_request_functionality(
environment_id = uuid.uuid4()
mock_authn_server.stop() # Cannot use multiple RequestsMock context manager
mock_get_env = mock_authn_server.get(
"http://localhost:8000/environment",
"http://localhost:8000/api/v1/environment",
json=EnvironmentResponse(
environment_id=environment_id, initialized=True
).model_dump(mode="json"),
Expand All @@ -143,7 +143,7 @@ def test_refresh_if_signature_expired(
environment_id = uuid.uuid4()
mock_authn_server.stop() # Cannot use multiple RequestsMock context manager
mock_get_env = mock_authn_server.get(
"http://localhost:8000/environment",
"http://localhost:8000/api/v1/environment",
json=EnvironmentResponse(
environment_id=environment_id, initialized=True
).model_dump(mode="json"),
Expand Down
Loading
Loading