Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/twinkle/server/gateway/twinkle_gateway_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@
def _register_twinkle_routes(app: FastAPI, self_fn: Callable[[], GatewayServer]) -> None:
"""Register all /twinkle/* routes on the given FastAPI app."""

@app.get('/twinkle/capacity_info', response_model=types.CapacityInfoResponse)
async def get_capacity_info(
request: Request,
self: GatewayServer = Depends(self_fn),
) -> types.CapacityInfoResponse:
info = await self.state.get_capacity_info()
return types.CapacityInfoResponse(**info)

@app.get('/twinkle/healthz', response_model=types.HealthResponse)
async def healthz(request: Request) -> types.HealthResponse:
return types.HealthResponse(status='ok')
Expand Down
4 changes: 4 additions & 0 deletions src/twinkle/server/model/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ def get_self() -> ModelManagement:

@asynccontextmanager
async def lifespan(app: FastAPI):
try:
await get_self()._ensure_replica_registered()
except Exception as e:
logger.warning(f'Failed to register replica at startup: {e}')
yield
try:
await get_self().shutdown()
Expand Down
3 changes: 2 additions & 1 deletion src/twinkle/server/model/tinker_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ async def create_model(
async def _create_adapter():
_model_id = None
try:
_model_id = await self.state.register_model(body.model_dump(), token=token, replica_id=self.replica_id)
_model_id = await self.state.register_model(
body.model_dump(), token=token, replica_id=self.replica_id, session_id=body.session_id)
if body.lora_config:
# TODO: Make LoraConfig more flexible
lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear')
Expand Down
16 changes: 14 additions & 2 deletions src/twinkle/server/model/twinkle_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,6 @@ async def _task():
config = deserialize_object(body.config)
extra_kwargs = body.model_extra or {}
training_run_manager = create_training_run_manager(token, client_type='twinkle')
self.register_resource(adapter_name, token, session_id)
self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs)

lora_config = None
if isinstance(config, LoraConfig):
Expand All @@ -507,6 +505,20 @@ async def _task():
lora_config=lora_config,
save_dir=resolved_save_dir,
user_metadata={'adapter_name': body.adapter_name})
await self.state.register_model(
run_config.model_dump(),
token=token,
model_id=adapter_name,
replica_id=self.replica_id,
session_id=session_id,
)
Comment thread
kevssim marked this conversation as resolved.
try:
self.register_resource(adapter_name, token, session_id)
self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs)
except Exception:
self.unregister_resource(adapter_name)
await self.state.unload_model(adapter_name)
raise
Comment thread
kevssim marked this conversation as resolved.
training_run_manager.save(adapter_name, run_config)
return {'status': 'ok', 'adapter_name': adapter_name}

Expand Down
14 changes: 14 additions & 0 deletions src/twinkle/server/utils/state/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@ def __init__(self, expiration_timeout: float, per_token_model_limit: int = 30) -
# replica_id -> max_loras limit declared at registration time
self._replica_max_loras: dict[str, int] = {}

def get_capacity_info(self) -> dict[str, int]:
"""Return global LoRA capacity across all registered replicas.

Returns:
Dict containing 'max_loras', 'used_loras', and 'free_loras'.
"""
total_max_loras = sum(self._replica_max_loras.values())
total_used_loras = sum(len(self._replica_models.get(rid, set())) for rid in self._replica_max_loras.keys())
return {
'max_loras': total_max_loras,
'used_loras': total_used_loras,
'free_loras': max(0, total_max_loras - total_used_loras),
}

# ----- Replica Registration -----

