Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 51 additions & 17 deletions src/dstack/_internal/core/backends/vastai/api_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import threading
import time
from typing import List, Optional, Union
Expand All @@ -10,24 +11,30 @@
from dstack._internal.core.errors import NoCapacityError
from dstack._internal.core.models.common import RegistryAuth

# v1 instances list is paginated; max enforced server-side is 25 per page.
_V1_INSTANCES_PAGE_LIMIT = 25


class VastAIAPIClient:
def __init__(self, api_key: str):
self.api_url = "https://console.vast.ai/api/v0".rstrip("/")
self.api_url_v0 = "https://console.vast.ai/api/v0"
self.api_url_v1 = "https://console.vast.ai/api/v1"
self.api_key = api_key
self.s = requests.Session() # TODO: set adequate timeout everywhere the session is used
retries = Retry(
total=5,
backoff_factor=1,
status_forcelist=[429, 504],
)
self.s.mount(prefix=self._url("/instances/"), adapter=HTTPAdapter(max_retries=retries))
adapter = HTTPAdapter(max_retries=retries)
self.s.mount(prefix=f"{self.api_url_v0}/instances/", adapter=adapter)
self.s.mount(prefix=f"{self.api_url_v1}/instances/", adapter=adapter)
self.lock = threading.Lock()
self.instances_cache_ts: float = 0
self.instances_cache: List[dict] = []

def get_bundle(self, bundle_id: Union[str, int]) -> Optional[dict]:
resp = self.s.post(self._url("/bundles/"), json={"id": {"eq": bundle_id}})
resp = self.s.post(self._url_v0("/bundles/"), json={"id": {"eq": bundle_id}})
resp.raise_for_status()
data = resp.json()
offers = data["offers"]
Expand Down Expand Up @@ -80,7 +87,7 @@ def create_instance(
"create_from": None,
"force": False,
}
resp = self.s.put(self._url(f"/asks/{bundle_id}/"), json=payload)
resp = self.s.put(self._url_v0(f"/asks/{bundle_id}/"), json=payload)
if resp.status_code != 200 or not (data := resp.json())["success"]:
raise NoCapacityError(resp.text)
self._invalidate_cache()
Expand All @@ -94,7 +101,7 @@ def destroy_instance(self, instance_id: Union[str, int]) -> bool:
Returns:
True if instance was destroyed successfully
"""
resp = self.s.delete(self._url(f"/instances/{instance_id}/"))
resp = self.s.delete(self._url_v0(f"/instances/{instance_id}/"))
if resp.status_code != 200 or not resp.json()["success"]:
return False
self._invalidate_cache()
Expand All @@ -103,23 +110,19 @@ def destroy_instance(self, instance_id: Union[str, int]) -> bool:
def get_instances(self, cache_ttl: float = 3.0) -> List[dict]:
with self.lock:
if time.time() - self.instances_cache_ts > cache_ttl:
resp = self.s.get(self._url("/instances/"))
resp.raise_for_status()
data = resp.json()
self.instances_cache = self._list_instances_v1()
self.instances_cache_ts = time.time()
self.instances_cache = data["instances"]
return self.instances_cache

def get_instance(self, instance_id: Union[str, int]) -> Optional[dict]:
instances = self.get_instances()
for instance in instances:
if instance["id"] == int(instance_id):
return instance
return None
instances = self._list_instances_v1(
select_filters={"id": {"eq": int(instance_id)}}, limit=1
)
return instances[0] if instances else None

def request_logs(self, instance_id: Union[str, int]) -> dict:
resp = self.s.put(
self._url(f"/instances/request_logs/{instance_id}/"), json={"tail": "1000"}
self._url_v0(f"/instances/request_logs/{instance_id}/"), json={"tail": "1000"}
)
resp.raise_for_status()
data = resp.json()
Expand All @@ -134,8 +137,39 @@ def auth_test(self) -> bool:
except requests.HTTPError:
return False

def _url(self, path):
return f"{self.api_url}/{path.lstrip('/')}?api_key={self.api_key}"
def _list_instances_v1(
self,
select_filters: Optional[dict] = None,
limit: int = _V1_INSTANCES_PAGE_LIMIT,
) -> List[dict]:
"""Page through the v1 instances endpoint and return all matches.

