From 71647728e3fcabbb46fb4abc314488366861f411 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 22 May 2026 11:08:43 +0800 Subject: [PATCH 1/3] wip --- .../server/gateway/tinker_gateway_handlers.py | 8 ++++++++ .../server/gateway/twinkle_gateway_handlers.py | 8 ++++++++ src/twinkle/server/model/app.py | 4 ++++ src/twinkle/server/model/twinkle_handlers.py | 15 +++++++++++++-- src/twinkle/server/utils/state/model_manager.py | 14 ++++++++++++++ src/twinkle/server/utils/state/server_state.py | 6 ++++++ src/twinkle_client/manager.py | 15 +++++++++++++++ src/twinkle_client/types/__init__.py | 1 + src/twinkle_client/types/server.py | 7 +++++++ 9 files changed, 76 insertions(+), 2 deletions(-) diff --git a/src/twinkle/server/gateway/tinker_gateway_handlers.py b/src/twinkle/server/gateway/tinker_gateway_handlers.py index 575ef82a..eede535d 100644 --- a/src/twinkle/server/gateway/tinker_gateway_handlers.py +++ b/src/twinkle/server/gateway/tinker_gateway_handlers.py @@ -33,6 +33,14 @@ def _register_tinker_routes(app: FastAPI, self_fn: Callable[[], GatewayServer]) It is wired in via ``Depends`` so it is resolved lazily at request time. """ + @app.get('/capacity_info') + async def get_capacity_info( + request: Request, + self: GatewayServer = Depends(self_fn), + ) -> dict: + info = await self.state.get_capacity_info() + return info + @app.get('/healthz') async def healthz(request: Request) -> types.HealthResponse: return types.HealthResponse(status='ok') diff --git a/src/twinkle/server/gateway/twinkle_gateway_handlers.py b/src/twinkle/server/gateway/twinkle_gateway_handlers.py index 9a2c2f27..af5b0939 100644 --- a/src/twinkle/server/gateway/twinkle_gateway_handlers.py +++ b/src/twinkle/server/gateway/twinkle_gateway_handlers.py @@ -24,6 +24,14 @@ def _register_twinkle_routes(app: FastAPI, self_fn: Callable[[], GatewayServer]) -> None: """Register all /twinkle/* routes on the given FastAPI app.""" + @app.get('/twinkle/capacity_info', response_model=types.CapacityInfoResponse) + async def get_capacity_info( + request: Request, + self: GatewayServer = Depends(self_fn), + ) -> types.CapacityInfoResponse: + info = await self.state.get_capacity_info() + return types.CapacityInfoResponse(**info) + @app.get('/twinkle/healthz', response_model=types.HealthResponse) async def healthz(request: Request) -> types.HealthResponse: return types.HealthResponse(status='ok') diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index ecd841e3..5d0bc228 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -163,6 +163,10 @@ def get_self() -> ModelManagement: @asynccontextmanager async def lifespan(app: FastAPI): + try: + await get_self()._ensure_replica_registered() + except Exception as e: + logger.warning(f'Failed to register replica at startup: {e}') yield try: await get_self().shutdown() diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index a441387b..af659936 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -496,8 +496,6 @@ async def _task(): config = deserialize_object(body.config) extra_kwargs = body.model_extra or {} training_run_manager = create_training_run_manager(token, client_type='twinkle') - self.register_resource(adapter_name, token, session_id) - self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) lora_config = None if isinstance(config, LoraConfig): @@ -507,6 +505,19 @@ async def _task(): lora_config=lora_config, save_dir=resolved_save_dir, user_metadata={'adapter_name': body.adapter_name}) + await self.state.register_model( + run_config.model_dump(), + token=token, + model_id=adapter_name, + replica_id=self.replica_id, + ) + try: + self.register_resource(adapter_name, token, session_id) + self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) + except Exception: + self.unregister_resource(adapter_name) + await self.state.unload_model(adapter_name) + raise training_run_manager.save(adapter_name, run_config) return {'status': 'ok', 'adapter_name': adapter_name} diff --git a/src/twinkle/server/utils/state/model_manager.py b/src/twinkle/server/utils/state/model_manager.py index 586e4868..2d5345f7 100644 --- a/src/twinkle/server/utils/state/model_manager.py +++ b/src/twinkle/server/utils/state/model_manager.py @@ -28,6 +28,20 @@ def __init__(self, expiration_timeout: float, per_token_model_limit: int = 30) - # replica_id -> max_loras limit declared at registration time self._replica_max_loras: dict[str, int] = {} + def get_capacity_info(self) -> dict[str, int]: + """Return global LoRA capacity across all registered replicas. + + Returns: + Dict containing 'max_loras', 'used_loras', and 'free_loras'. + """ + total_max_loras = sum(self._replica_max_loras.values()) + total_used_loras = sum(len(self._replica_models.get(rid, set())) for rid in self._replica_max_loras.keys()) + return { + 'max_loras': total_max_loras, + 'used_loras': total_used_loras, + 'free_loras': max(0, total_max_loras - total_used_loras), + } + # ----- Replica Registration ----- def register_replica(self, replica_id: str, max_loras: int) -> None: diff --git a/src/twinkle/server/utils/state/server_state.py b/src/twinkle/server/utils/state/server_state.py index 8e7689b2..5cdf5218 100644 --- a/src/twinkle/server/utils/state/server_state.py +++ b/src/twinkle/server/utils/state/server_state.py @@ -57,6 +57,9 @@ def __init__( self._metrics_running = False self._metrics_update_interval: float = float(kwargs.get('metrics_update_interval', 15.0)) + async def get_capacity_info(self) -> dict[str, int]: + return self._model_mgr.get_capacity_info() + # ----- Session Management ----- async def create_session(self, payload: dict[str, Any]) -> str: @@ -374,6 +377,9 @@ class ServerStateProxy: def __init__(self, actor_handle) -> None: self._actor = actor_handle + async def get_capacity_info(self) -> dict[str, int]: + return await self._actor.get_capacity_info.remote() + # ----- Session Management ----- async def create_session(self, payload: dict[str, Any]) -> str: diff --git a/src/twinkle_client/manager.py b/src/twinkle_client/manager.py index b9398997..f7c5d88d 100644 --- a/src/twinkle_client/manager.py +++ b/src/twinkle_client/manager.py @@ -76,6 +76,21 @@ def __init__( self._heartbeat_thread.start() atexit.register(self.close) + def get_capacity_info(self) -> dict: + """ + Get the server's global LoRA capacity information. + + Returns: + dict: Containing 'max_loras', 'used_loras', and 'free_loras'. + + Raises: + TwinkleClientError: If the request fails. + """ + from twinkle_client.types.server import CapacityInfoResponse + response = http_get(self._get_url('/capacity_info')) + data = self._handle_response(response) + return CapacityInfoResponse(**data).model_dump() + # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ diff --git a/src/twinkle_client/types/__init__.py b/src/twinkle_client/types/__init__.py index d58bb00c..49673b0e 100644 --- a/src/twinkle_client/types/__init__.py +++ b/src/twinkle_client/types/__init__.py @@ -76,6 +76,7 @@ SupportedModel, WeightsInfoRequest, WeightsInfoResponse as ServerWeightsInfoResponse, + CapacityInfoResponse, ) from .session import CreateSessionRequest, CreateSessionResponse, SessionHeartbeatRequest, SessionHeartbeatResponse from .training import ( diff --git a/src/twinkle_client/types/server.py b/src/twinkle_client/types/server.py index 2d9233b4..1c7c992d 100644 --- a/src/twinkle_client/types/server.py +++ b/src/twinkle_client/types/server.py @@ -40,3 +40,10 @@ class CheckpointPathResponse(BaseModel): """Response body for the /checkpoint_path endpoint.""" path: str twinkle_path: str + + +class CapacityInfoResponse(BaseModel): + """Response body for the /capacity_info endpoint.""" + max_loras: int + used_loras: int + free_loras: int From cb6b88ff32a5af5c25291537099bd03b9de5af92 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 22 May 2026 14:19:06 +0800 Subject: [PATCH 2/3] Fix capacity info cold start registration --- src/twinkle/server/model/app.py | 8 ++++++++ src/twinkle/server/utils/state/server_state.py | 3 +++ 2 files changed, 11 insertions(+) diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index 5d0bc228..2ebba7be 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -85,8 +85,16 @@ def __init__(self, # Initialize mixins self._init_task_queue(TaskQueueConfig.from_dict(queue_config), deployment_name='Model') self._init_adapter_manager(**(adapter_config or {})) + self._register_replica_at_startup() # Note: countdown task is started lazily in _ensure_sticky() + def _register_replica_at_startup(self) -> None: + try: + self.state.register_replica_blocking(self.replica_id, self.max_loras) + self._replica_registered = True + except Exception as e: + logger.warning(f'Failed to register replica at startup: {e}') + async def _ensure_replica_registered(self): """Lazily register replica on first async request.""" if not self._replica_registered: diff --git a/src/twinkle/server/utils/state/server_state.py b/src/twinkle/server/utils/state/server_state.py index 5cdf5218..4bbadb3b 100644 --- a/src/twinkle/server/utils/state/server_state.py +++ b/src/twinkle/server/utils/state/server_state.py @@ -411,6 +411,9 @@ async def get_model_metadata(self, model_id: str) -> dict[str, Any] | None: async def register_replica(self, replica_id: str, max_loras: int) -> None: await self._actor.register_replica.remote(replica_id, max_loras) + def register_replica_blocking(self, replica_id: str, max_loras: int) -> None: + ray.get(self._actor.register_replica.remote(replica_id, max_loras)) + async def unregister_replica(self, replica_id: str) -> None: await self._actor.unregister_replica.remote(replica_id) From 012648af488f5f67df01683d0408789a57eb2a03 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 28 May 2026 15:39:19 +0800 Subject: [PATCH 3/3] fix --- .../server/gateway/tinker_gateway_handlers.py | 8 -------- src/twinkle/server/model/app.py | 8 -------- src/twinkle/server/model/tinker_handlers.py | 3 ++- src/twinkle/server/model/twinkle_handlers.py | 1 + src/twinkle/server/utils/state/server_state.py | 15 ++++++++------- src/twinkle_client/manager.py | 10 +++++----- 6 files changed, 16 insertions(+), 29 deletions(-) diff --git a/src/twinkle/server/gateway/tinker_gateway_handlers.py b/src/twinkle/server/gateway/tinker_gateway_handlers.py index eede535d..575ef82a 100644 --- a/src/twinkle/server/gateway/tinker_gateway_handlers.py +++ b/src/twinkle/server/gateway/tinker_gateway_handlers.py @@ -33,14 +33,6 @@ def _register_tinker_routes(app: FastAPI, self_fn: Callable[[], GatewayServer]) It is wired in via ``Depends`` so it is resolved lazily at request time. """ - @app.get('/capacity_info') - async def get_capacity_info( - request: Request, - self: GatewayServer = Depends(self_fn), - ) -> dict: - info = await self.state.get_capacity_info() - return info - @app.get('/healthz') async def healthz(request: Request) -> types.HealthResponse: return types.HealthResponse(status='ok') diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index 2ebba7be..5d0bc228 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -85,16 +85,8 @@ def __init__(self, # Initialize mixins self._init_task_queue(TaskQueueConfig.from_dict(queue_config), deployment_name='Model') self._init_adapter_manager(**(adapter_config or {})) - self._register_replica_at_startup() # Note: countdown task is started lazily in _ensure_sticky() - def _register_replica_at_startup(self) -> None: - try: - self.state.register_replica_blocking(self.replica_id, self.max_loras) - self._replica_registered = True - except Exception as e: - logger.warning(f'Failed to register replica at startup: {e}') - async def _ensure_replica_registered(self): """Lazily register replica on first async request.""" if not self._replica_registered: diff --git a/src/twinkle/server/model/tinker_handlers.py b/src/twinkle/server/model/tinker_handlers.py index e357d720..1f3d7ad9 100644 --- a/src/twinkle/server/model/tinker_handlers.py +++ b/src/twinkle/server/model/tinker_handlers.py @@ -41,7 +41,8 @@ async def create_model( async def _create_adapter(): _model_id = None try: - _model_id = await self.state.register_model(body.model_dump(), token=token, replica_id=self.replica_id) + _model_id = await self.state.register_model( + body.model_dump(), token=token, replica_id=self.replica_id, session_id=body.session_id) if body.lora_config: # TODO: Make LoraConfig more flexible lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear') diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index af659936..ff161f5f 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -510,6 +510,7 @@ async def _task(): token=token, model_id=adapter_name, replica_id=self.replica_id, + session_id=session_id, ) try: self.register_resource(adapter_name, token, session_id) diff --git a/src/twinkle/server/utils/state/server_state.py b/src/twinkle/server/utils/state/server_state.py index 4bbadb3b..fd3a7626 100644 --- a/src/twinkle/server/utils/state/server_state.py +++ b/src/twinkle/server/utils/state/server_state.py @@ -102,7 +102,8 @@ async def register_model(self, payload: dict[str, Any], token: str, model_id: str | None = None, - replica_id: str | None = None) -> str: + replica_id: str | None = None, + session_id: str | None = None) -> str: """Register a new model with the server state. Args: @@ -110,6 +111,8 @@ async def register_model(self, token: User token that owns this model. Required. model_id: Optional explicit model_id; otherwise auto-generated. replica_id: Optional replica that is hosting this model. + session_id: Optional owning session; enables cascade cleanup when + the session expires. Falls back to ``payload['session_id']``. Returns: The model_id for the registered model. @@ -120,7 +123,7 @@ async def register_model(self, _model_id = re.sub(r'[^\w\-]', '_', _model_id) record = ModelRecord( - session_id=payload.get('session_id'), + session_id=session_id or payload.get('session_id'), model_seq_id=payload.get('model_seq_id'), base_model=payload.get('base_model'), user_metadata=payload.get('user_metadata') or {}, @@ -397,8 +400,9 @@ async def register_model(self, payload: dict[str, Any], token: str, model_id: str | None = None, - replica_id: str | None = None) -> str: - return await self._actor.register_model.remote(payload, token, model_id, replica_id) + replica_id: str | None = None, + session_id: str | None = None) -> str: + return await self._actor.register_model.remote(payload, token, model_id, replica_id, session_id) async def unload_model(self, model_id: str) -> bool: return await self._actor.unload_model.remote(model_id) @@ -411,9 +415,6 @@ async def get_model_metadata(self, model_id: str) -> dict[str, Any] | None: async def register_replica(self, replica_id: str, max_loras: int) -> None: await self._actor.register_replica.remote(replica_id, max_loras) - def register_replica_blocking(self, replica_id: str, max_loras: int) -> None: - ray.get(self._actor.register_replica.remote(replica_id, max_loras)) - async def unregister_replica(self, replica_id: str) -> None: await self._actor.unregister_replica.remote(replica_id) diff --git a/src/twinkle_client/manager.py b/src/twinkle_client/manager.py index f7c5d88d..12257d63 100644 --- a/src/twinkle_client/manager.py +++ b/src/twinkle_client/manager.py @@ -5,7 +5,7 @@ import threading from typing import Any, Dict, List, Optional, Tuple from twinkle import get_logger -from twinkle_client.types.server import (DeleteCheckpointResponse, GetServerCapabilitiesResponse) +from twinkle_client.types.server import (CapacityInfoResponse, DeleteCheckpointResponse, GetServerCapabilitiesResponse) from twinkle_client.types.session import (CreateSessionRequest, CreateSessionResponse, SessionHeartbeatRequest, SessionHeartbeatResponse) from twinkle_client.types.training import (Checkpoint, Cursor, ParsedCheckpointTwinklePath, TrainingRun, @@ -76,20 +76,20 @@ def __init__( self._heartbeat_thread.start() atexit.register(self.close) - def get_capacity_info(self) -> dict: + def get_capacity_info(self) -> CapacityInfoResponse: """ Get the server's global LoRA capacity information. Returns: - dict: Containing 'max_loras', 'used_loras', and 'free_loras'. + :class:`~twinkle_client.types.server.CapacityInfoResponse` with + ``max_loras``, ``used_loras``, and ``free_loras`` fields. Raises: TwinkleClientError: If the request fails. """ - from twinkle_client.types.server import CapacityInfoResponse response = http_get(self._get_url('/capacity_info')) data = self._handle_response(response) - return CapacityInfoResponse(**data).model_dump() + return CapacityInfoResponse(**data) # ------------------------------------------------------------------ # Internal helpers