def register_replica(self, replica_id: str, max_loras: int) -> None:
Expand Down
18 changes: 14 additions & 4 deletions src/twinkle/server/utils/state/server_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def __init__(
self._metrics_running = False
self._metrics_update_interval: float = float(kwargs.get('metrics_update_interval', 15.0))

async def get_capacity_info(self) -> dict[str, int]:
return self._model_mgr.get_capacity_info()

# ----- Session Management -----

async def create_session(self, payload: dict[str, Any]) -> str:
Expand Down Expand Up @@ -99,14 +102,17 @@ async def register_model(self,
payload: dict[str, Any],
token: str,
model_id: str | None = None,
replica_id: str | None = None) -> str:
replica_id: str | None = None,
session_id: str | None = None) -> str:
"""Register a new model with the server state.

Args:
payload: Model configuration containing base_model, lora_config, etc.
token: User token that owns this model. Required.
model_id: Optional explicit model_id; otherwise auto-generated.
replica_id: Optional replica that is hosting this model.
session_id: Optional owning session; enables cascade cleanup when
the session expires. Falls back to ``payload['session_id']``.

Returns:
The model_id for the registered model.
Expand All @@ -117,7 +123,7 @@ async def register_model(self,
_model_id = re.sub(r'[^\w\-]', '_', _model_id)

record = ModelRecord(
session_id=payload.get('session_id'),
session_id=session_id or payload.get('session_id'),
model_seq_id=payload.get('model_seq_id'),
base_model=payload.get('base_model'),
user_metadata=payload.get('user_metadata') or {},
Expand Down Expand Up @@ -374,6 +380,9 @@ class ServerStateProxy:
def __init__(self, actor_handle) -> None:
self._actor = actor_handle

async def get_capacity_info(self) -> dict[str, int]:
return await self._actor.get_capacity_info.remote()

# ----- Session Management -----

async def create_session(self, payload: dict[str, Any]) -> str:
Expand All @@ -391,8 +400,9 @@ async def register_model(self,
payload: dict[str, Any],
token: str,
model_id: str | None = None,
replica_id: str | None = None) -> str:
return await self._actor.register_model.remote(payload, token, model_id, replica_id)
replica_id: str | None = None,
session_id: str | None = None) -> str:
return await self._actor.register_model.remote(payload, token, model_id, replica_id, session_id)

async def unload_model(self, model_id: str) -> bool:
return await self._actor.unload_model.remote(model_id)
Expand Down
17 changes: 16 additions & 1 deletion src/twinkle_client/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import threading
from typing import Any, Dict, List, Optional, Tuple
from twinkle import get_logger
from twinkle_client.types.server import (DeleteCheckpointResponse, GetServerCapabilitiesResponse)
from twinkle_client.types.server import (CapacityInfoResponse, DeleteCheckpointResponse, GetServerCapabilitiesResponse)
from twinkle_client.types.session import (CreateSessionRequest, CreateSessionResponse, SessionHeartbeatRequest,
SessionHeartbeatResponse)
from twinkle_client.types.training import (Checkpoint, Cursor, ParsedCheckpointTwinklePath, TrainingRun,
Expand Down Expand Up @@ -76,6 +76,21 @@ def __init__(
self._heartbeat_thread.start()
atexit.register(self.close)

def get_capacity_info(self) -> CapacityInfoResponse:
"""
Get the server's global LoRA capacity information.

Returns:
:class:`~twinkle_client.types.server.CapacityInfoResponse` with
``max_loras``, ``used_loras``, and ``free_loras`` fields.

Raises:
TwinkleClientError: If the request fails.
"""
response = http_get(self._get_url('/capacity_info'))
data = self._handle_response(response)
return CapacityInfoResponse(**data)

# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions src/twinkle_client/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
SupportedModel,
WeightsInfoRequest,
WeightsInfoResponse as ServerWeightsInfoResponse,
CapacityInfoResponse,
)
from .session import CreateSessionRequest, CreateSessionResponse, SessionHeartbeatRequest, SessionHeartbeatResponse
from .training import (
Expand Down
7 changes: 7 additions & 0 deletions src/twinkle_client/types/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,10 @@ class CheckpointPathResponse(BaseModel):
"""Response body for the /checkpoint_path endpoint."""
path: str
twinkle_path: str


class CapacityInfoResponse(BaseModel):
"""Response body for the /capacity_info endpoint."""
max_loras: int
used_loras: int
free_loras: int
Loading