The v1 endpoint enforces keyset pagination with a max of 25 results
per response, so we follow `next_token` until the server stops
returning one.
"""
page_limit = max(1, min(limit, _V1_INSTANCES_PAGE_LIMIT))
params: dict = {
"limit": page_limit,
"select_filters": json.dumps(select_filters or {}),
}
instances: List[dict] = []
while True:
resp = self.s.get(self._url_v1("/instances/"), params=params)
resp.raise_for_status()
data = resp.json()
instances.extend(data.get("instances", []))
next_token = data.get("next_token")
if not next_token:
break
params["after_token"] = next_token
return instances

def _url_v0(self, path: str) -> str:
return f"{self.api_url_v0}/{path.lstrip('/')}?api_key={self.api_key}"

def _url_v1(self, path: str) -> str:
return f"{self.api_url_v1}/{path.lstrip('/')}?api_key={self.api_key}"

def _invalidate_cache(self):
with self.lock:
Expand Down
108 changes: 108 additions & 0 deletions src/tests/_internal/core/backends/vastai/test_api_client.py

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's drop this new test file completely – the tests don't seem very useful.

Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import json
from urllib.parse import parse_qs, urlparse

import pytest

from dstack._internal.core.backends.vastai.api_client import VastAIAPIClient


class _FakeResponse:
def __init__(self, payload, status_code=200):
self._payload = payload
self.status_code = status_code
self.text = json.dumps(payload)

def json(self):
return self._payload

def raise_for_status(self):
if self.status_code >= 400:
raise AssertionError(f"unexpected status {self.status_code}")


@pytest.fixture
def client():
return VastAIAPIClient(api_key="test-key")


def _parse_call(call):
"""Return (path, query_dict) for a recorded get() call."""
url = call.args[0] if call.args else call.kwargs["url"]
parsed = urlparse(url)
query = {k: v[0] for k, v in parse_qs(parsed.query).items()}
if "params" in call.kwargs and call.kwargs["params"]:
for k, v in call.kwargs["params"].items():
query[k] = str(v)
return parsed.path, query


def test_get_instances_uses_v1_endpoint_and_paginates(client, monkeypatch):
pages = [
{"instances": [{"id": 1}, {"id": 2}], "next_token": "tok-1"},
{"instances": [{"id": 3}], "next_token": None},
]
calls = []

def fake_get(url, params=None):
calls.append((url, dict(params or {})))
return _FakeResponse(pages[len(calls) - 1])

monkeypatch.setattr(client.s, "get", fake_get)

instances = client.get_instances(cache_ttl=0)

assert [i["id"] for i in instances] == [1, 2, 3]
assert len(calls) == 2
# First call hits v1 and includes select_filters={} and a limit.
first_url, first_params = calls[0]
assert "/api/v1/instances/" in first_url
assert first_params["select_filters"] == "{}"
assert first_params["limit"] == 25
assert "after_token" not in first_params
# Second call carries the next_token from the prior response.
_, second_params = calls[1]
assert second_params["after_token"] == "tok-1"


def test_get_instance_uses_v1_select_filters(client, monkeypatch):
calls = []

def fake_get(url, params=None):
calls.append((url, params))
return _FakeResponse(
{"instances": [{"id": 42, "actual_status": "running"}], "next_token": None}
)

monkeypatch.setattr(client.s, "get", fake_get)

instance = client.get_instance(42)

assert instance == {"id": 42, "actual_status": "running"}
assert len(calls) == 1
url, params = calls[0]
assert "/api/v1/instances/" in url
assert json.loads(params["select_filters"]) == {"id": {"eq": 42}}
assert params["limit"] == 1


def test_get_instance_returns_none_when_missing(client, monkeypatch):
monkeypatch.setattr(
client.s,
"get",
lambda url, params=None: _FakeResponse({"instances": [], "next_token": None}),
)
assert client.get_instance(99) is None


def test_destroy_instance_still_uses_v0(client, monkeypatch):
calls = []

def fake_delete(url):
calls.append(url)
return _FakeResponse({"success": True})

monkeypatch.setattr(client.s, "delete", fake_delete)
monkeypatch.setattr(client, "_invalidate_cache", lambda: None)

assert client.destroy_instance(7) is True
assert "/api/v0/instances/7/" in calls[0]
Loading