From fd8263de247e33172245f17ae91fd2fdceefbef4 Mon Sep 17 00:00:00 2001 From: Atsushi Morimoto <74th.tech@gmail.com> Date: Fri, 13 Mar 2026 20:34:26 +0900 Subject: [PATCH] refactor: improve WebSocket proxy management with async locking and cleanup --- stackchan_server/app.py | 54 +++++++++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 16 deletions(-) diff --git a/stackchan_server/app.py b/stackchan_server/app.py index 03d6933..14496d2 100644 --- a/stackchan_server/app.py +++ b/stackchan_server/app.py @@ -36,6 +36,7 @@ def __init__( self._setup_fn: Optional[Callable[[WsProxy], Awaitable[None]]] = None self._talk_session_fn: Optional[Callable[[WsProxy], Awaitable[None]]] = None self._proxies: dict[str, WsProxy] = {} + self._proxies_lock = asyncio.Lock() @self.fastapi.get("/health") async def _health() -> dict[str, str]: @@ -47,28 +48,25 @@ async def _ws_audio(websocket: WebSocket): @self.fastapi.get("/v1/stackchan", response_model=list[StackChanInfo]) async def _list_stackchans(): - return [ - StackChanInfo(ip=ip, state=proxy.current_state.name.lower()) - for ip, proxy in self._proxies.items() - ] + return await self._list_stackchan_infos() @self.fastapi.get("/v1/stackchan/{stackchan_ip}", response_model=StackChanInfo) async def _get_stackchan(stackchan_ip: str): - proxy = self._proxies.get(stackchan_ip) + proxy = await self._get_proxy(stackchan_ip) if proxy is None: raise HTTPException(status_code=404, detail="stackchan not connected") return StackChanInfo(ip=stackchan_ip, state=proxy.current_state.name.lower()) @self.fastapi.post("/v1/stackchan/{stackchan_ip}/wakeword", status_code=204) async def _trigger_wakeword(stackchan_ip: str): - proxy = self._proxies.get(stackchan_ip) + proxy = await self._get_proxy(stackchan_ip) if proxy is None: raise HTTPException(status_code=404, detail="stackchan not connected") proxy.trigger_wakeword() @self.fastapi.post("/v1/stackchan/{stackchan_ip}/speak", status_code=204) async def _speak(stackchan_ip: str, body: SpeakRequest): - proxy = self._proxies.get(stackchan_ip) + proxy = await self._get_proxy(stackchan_ip) if proxy is None: raise HTTPException(status_code=404, detail="stackchan not connected") await proxy.speak(body.text) @@ -85,20 +83,18 @@ async def _handle_ws(self, websocket: WebSocket) -> None: await websocket.accept() client_ip = websocket.client.host if websocket.client else "unknown" - # 同一 IP からの既存接続があれば切断する - existing = self._proxies.get(client_ip) - if existing is not None: - logger.info("Duplicate connection from %s, closing old one", client_ip) - await existing.close() - self._proxies.pop(client_ip, None) - proxy = WsProxy( websocket, speech_recognizer=self.speech_recognizer, speech_synthesizer=self.speech_synthesizer, ) - self._proxies[client_ip] = proxy + existing = await self._register_proxy(client_ip, proxy) await proxy.start() + + if existing is not None and existing is not proxy: + logger.info("Replacing existing connection from %s", client_ip) + await existing.close() + try: if self._setup_fn: await self._setup_fn(proxy) @@ -131,7 +127,33 @@ async def _handle_ws(self, websocket: WebSocket) -> None: pass finally: await proxy.close() - self._proxies.pop(client_ip, None) + await self._unregister_proxy(client_ip, proxy) + + async def _list_stackchan_infos(self) -> list[StackChanInfo]: + async with self._proxies_lock: + return [ + StackChanInfo(ip=ip, state=proxy.current_state.name.lower()) + for ip, proxy in self._proxies.items() + if not proxy.closed + ] + + async def _get_proxy(self, client_ip: str) -> WsProxy | None: + async with self._proxies_lock: + proxy = self._proxies.get(client_ip) + if proxy is None or proxy.closed: + return None + return proxy + + async def _register_proxy(self, client_ip: str, proxy: WsProxy) -> WsProxy | None: + async with self._proxies_lock: + existing = self._proxies.get(client_ip) + self._proxies[client_ip] = proxy + return existing + + async def _unregister_proxy(self, client_ip: str, proxy: WsProxy) -> None: + async with self._proxies_lock: + if self._proxies.get(client_ip) is proxy: + self._proxies.pop(client_ip, None) def run(self, host: str = "0.0.0.0", port: int = 8000, reload: bool = True) -> None: import uvicorn