diff --git a/src/dstack/_internal/core/backends/vastai/api_client.py b/src/dstack/_internal/core/backends/vastai/api_client.py index 1d8e4d8f3..f56a2f7ce 100644 --- a/src/dstack/_internal/core/backends/vastai/api_client.py +++ b/src/dstack/_internal/core/backends/vastai/api_client.py @@ -1,3 +1,4 @@ +import json import threading import time from typing import List, Optional, Union @@ -10,10 +11,14 @@ from dstack._internal.core.errors import NoCapacityError from dstack._internal.core.models.common import RegistryAuth +# v1 instances list is paginated; max enforced server-side is 25 per page. +_V1_INSTANCES_PAGE_LIMIT = 25 + class VastAIAPIClient: def __init__(self, api_key: str): - self.api_url = "https://console.vast.ai/api/v0".rstrip("/") + self.api_url_v0 = "https://console.vast.ai/api/v0" + self.api_url_v1 = "https://console.vast.ai/api/v1" self.api_key = api_key self.s = requests.Session() # TODO: set adequate timeout everywhere the session is used retries = Retry( @@ -21,13 +26,15 @@ def __init__(self, api_key: str): backoff_factor=1, status_forcelist=[429, 504], ) - self.s.mount(prefix=self._url("/instances/"), adapter=HTTPAdapter(max_retries=retries)) + adapter = HTTPAdapter(max_retries=retries) + self.s.mount(prefix=f"{self.api_url_v0}/instances/", adapter=adapter) + self.s.mount(prefix=f"{self.api_url_v1}/instances/", adapter=adapter) self.lock = threading.Lock() self.instances_cache_ts: float = 0 self.instances_cache: List[dict] = [] def get_bundle(self, bundle_id: Union[str, int]) -> Optional[dict]: - resp = self.s.post(self._url("/bundles/"), json={"id": {"eq": bundle_id}}) + resp = self.s.post(self._url_v0("/bundles/"), json={"id": {"eq": bundle_id}}) resp.raise_for_status() data = resp.json() offers = data["offers"] @@ -80,7 +87,7 @@ def create_instance( "create_from": None, "force": False, } - resp = self.s.put(self._url(f"/asks/{bundle_id}/"), json=payload) + resp = self.s.put(self._url_v0(f"/asks/{bundle_id}/"), json=payload) if resp.status_code != 200 or not (data := resp.json())["success"]: raise NoCapacityError(resp.text) self._invalidate_cache() @@ -94,7 +101,7 @@ def destroy_instance(self, instance_id: Union[str, int]) -> bool: Returns: True if instance was destroyed successfully """ - resp = self.s.delete(self._url(f"/instances/{instance_id}/")) + resp = self.s.delete(self._url_v0(f"/instances/{instance_id}/")) if resp.status_code != 200 or not resp.json()["success"]: return False self._invalidate_cache() @@ -103,23 +110,19 @@ def destroy_instance(self, instance_id: Union[str, int]) -> bool: def get_instances(self, cache_ttl: float = 3.0) -> List[dict]: with self.lock: if time.time() - self.instances_cache_ts > cache_ttl: - resp = self.s.get(self._url("/instances/")) - resp.raise_for_status() - data = resp.json() + self.instances_cache = self._list_instances_v1() self.instances_cache_ts = time.time() - self.instances_cache = data["instances"] return self.instances_cache def get_instance(self, instance_id: Union[str, int]) -> Optional[dict]: - instances = self.get_instances() - for instance in instances: - if instance["id"] == int(instance_id): - return instance - return None + instances = self._list_instances_v1( + select_filters={"id": {"eq": int(instance_id)}}, limit=1 + ) + return instances[0] if instances else None def request_logs(self, instance_id: Union[str, int]) -> dict: resp = self.s.put( - self._url(f"/instances/request_logs/{instance_id}/"), json={"tail": "1000"} + self._url_v0(f"/instances/request_logs/{instance_id}/"), json={"tail": "1000"} ) resp.raise_for_status() data = resp.json() @@ -134,8 +137,39 @@ def auth_test(self) -> bool: except requests.HTTPError: return False - def _url(self, path): - return f"{self.api_url}/{path.lstrip('/')}?api_key={self.api_key}" + def _list_instances_v1( + self, + select_filters: Optional[dict] = None, + limit: int = _V1_INSTANCES_PAGE_LIMIT, + ) -> List[dict]: + """Page through the v1 instances endpoint and return all matches. + + The v1 endpoint enforces keyset pagination with a max of 25 results + per response, so we follow `next_token` until the server stops + returning one. + """ + page_limit = max(1, min(limit, _V1_INSTANCES_PAGE_LIMIT)) + params: dict = { + "limit": page_limit, + "select_filters": json.dumps(select_filters or {}), + } + instances: List[dict] = [] + while True: + resp = self.s.get(self._url_v1("/instances/"), params=params) + resp.raise_for_status() + data = resp.json() + instances.extend(data.get("instances", [])) + next_token = data.get("next_token") + if not next_token: + break + params["after_token"] = next_token + return instances + + def _url_v0(self, path: str) -> str: + return f"{self.api_url_v0}/{path.lstrip('/')}?api_key={self.api_key}" + + def _url_v1(self, path: str) -> str: + return f"{self.api_url_v1}/{path.lstrip('/')}?api_key={self.api_key}" def _invalidate_cache(self): with self.lock: diff --git a/src/tests/_internal/core/backends/vastai/test_api_client.py b/src/tests/_internal/core/backends/vastai/test_api_client.py new file mode 100644 index 000000000..be9d2f469 --- /dev/null +++ b/src/tests/_internal/core/backends/vastai/test_api_client.py @@ -0,0 +1,108 @@ +import json +from urllib.parse import parse_qs, urlparse + +import pytest + +from dstack._internal.core.backends.vastai.api_client import VastAIAPIClient + + +class _FakeResponse: + def __init__(self, payload, status_code=200): + self._payload = payload + self.status_code = status_code + self.text = json.dumps(payload) + + def json(self): + return self._payload + + def raise_for_status(self): + if self.status_code >= 400: + raise AssertionError(f"unexpected status {self.status_code}") + + +@pytest.fixture +def client(): + return VastAIAPIClient(api_key="test-key") + + +def _parse_call(call): + """Return (path, query_dict) for a recorded get() call.""" + url = call.args[0] if call.args else call.kwargs["url"] + parsed = urlparse(url) + query = {k: v[0] for k, v in parse_qs(parsed.query).items()} + if "params" in call.kwargs and call.kwargs["params"]: + for k, v in call.kwargs["params"].items(): + query[k] = str(v) + return parsed.path, query + + +def test_get_instances_uses_v1_endpoint_and_paginates(client, monkeypatch): + pages = [ + {"instances": [{"id": 1}, {"id": 2}], "next_token": "tok-1"}, + {"instances": [{"id": 3}], "next_token": None}, + ] + calls = [] + + def fake_get(url, params=None): + calls.append((url, dict(params or {}))) + return _FakeResponse(pages[len(calls) - 1]) + + monkeypatch.setattr(client.s, "get", fake_get) + + instances = client.get_instances(cache_ttl=0) + + assert [i["id"] for i in instances] == [1, 2, 3] + assert len(calls) == 2 + # First call hits v1 and includes select_filters={} and a limit. + first_url, first_params = calls[0] + assert "/api/v1/instances/" in first_url + assert first_params["select_filters"] == "{}" + assert first_params["limit"] == 25 + assert "after_token" not in first_params + # Second call carries the next_token from the prior response. + _, second_params = calls[1] + assert second_params["after_token"] == "tok-1" + + +def test_get_instance_uses_v1_select_filters(client, monkeypatch): + calls = [] + + def fake_get(url, params=None): + calls.append((url, params)) + return _FakeResponse( + {"instances": [{"id": 42, "actual_status": "running"}], "next_token": None} + ) + + monkeypatch.setattr(client.s, "get", fake_get) + + instance = client.get_instance(42) + + assert instance == {"id": 42, "actual_status": "running"} + assert len(calls) == 1 + url, params = calls[0] + assert "/api/v1/instances/" in url + assert json.loads(params["select_filters"]) == {"id": {"eq": 42}} + assert params["limit"] == 1 + + +def test_get_instance_returns_none_when_missing(client, monkeypatch): + monkeypatch.setattr( + client.s, + "get", + lambda url, params=None: _FakeResponse({"instances": [], "next_token": None}), + ) + assert client.get_instance(99) is None + + +def test_destroy_instance_still_uses_v0(client, monkeypatch): + calls = [] + + def fake_delete(url): + calls.append(url) + return _FakeResponse({"success": True}) + + monkeypatch.setattr(client.s, "delete", fake_delete) + monkeypatch.setattr(client, "_invalidate_cache", lambda: None) + + assert client.destroy_instance(7) is True + assert "/api/v0/instances/7/" in calls[0]