diff --git a/.gitignore b/.gitignore index ca4ef755b2..017b2e3abc 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ node_modules/ data/ temp/ WareHouse/ + diff --git a/frontend/src/locales/en.json b/frontend/src/locales/en.json index 051d972429..41b0feb5c4 100644 --- a/frontend/src/locales/en.json +++ b/frontend/src/locales/en.json @@ -182,7 +182,8 @@ "alert_download_failed": "Failed to download file, please try again.", "alert_download_logs_failed": "Download failed, please try again later", "no_initial_instructions": "No initial instructions provided", - "workflow_cancelled": "Workflow cancelled" + "workflow_cancelled": "Workflow cancelled", + "reconnected": "Reconnected to existing session" }, "form_generator": { "advanced_settings": "Advanced Settings", diff --git a/frontend/src/locales/zh.json b/frontend/src/locales/zh.json index 119d3900d0..f6cb3041fd 100644 --- a/frontend/src/locales/zh.json +++ b/frontend/src/locales/zh.json @@ -160,7 +160,8 @@ "alert_download_failed": "下载文件失败,请重试。", "alert_download_logs_failed": "下载失败,请稍后重试", "no_initial_instructions": "未提供初始说明", - "workflow_cancelled": "工作流已取消" + "workflow_cancelled": "工作流已取消", + "reconnected": "已重新连接到现有会话" }, "components": { "workflow_edge": { diff --git a/frontend/src/pages/LaunchView.vue b/frontend/src/pages/LaunchView.vue index 72426d57a3..ce7f5868ce 100755 --- a/frontend/src/pages/LaunchView.vue +++ b/frontend/src/pages/LaunchView.vue @@ -759,7 +759,7 @@ const clearUploadedAttachments = () => { } // Reset the WebSocket connection and related state -const resetConnectionState = ({ closeSocket = true } = {}) => { +const resetConnectionState = ({ closeSocket = true, keepSession = false } = {}) => { if (closeSocket && ws) { try { ws.close() @@ -769,20 +769,29 @@ const resetConnectionState = ({ closeSocket = true } = {}) => { } ws = null - sessionId = null isConnectionReady.value = false - shouldGlow.value = false - isWorkflowRunning.value = false - activeNodes.value = [] + + if (!keepSession) { + sessionId = null + isWorkflowRunning.value = false + activeNodes.value = [] + shouldGlow.value = false + clearUploadedAttachments() + chatMessages.value = [] + nodesLoadingMessagesMap.clear() + nameToSpriteMap.value.clear() + nodeSpriteMap.value.clear() + } + if (attachmentHoverTimeout) { clearTimeout(attachmentHoverTimeout) attachmentHoverTimeout = null } - clearUploadedAttachments() } // Button state management const isWorkflowRunning = ref(false) +const isReconnecting = ref(false) // Active node list const activeNodes = ref([]) @@ -1432,12 +1441,28 @@ const sendHumanInput = () => { } // Establish a WebSocket connection -const establishWebSocketConnection = () => { - // Reset any previous state before creating a new socket - resetConnectionState() +const establishWebSocketConnection = (options = {}) => { + let { sessionId: reconnectSid } = options + + // If no explicit sessionId, check URL for an existing session + if (!reconnectSid) { + const urlSession = route.query?.session + if (urlSession && typeof urlSession === 'string' && urlSession.trim()) { + reconnectSid = urlSession.trim() + } + } - if (!selectedFile.value) { - return + const reconnecting = !!reconnectSid + + if (reconnecting) { + isReconnecting.value = true + resetConnectionState({ closeSocket: true, keepSession: true }) + status.value = 'Connecting...' + } else { + resetConnectionState() + if (!selectedFile.value) { + return + } } const apiBase = import.meta.env.VITE_API_BASE_URL || '' @@ -1457,7 +1482,9 @@ const establishWebSocketConnection = () => { } } - const wsUrl = `${scheme}//${host}/ws` + const wsUrl = reconnecting + ? `${scheme}//${host}/ws?session_id=${encodeURIComponent(reconnectSid)}` + : `${scheme}//${host}/ws` const socket = new WebSocket(wsUrl) ws = socket @@ -1485,12 +1512,15 @@ const establishWebSocketConnection = () => { } isConnectionReady.value = true - shouldGlow.value = true - status.value = 'Waiting for launch...' - nextTick(() => { - taskInputRef.value?.focus() - }) + // For new connections, set initial state; reconnections are handled by session_resumed + if (!isReconnecting.value) { + shouldGlow.value = true + status.value = 'Waiting for launch...' + nextTick(() => { + taskInputRef.value?.focus() + }) + } } else { processMessage(msg) } @@ -1522,6 +1552,11 @@ const establishWebSocketConnection = () => { // Watch for file selection changes watch(selectedFile, (newFile) => { + // When reconnecting, selectedFile is set by session_resumed; skip the normal flow + if (isReconnecting.value) { + return + } + taskPrompt.value = '' fileSearchQuery.value = newFile || '' isFileSearchDirty.value = false @@ -1555,10 +1590,18 @@ watch( } ) -onMounted(() => { +onMounted(async () => { document.addEventListener('click', handleClickOutside) document.addEventListener('keydown', handleKeydown) - loadWorkflows() + await loadWorkflows() + // If URL contains a session id, the watch on selectedFile (triggered by + // applyWorkflowFromRoute inside loadWorkflows) will call establishWebSocketConnection, + // which auto-detects the session param and reconnects. + // Fallback: if session is present but no workflow was in URL, connect directly. + const sessionParam = route.query?.session + if (sessionParam && typeof sessionParam === 'string' && sessionParam.trim() && !selectedFile.value) { + establishWebSocketConnection({ sessionId: sessionParam.trim() }) + } }) onUnmounted(() => { @@ -1836,6 +1879,15 @@ const launchWorkflow = async () => { status.value = 'Running...' isWorkflowRunning.value = true + + // Persist session id in URL for reconnection after refresh + router.push({ + query: { + ...route.query, + workflow: selectedFile.value, + session: sessionId + } + }) } else { const error = await response.json().catch(() => ({})) console.error('Failed to launch workflow:', error) @@ -2025,6 +2077,68 @@ const animateSpriteAlongEdge = (edge) => { const processMessage = async (msg) => { console.log('Message: ', msg) + // Session resumed after reconnection — sync final UI state + if (msg.type === 'session_resumed') { + const data = msg.data + sessionId = data.session_id + + // Restore workflow selection without clearing chat (messages were already replayed) + // Set selectedFile BEFORE clearing isReconnecting so the watch skips + if (data.yaml_file) { + selectedFile.value = data.yaml_file + fileSearchQuery.value = data.yaml_file + // Load YAML data and sprites (but don't clear chat) + try { + const yamlContentString = await fetchWorkflowYAML(data.yaml_file) + const parsedYaml = yaml.load(yamlContentString) + workflowYaml.value = parsedYaml || {} + + const yamlNodes = Array.isArray(parsedYaml?.graph?.nodes) ? parsedYaml.graph.nodes : [] + for (const node of yamlNodes) { + if (node.id && !nodeSpriteMap.value.has(node.id)) { + const spritePath = spriteFetcher.fetchSprite(node.id, 'D', 1) + nodeSpriteMap.value.set(node.id, spritePath) + } + } + } catch (e) { + console.error('Failed to load YAML on reconnect:', e) + } + } + + isReconnecting.value = false + + // Restore workflow status + const statusMap = { + 'idle': 'Connected', + 'running': 'Running...', + 'waiting_for_input': 'Waiting for input...', + 'completed': 'Completed', + 'error': 'Error', + 'cancelled': 'Cancelled', + } + status.value = statusMap[data.status] || 'Connected' + + if (data.status === 'running' || data.status === 'waiting_for_input') { + isWorkflowRunning.value = true + } + + if (data.status === 'waiting_for_input') { + shouldGlow.value = true + } + + if (data.status === 'completed' || data.status === 'error' || data.status === 'cancelled') { + sessionIdToDownload = sessionId + } + + if (data.current_node_id && !activeNodes.value.includes(data.current_node_id)) { + activeNodes.value.push(data.current_node_id) + } + + isConnectionReady.value = true + addChatNotification(t('launch.reconnected')) + return + } + // Prompt for human input if (msg.type === 'human_input_required') { const fullMessage = msg.data.task_description + '\n\n' + msg.data.input @@ -2184,6 +2298,14 @@ const processMessage = async (msg) => { sessionIdToDownload = sessionId } + // Workflow cancelled (e.g., from server-side cancellation) + if (msg.type === 'workflow_cancelled') { + addChatNotification(msg.data?.message || t('launch.workflow_cancelled')) + status.value = 'Cancelled' + isWorkflowRunning.value = false + sessionIdToDownload = sessionId + } + // Handle direct error messages (e.g., workflow execution errors) if (msg.type === 'error') { const errorMessage = msg.data?.message || 'Unknown error occurred' @@ -2199,6 +2321,14 @@ const cancelWorkflow = () => { if (!isWorkflowRunning.value || !ws) { return } + + // Send cancel request through WebSocket so the server stops the workflow + try { + ws.send(JSON.stringify({ type: 'cancel' })) + } catch (sendError) { + console.warn('Failed to send cancel message:', sendError) + } + addChatNotification(t('launch.workflow_cancelled')) status.value = 'Cancelled' isWorkflowRunning.value = false @@ -2214,12 +2344,6 @@ const cancelWorkflow = () => { nodesLoadingMessagesMap.delete(nodeId) } } - - try { - ws.close() - } catch (closeError) { - console.warn('Failed to close WebSocket:', closeError) - } } // Download logs diff --git a/server/routes/websocket.py b/server/routes/websocket.py index 870634dbb6..89fe77d897 100755 --- a/server/routes/websocket.py +++ b/server/routes/websocket.py @@ -8,12 +8,12 @@ @router.websocket("/ws") -async def websocket_endpoint(websocket: WebSocket): +async def websocket_endpoint(websocket: WebSocket, session_id: str = ""): manager = get_websocket_manager() - session_id = await manager.connect(websocket) + sid = await manager.connect(websocket, session_id=session_id or None) try: while True: message = await websocket.receive_text() - await manager.handle_message(session_id, message) + await manager.handle_message(sid, message) except WebSocketDisconnect: - manager.disconnect(session_id) + manager.disconnect(sid) diff --git a/server/services/message_handler.py b/server/services/message_handler.py index 75580b5816..bc523691c7 100755 --- a/server/services/message_handler.py +++ b/server/services/message_handler.py @@ -29,12 +29,22 @@ async def handle_message(self, session_id: str, data: Dict[str, Any], websocket_ await self._handle_ping(session_id, websocket_manager) elif message_type == "get_status": await self._handle_get_status(session_id, websocket_manager) + elif message_type == "cancel": + await self._handle_cancel(session_id, websocket_manager) else: await websocket_manager.send_message( session_id, {"type": "error", "data": {"message": f"Unknown message type: {message_type}"}}, ) + async def _handle_cancel(self, session_id: str, websocket_manager): + if self.workflow_run_service: + self.workflow_run_service.request_cancel(session_id, reason="User requested cancellation") + await websocket_manager.send_message( + session_id, + {"type": "input_received", "data": {"message": "Cancellation requested"}}, + ) + async def _handle_human_input(self, session_id: str, data: Dict[str, Any], websocket_manager): try: payload = data.get("data", {}) or {} diff --git a/server/services/session_store.py b/server/services/session_store.py index 433e3be2eb..1c31bd4ed4 100755 --- a/server/services/session_store.py +++ b/server/services/session_store.py @@ -56,6 +56,16 @@ class WorkflowSession: cancel_event: Event = field(default_factory=Event) cancel_reason: Optional[str] = None + # Message buffer for reconnection replay + message_buffer: list = field(default_factory=list) + + MAX_BUFFER_SIZE: int = 1000 + + def append_message(self, message: Dict[str, Any]) -> None: + if len(self.message_buffer) >= self.MAX_BUFFER_SIZE: + self.message_buffer.pop(0) + self.message_buffer.append(message) + class WorkflowSessionStore: """In-memory registry that tracks workflow session metadata.""" @@ -129,3 +139,20 @@ def list_sessions(self) -> Dict[str, Dict[str, Any]]: def get_artifact_queue(self, session_id: str) -> Optional[ArtifactEventQueue]: session = self._sessions.get(session_id) return session.artifact_queue if session else None + + def get_session_snapshot(self, session_id: str) -> Optional[Dict[str, Any]]: + session = self._sessions.get(session_id) + if not session: + return None + return { + "session_id": session.session_id, + "yaml_file": session.yaml_file, + "task_prompt": session.task_prompt, + "status": session.status.value, + "current_node_id": session.current_node_id, + "created_at": session.created_at, + "updated_at": session.updated_at, + "waiting_for_input": session.waiting_for_input, + "error_message": session.error_message, + "message_count": len(session.message_buffer), + } diff --git a/server/services/websocket_manager.py b/server/services/websocket_manager.py index e5d0df18fa..f41ab47554 100755 --- a/server/services/websocket_manager.py +++ b/server/services/websocket_manager.py @@ -39,6 +39,8 @@ def _encode_ws_message(message: Any) -> str: class WebSocketManager: + SESSION_TTL_SECONDS = 24 * 60 * 60 # 24 hours + def __init__( self, *, @@ -50,6 +52,7 @@ def __init__( self.active_connections: Dict[str, WebSocket] = {} self.connection_timestamps: Dict[str, float] = {} self._owner_loop: Optional[asyncio.AbstractEventLoop] = None + self._gc_task: Optional[asyncio.Task] = None self.session_store = session_store or WorkflowSessionStore() self.session_controller = session_controller or SessionExecutionController(self.session_store) self.attachment_service = attachment_service or AttachmentService() @@ -70,11 +73,55 @@ async def connect(self, websocket: WebSocket, session_id: Optional[str] = None) # worker threads can safely schedule sends via run_coroutine_threadsafe. if self._owner_loop is None: self._owner_loop = asyncio.get_running_loop() + + # --- Reconnect to existing session --- + if session_id and self.session_store.has_session(session_id): + # If an old WebSocket is still tied to this session, close it first + if session_id in self.active_connections: + old_ws = self.active_connections[session_id] + try: + await old_ws.close(code=1000, reason="Replaced by new connection") + except Exception: + pass + + self.active_connections[session_id] = websocket + self.connection_timestamps[session_id] = time.time() + logging.info("WebSocket reconnected to existing session: %s", session_id) + + # Always start the GC loop (idempotent) + self._start_gc() + + # Send connection confirmation + await self._send_raw( + session_id, + {"type": "connection", "data": {"session_id": session_id, "status": "connected"}}, + ) + + # Replay all buffered messages (snapshot to avoid including messages + # that arrive during replay) + session = self.session_store.get_session(session_id) + if session: + messages_to_replay = list(session.message_buffer) + for msg in messages_to_replay: + await self._send_raw(session_id, msg) + + # Send session state snapshot + snapshot = self.session_store.get_session_snapshot(session_id) + if snapshot: + await self._send_raw(session_id, {"type": "session_resumed", "data": snapshot}) + + return session_id + + # --- New connection --- if not session_id: session_id = str(uuid.uuid4()) self.active_connections[session_id] = websocket self.connection_timestamps[session_id] = time.time() logging.info("WebSocket connected: %s", session_id) + + # Always start the GC loop (idempotent) + self._start_gc() + await self.send_message( session_id, { @@ -85,24 +132,19 @@ async def connect(self, websocket: WebSocket, session_id: Optional[str] = None) return session_id def disconnect(self, session_id: str) -> None: - session = self.session_store.get_session(session_id) - if session and session.status in {SessionStatus.RUNNING, SessionStatus.WAITING_FOR_INPUT}: - self.workflow_run_service.request_cancel( - session_id, - reason="WebSocket disconnected", - ) if session_id in self.active_connections: del self.active_connections[session_id] if session_id in self.connection_timestamps: del self.connection_timestamps[session_id] - self.session_controller.cleanup_session(session_id) - remaining_session = self.session_store.get_session(session_id) - if remaining_session and remaining_session.executor is None: - self.session_store.pop_session(session_id) - self.attachment_service.cleanup_session(session_id) - logging.info("WebSocket disconnected: %s", session_id) + logging.info("WebSocket disconnected (session preserved): %s", session_id) async def send_message(self, session_id: str, message: Dict[str, Any]) -> None: + # Buffer business messages for reconnection replay (exclude transport messages) + if message.get("type") not in ("connection", "pong"): + session = self.session_store.get_session(session_id) + if session: + session.append_message(message) + if session_id in self.active_connections: websocket = self.active_connections[session_id] try: @@ -110,7 +152,16 @@ async def send_message(self, session_id: str, message: Dict[str, Any]) -> None: except Exception as exc: traceback.print_exc() logging.error("Failed to send message to %s: %s", session_id, exc) - # self.disconnect(session_id) + + async def _send_raw(self, session_id: str, message: Dict[str, Any]) -> None: + """Send a message without buffering. Used for replay and connection management.""" + if session_id in self.active_connections: + websocket = self.active_connections[session_id] + try: + await websocket.send_text(_encode_ws_message(message)) + except Exception as exc: + traceback.print_exc() + logging.error("Failed to send raw message to %s: %s", session_id, exc) def send_message_sync(self, session_id: str, message: Dict[str, Any]) -> None: """Send a WebSocket message from any thread (including worker threads). @@ -174,3 +225,26 @@ async def handle_message(self, session_id: str, message: str) -> None: session_id, {"type": "error", "data": {"message": str(exc)}}, ) + + def _start_gc(self) -> None: + """Start the background GC task if not already running.""" + if self._gc_task is not None and not self._gc_task.done(): + return + loop = asyncio.get_running_loop() + self._gc_task = loop.create_task(self._gc_loop()) + + async def _gc_loop(self) -> None: + """Periodically clean up terminal sessions older than TTL.""" + TERMINAL = {SessionStatus.COMPLETED, SessionStatus.ERROR, SessionStatus.CANCELLED} + while True: + await asyncio.sleep(3600) # run every hour + now = time.time() + to_remove = [] + for sid, session in self.session_store._sessions.items(): + if session.status in TERMINAL: + if now - session.updated_at > self.SESSION_TTL_SECONDS: + to_remove.append(sid) + for sid in to_remove: + self.session_store.pop_session(sid) + self.attachment_service.cleanup_session(sid) + logging.info("GC: removed expired session %s", sid) diff --git a/server/services/workflow_run_service.py b/server/services/workflow_run_service.py index 9b0689b904..a4e47c2c96 100755 --- a/server/services/workflow_run_service.py +++ b/server/services/workflow_run_service.py @@ -264,8 +264,6 @@ async def _execute_workflow_async( session_ref.executor = None session_ref.graph = None self.session_controller.cleanup_session(session_id) - if session_id not in websocket_manager.active_connections: - self.session_store.pop_session(session_id) def _build_initial_task_input( self,