From c26542491756b8bc6d053cfc41ffe406931126ee Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 26 May 2026 16:52:49 -0400 Subject: [PATCH 1/4] improved websocket connection management --- .../src/WebSocketManager.ts | 8 + @plotly/dash-websocket-worker/src/types.ts | 1 + @plotly/dash-websocket-worker/src/worker.ts | 7 + dash/backends/_fastapi.py | 32 +- dash/backends/_quart.py | 32 +- dash/backends/_ws_registry.py | 166 +++++++ dash/backends/ws.py | 157 ++++-- .../src/observers/websocketObserver.ts | 60 +++ dash/dash-renderer/src/utils/workerClient.ts | 14 + tests/websocket/test_ws_reconnect.py | 449 ++++++++++++++++++ 10 files changed, 879 insertions(+), 47 deletions(-) create mode 100644 dash/backends/_ws_registry.py create mode 100644 tests/websocket/test_ws_reconnect.py diff --git a/@plotly/dash-websocket-worker/src/WebSocketManager.ts b/@plotly/dash-websocket-worker/src/WebSocketManager.ts index d96a0d8e68..942917f075 100644 --- a/@plotly/dash-websocket-worker/src/WebSocketManager.ts +++ b/@plotly/dash-websocket-worker/src/WebSocketManager.ts @@ -144,6 +144,14 @@ export class WebSocketManager { return this.ws !== null && this.ws.readyState === WebSocket.OPEN; } + /** + * Reset the activity timer. + * Call this when a tab becomes visible to prevent inactivity timeout. + */ + public resetActivity(): void { + this.lastActivityTime = Date.now(); + } + private createConnection(): void { if (!this.serverUrl) { return; diff --git a/@plotly/dash-websocket-worker/src/types.ts b/@plotly/dash-websocket-worker/src/types.ts index 5d1ff80bf0..3d88af068a 100644 --- a/@plotly/dash-websocket-worker/src/types.ts +++ b/@plotly/dash-websocket-worker/src/types.ts @@ -7,6 +7,7 @@ export enum WorkerMessageType { DISCONNECT = 'disconnect', CALLBACK_REQUEST = 'callback_request', GET_PROPS_RESPONSE = 'get_props_response', + TAB_VISIBLE = 'tab_visible', // Worker -> Renderer CONNECTED = 'connected', diff --git a/@plotly/dash-websocket-worker/src/worker.ts b/@plotly/dash-websocket-worker/src/worker.ts index 0e68f0b09a..e28d0c1583 100644 --- a/@plotly/dash-websocket-worker/src/worker.ts +++ b/@plotly/dash-websocket-worker/src/worker.ts @@ -122,6 +122,13 @@ self.onconnect = (event: MessageEvent) => { break; } + case WorkerMessageType.TAB_VISIBLE: { + // Reset activity timer when tab becomes visible + // This prevents inactivity timeout while user is viewing the tab + wsManager.resetActivity(); + break; + } + default: // Forward other messages through the router router.handleRendererMessage(message.rendererId, message); diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index c46fb4ffc5..b619c040c9 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -50,6 +50,7 @@ SHUTDOWN_SIGNAL, DISCONNECTED, ) +from ._ws_registry import ActiveCallbackRegistry from ._utils import format_traceback_html if TYPE_CHECKING: # pragma: no cover - typing only @@ -677,6 +678,12 @@ def serve_websocket_callback(self, dash_app: "Dash"): dash_app, "_websocket_allowed_origins", [] ) # pylint: disable=protected-access + # Initialize registry on dash_app if not present + # pylint: disable=protected-access + if not hasattr(dash_app, "_ws_callback_registry"): + dash_app._ws_callback_registry = ActiveCallbackRegistry() + registry: ActiveCallbackRegistry = dash_app._ws_callback_registry + def validate_origin(origin: str | None, host: str | None) -> str | None: """Validate WebSocket origin. Returns error message or None if valid.""" if not origin: @@ -723,6 +730,8 @@ async def websocket_handler(websocket: WebSocket): executor = self.get_callback_executor() # Track pending callback futures pending_callbacks: Dict[str, concurrent.futures.Future] = {} + # Track current renderer ID for this connection + current_renderer_id: str | None = None # Start sender task to drain outbound queue (sends pre-serialized text) # pylint: disable=protected-access @@ -753,6 +762,22 @@ async def websocket_handler(websocket: WebSocket): renderer_id = message.get("rendererId", "") payload = message.get("payload", {}) + # Update current renderer ID for cleanup + current_renderer_id = renderer_id + + # Adopt connection for this renderer (allows reconnection) + # Called for every callback to ensure registry entry exists + # (entry may have been cleaned up after previous callback) + registry.adopt_connection( + renderer_id, + outbound_queue, + pending_get_props, + shutdown_event, + ) + + # Register this callback with the registry + registry.register_callback(renderer_id) + # Validate that the callback is allowed to use WebSocket transport # pylint: disable=protected-access _validate.validate_websocket_callback_request( @@ -761,12 +786,13 @@ async def websocket_handler(websocket: WebSocket): dash_app._websocket_callbacks, ) - # Create WebSocket callback instance with outbound queue + # Create WebSocket callback instance with registry ws_cb = DashWebsocketCallback( pending_get_props, renderer_id, outbound_queue, shutdown_event, + registry=registry, ) # Submit callback to executor @@ -786,6 +812,7 @@ async def websocket_handler(websocket: WebSocket): request_id, renderer_id, shutdown_event, + registry=registry, ) ) pending_callbacks[request_id] = future @@ -821,6 +848,9 @@ async def websocket_handler(websocket: WebSocket): # Cancel any pending futures for f in pending_callbacks.values(): f.cancel() + # Cleanup registry entry if no active callbacks + if current_renderer_id is not None: + registry.cleanup_renderer(current_renderer_id) self.server.add_api_websocket_route(ws_path, websocket_handler) diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 881fd6466f..31d9668d36 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -55,6 +55,7 @@ SHUTDOWN_SIGNAL, DISCONNECTED, ) +from ._ws_registry import ActiveCallbackRegistry from ._utils import format_traceback_html if TYPE_CHECKING: @@ -521,6 +522,12 @@ def serve_websocket_callback(self, dash_app: "Dash"): # pylint: disable=protected-access allowed_origins = getattr(dash_app, "_websocket_allowed_origins", []) + # Initialize registry on dash_app if not present + # pylint: disable=protected-access + if not hasattr(dash_app, "_ws_callback_registry"): + dash_app._ws_callback_registry = ActiveCallbackRegistry() + registry: ActiveCallbackRegistry = dash_app._ws_callback_registry + @self.server.websocket(ws_path) async def websocket_handler(): # pylint: disable=too-many-branches ws = websocket @@ -564,6 +571,8 @@ async def websocket_handler(): # pylint: disable=too-many-branches executor = self.get_callback_executor() # Track pending callback futures pending_callbacks: Dict[str, concurrent.futures.Future] = {} + # Track current renderer ID for this connection + current_renderer_id: str | None = None # Start sender task to drain outbound queue (sends pre-serialized text) # pylint: disable=protected-access @@ -601,6 +610,22 @@ async def websocket_handler(): # pylint: disable=too-many-branches renderer_id = message.get("rendererId", "") payload = message.get("payload", {}) + # Update current renderer ID for cleanup + current_renderer_id = renderer_id + + # Adopt connection for this renderer (allows reconnection) + # Called for every callback to ensure registry entry exists + # (entry may have been cleaned up after previous callback) + registry.adopt_connection( + renderer_id, + outbound_queue, + pending_get_props, + connection_shutdown_event, + ) + + # Register this callback with the registry + registry.register_callback(renderer_id) + # Validate that the callback is allowed to use WebSocket transport # pylint: disable=protected-access _validate.validate_websocket_callback_request( @@ -609,12 +634,13 @@ async def websocket_handler(): # pylint: disable=too-many-branches dash_app._websocket_callbacks, ) - # Create WebSocket callback instance with outbound queue + # Create WebSocket callback instance with registry ws_cb = DashWebsocketCallback( pending_get_props, renderer_id, outbound_queue, connection_shutdown_event, + registry=registry, ) # Submit callback to executor @@ -634,6 +660,7 @@ async def websocket_handler(): # pylint: disable=too-many-branches request_id, renderer_id, connection_shutdown_event, + registry=registry, ) ) pending_callbacks[request_id] = future @@ -672,6 +699,9 @@ async def websocket_handler(): # pylint: disable=too-many-branches # Cancel any pending futures for f in pending_callbacks.values(): f.cancel() + # Cleanup registry entry if no active callbacks + if current_renderer_id is not None: + registry.cleanup_renderer(current_renderer_id) class QuartRequestAdapter(RequestAdapter): diff --git a/dash/backends/_ws_registry.py b/dash/backends/_ws_registry.py new file mode 100644 index 0000000000..4ab1132d6a --- /dev/null +++ b/dash/backends/_ws_registry.py @@ -0,0 +1,166 @@ +"""WebSocket callback registry for handling reconnections. + +This module provides a registry that tracks active callbacks per renderer, +allowing callbacks to persist across WebSocket reconnections. +""" + +from __future__ import annotations + +import threading +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, Optional + +if TYPE_CHECKING: + import queue + import janus + + +@dataclass +class RendererState: + """State for a single renderer's WebSocket connection.""" + + outbound_queue: "janus.Queue[str]" + pending_get_props: Dict[str, "queue.Queue[Any]"] + shutdown_event: threading.Event + active_callback_count: int = 0 + lock: threading.Lock = field(default_factory=threading.Lock) + + +class ActiveCallbackRegistry: + """Registry for active WebSocket callbacks that persists across reconnections. + + When a WebSocket disconnects and reconnects, callbacks that are still running + can "adopt" the new connection's queues to continue sending updates. + + Thread-safe for access from both the main event loop and worker threads. + """ + + def __init__(self) -> None: + self._renderers: Dict[str, RendererState] = {} + self._lock = threading.Lock() + + def adopt_connection( + self, + renderer_id: str, + outbound_queue: "janus.Queue[str]", + pending_get_props: Dict[str, "queue.Queue[Any]"], + shutdown_event: threading.Event, + ) -> None: + """Associate new connection with existing callbacks for this renderer. + + When a WebSocket reconnects, this method updates the queues and shutdown + event so that running callbacks can use the new connection. + + Args: + renderer_id: The renderer ID for this connection + outbound_queue: janus.Queue for sending messages + pending_get_props: Dict to track pending get_props requests + shutdown_event: Event signaling connection closure + """ + with self._lock: + if renderer_id in self._renderers: + state = self._renderers[renderer_id] + with state.lock: + state.outbound_queue = outbound_queue + state.pending_get_props = pending_get_props + state.shutdown_event = shutdown_event + else: + self._renderers[renderer_id] = RendererState( + outbound_queue=outbound_queue, + pending_get_props=pending_get_props, + shutdown_event=shutdown_event, + active_callback_count=0, + ) + + def register_callback(self, renderer_id: str) -> None: + """Register a new active callback for this renderer. + + Args: + renderer_id: The renderer ID + """ + with self._lock: + if renderer_id in self._renderers: + state = self._renderers[renderer_id] + with state.lock: + state.active_callback_count += 1 + + def unregister_callback(self, renderer_id: str) -> None: + """Unregister a completed callback for this renderer. + + If no active callbacks remain, the renderer state is cleaned up. + + Args: + renderer_id: The renderer ID + """ + with self._lock: + if renderer_id in self._renderers: + state = self._renderers[renderer_id] + with state.lock: + state.active_callback_count -= 1 + if state.active_callback_count <= 0: + del self._renderers[renderer_id] + + def get_queue(self, renderer_id: str) -> Optional["janus.Queue[str]"]: + """Get current outbound queue for renderer (thread-safe). + + Args: + renderer_id: The renderer ID + + Returns: + The current outbound queue, or None if renderer not found + """ + with self._lock: + state = self._renderers.get(renderer_id) + if state is None: + return None + with state.lock: + return state.outbound_queue + + def get_pending_get_props( + self, renderer_id: str + ) -> Optional[Dict[str, "queue.Queue[Any]"]]: + """Get current pending_get_props dict for renderer (thread-safe). + + Args: + renderer_id: The renderer ID + + Returns: + The current pending_get_props dict, or None if renderer not found + """ + with self._lock: + state = self._renderers.get(renderer_id) + if state is None: + return None + with state.lock: + return state.pending_get_props + + def is_shutdown(self, renderer_id: str) -> bool: + """Check if current connection is shutdown. + + Args: + renderer_id: The renderer ID + + Returns: + True if shutdown event is set or renderer not found, False otherwise + """ + with self._lock: + state = self._renderers.get(renderer_id) + if state is None: + return True + with state.lock: + return state.shutdown_event.is_set() + + def cleanup_renderer(self, renderer_id: str) -> None: + """Clean up renderer state when connection closes. + + Only removes if no active callbacks remain. + + Args: + renderer_id: The renderer ID to clean up + """ + with self._lock: + state = self._renderers.get(renderer_id) + if state is not None: + with state.lock: + if state.active_callback_count <= 0: + del self._renderers[renderer_id] diff --git a/dash/backends/ws.py b/dash/backends/ws.py index f44913d873..adebdf2c4a 100644 --- a/dash/backends/ws.py +++ b/dash/backends/ws.py @@ -25,6 +25,7 @@ if TYPE_CHECKING: import dash from .base_server import ResponseAdapter + from ._ws_registry import ActiveCallbackRegistry SHUTDOWN_SIGNAL = "__shutdown__" @@ -41,6 +42,10 @@ class DashWebsocketCallback: Uses janus.Queue for outbound messages (serialized with to_json) and queue.Queue for get_props responses, enabling thread-safe communication between worker threads and the main event loop. + + Supports two modes: + 1. Registry mode: Uses ActiveCallbackRegistry to allow queue adoption on reconnect + 2. Direct mode: Uses direct queue references (legacy, for backwards compatibility) """ def __init__( @@ -49,6 +54,7 @@ def __init__( renderer_id: str, outbound_queue: janus.Queue[str], shutdown_event: "threading.Event", + registry: "ActiveCallbackRegistry | None" = None, ): """Initialize the WebSocket callback interface. @@ -58,26 +64,46 @@ def __init__( renderer_id: The renderer ID for routing messages back to the correct client outbound_queue: janus.Queue for thread-safe outbound messaging. shutdown_event: Event signaling the websocket connection has closed. + registry: Optional registry for handling reconnections. If provided, + the callback will use the registry to get current queues, allowing + it to survive reconnections. """ self._pending_get_props = pending_get_props self._renderer_id = renderer_id self._outbound_queue = outbound_queue self._shutdown_event = shutdown_event + self._registry = registry @property def is_shutdown(self) -> bool: """Check if the websocket connection has been shut down.""" + if self._registry is not None: + return self._registry.is_shutdown(self._renderer_id) return self._shutdown_event.is_set() + def _get_outbound_queue(self) -> janus.Queue[str] | None: + """Get the current outbound queue (may be updated on reconnect).""" + if self._registry is not None: + return self._registry.get_queue(self._renderer_id) + return self._outbound_queue + + def _get_pending_get_props(self) -> Dict[str, queue.Queue[Any]] | None: + """Get the current pending_get_props dict (may be updated on reconnect).""" + if self._registry is not None: + return self._registry.get_pending_get_props(self._renderer_id) + return self._pending_get_props + def _queue_message(self, msg: dict) -> None: """Serialize and queue message for sending (thread-safe, non-blocking). Uses to_json for proper serialization of Dash components. Does nothing if the connection has been shut down. """ - if self._shutdown_event.is_set(): + if self.is_shutdown: return - self._outbound_queue.sync_q.put_nowait(cast(str, to_json(msg))) + outbound_queue = self._get_outbound_queue() + if outbound_queue is not None: + outbound_queue.sync_q.put_nowait(cast(str, to_json(msg))) async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: """Send immediate prop update to the client via WebSocket. @@ -115,7 +141,11 @@ async def get_prop( WebsocketDisconnected: If the websocket connection has been closed. TimeoutError: If the response doesn't arrive within the timeout. """ - if self._shutdown_event.is_set(): + if self.is_shutdown: + raise WebsocketDisconnected() + + pending_get_props = self._get_pending_get_props() + if pending_get_props is None: raise WebsocketDisconnected() request_id = str(uuid.uuid4()) @@ -128,7 +158,7 @@ async def get_prop( # Use standard queue.Queue for response response_queue: queue.Queue = queue.Queue() - self._pending_get_props[request_id] = response_queue + pending_get_props[request_id] = response_queue # Queue the outbound request via janus sync interface self._queue_message(msg) @@ -146,7 +176,10 @@ async def get_prop( f"Timeout waiting for {component_id}.{prop_name}" ) from exc finally: - self._pending_get_props.pop(request_id, None) + # Get fresh reference in case of reconnection + current_pending = self._get_pending_get_props() + if current_pending is not None: + current_pending.pop(request_id, None) def create_ws_context( @@ -219,21 +252,26 @@ async def run_ws_sender( return if msg == FLUSH_SIGNAL: if messages: - await _send_batched(send_text, messages) + if not await _send_batched(send_text, messages): + return # Connection closed messages = [] continue if not batch_delay: - await send_text(msg) + try: + await send_text(msg) + except Exception: # WebSocketDisconnect, RuntimeError, etc. + return # Connection closed else: messages.append(msg) except asyncio.TimeoutError: - await _send_batched(send_text, messages) + if not await _send_batched(send_text, messages): + return # Connection closed messages = [] except asyncio.CancelledError: pass -async def _send_batched(send_text: Callable[[str], Any], messages: list) -> None: +async def _send_batched(send_text: Callable[[str], Any], messages: list) -> bool: """Send messages as a batch. Single messages are sent as-is. Multiple messages are wrapped @@ -242,12 +280,19 @@ async def _send_batched(send_text: Callable[[str], Any], messages: list) -> None Args: send_text: Async function to send text data over WebSocket messages: List of pre-serialized JSON message strings + + Returns: + True if send succeeded, False if connection was closed """ - if len(messages) == 1: - await send_text(messages[0]) - else: - # Wrap in array: "[msg1,msg2,msg3]" - await send_text("[" + ",".join(messages) + "]") + try: + if len(messages) == 1: + await send_text(messages[0]) + else: + # Wrap in array: "[msg1,msg2,msg3]" + await send_text("[" + ",".join(messages) + "]") + return True + except Exception: # WebSocketDisconnect, RuntimeError, etc. + return False # Connection closed, cleanup handled by main loop def make_callback_done_handler( @@ -256,6 +301,7 @@ def make_callback_done_handler( request_id: str, renderer_id: str, shutdown_event: threading.Event, + registry: "ActiveCallbackRegistry | None" = None, ) -> Callable[[concurrent.futures.Future], None]: """Create a done callback handler for executor futures. @@ -268,52 +314,73 @@ def make_callback_done_handler( request_id: The request ID for the callback response renderer_id: The renderer ID for routing the response shutdown_event: Event signaling the websocket connection has closed. + registry: Optional registry for managing callback lifecycle. Returns: A callback function suitable for Future.add_done_callback() """ + def _is_shutdown() -> bool: + """Check if connection is shutdown (registry-aware).""" + if registry is not None: + return registry.is_shutdown(renderer_id) + return shutdown_event.is_set() + + def _get_queue() -> janus.Queue[str] | None: + """Get current outbound queue (may change on reconnect).""" + if registry is not None: + return registry.get_queue(renderer_id) + return outbound_queue + def on_done(f: concurrent.futures.Future) -> None: try: - if shutdown_event.is_set(): + if _is_shutdown(): return result = f.result() - outbound_queue.sync_q.put_nowait( - cast( - str, - to_json( - { - "type": "callback_response", - "rendererId": renderer_id, - "requestId": request_id, - "payload": result, - } - ), + current_queue = _get_queue() + if current_queue is not None: + current_queue.sync_q.put_nowait( + cast( + str, + to_json( + { + "type": "callback_response", + "rendererId": renderer_id, + "requestId": request_id, + "payload": result, + } + ), + ) ) - ) except Exception as e: # pylint: disable=broad-exception-caught - if shutdown_event.is_set(): + if _is_shutdown(): return - outbound_queue.sync_q.put_nowait( - cast( - str, - to_json( - { - "type": "callback_response", - "rendererId": renderer_id, - "requestId": request_id, - "payload": { - "status": "error", - "message": str(e), - }, - } - ), + current_queue = _get_queue() + if current_queue is not None: + current_queue.sync_q.put_nowait( + cast( + str, + to_json( + { + "type": "callback_response", + "rendererId": renderer_id, + "requestId": request_id, + "payload": { + "status": "error", + "message": str(e), + }, + } + ), + ) ) - ) finally: pending_callbacks.pop(request_id, None) - if not shutdown_event.is_set(): - outbound_queue.sync_q.put_nowait(FLUSH_SIGNAL) + if registry is not None: + registry.unregister_callback(renderer_id) + if not _is_shutdown(): + current_queue = _get_queue() + if current_queue is not None: + current_queue.sync_q.put_nowait(FLUSH_SIGNAL) return on_done diff --git a/dash/dash-renderer/src/observers/websocketObserver.ts b/dash/dash-renderer/src/observers/websocketObserver.ts index daa3238773..24dc5a39d4 100644 --- a/dash/dash-renderer/src/observers/websocketObserver.ts +++ b/dash/dash-renderer/src/observers/websocketObserver.ts @@ -18,6 +18,8 @@ import { GetPropsRequestPayload } from '../utils/workerClient'; import {DashConfig} from '../config'; +import {addRequestedCallbacks} from '../actions/callbacks'; +import {makeResolvedCallback, resolveDeps} from '../actions/dependencies_ts'; /** * Parse a component ID that may be a stringified JSON object. @@ -175,13 +177,53 @@ export async function initializeWebSocket( workerClient.sendGetPropsResponse(requestId, result); }; + // Track connection state for reconnection handling + let wasDisconnected = false; + // Handle connection events workerClient.onConnected = () => { console.log('[Dash] WebSocket connected'); + + // On reconnect (not initial connect), re-trigger persistent callbacks + if (wasDisconnected) { + console.log( + '[Dash] Reconnected - re-triggering persistent callbacks' + ); + const state = store.getState(); + const {graphs} = state; + + if (graphs?.callbacks) { + const persistentCallbacks = graphs.callbacks.reduce( + (acc: any[], cb: any) => { + // Only re-trigger no-output callbacks with no inputs + // These are the "persistent" callbacks that should restart + if (cb.noOutput && cb.inputs.length === 0) { + const resolved = makeResolvedCallback( + cb, + resolveDeps(), + '' + ); + resolved.initialCall = true; + acc.push(resolved); + } + return acc; + }, + [] + ); + + if (persistentCallbacks.length > 0) { + console.log( + `[Dash] Re-triggering ${persistentCallbacks.length} persistent callback(s)` + ); + store.dispatch(addRequestedCallbacks(persistentCallbacks)); + } + } + } }; workerClient.onDisconnected = (reason?: string) => { console.log(`[Dash] WebSocket disconnected: ${reason}`); + wasDisconnected = true; }; workerClient.onError = (message: string, code?: string) => { @@ -201,6 +243,24 @@ export async function initializeWebSocket( } catch (error) { console.error('[Dash] Failed to connect to WebSocket worker:', error); } + + // Handle tab visibility changes + document.addEventListener('visibilitychange', () => { + if (document.visibilityState === 'visible') { + if (workerClient.connected) { + // Tab visible and connected - reset inactivity timer + workerClient.notifyTabVisible(); + } else { + // Tab visible but disconnected - reconnect + console.log('[Dash] Tab visible, reconnecting WebSocket...'); + workerClient + .ensureConnected(config) + .catch(err => + console.error('[Dash] Failed to reconnect:', err) + ); + } + } + }); } /** diff --git a/dash/dash-renderer/src/utils/workerClient.ts b/dash/dash-renderer/src/utils/workerClient.ts index 01584bf20c..ebfb1223eb 100644 --- a/dash/dash-renderer/src/utils/workerClient.ts +++ b/dash/dash-renderer/src/utils/workerClient.ts @@ -10,6 +10,7 @@ export enum WorkerMessageType { DISCONNECT = 'disconnect', CALLBACK_REQUEST = 'callback_request', GET_PROPS_RESPONSE = 'get_props_response', + TAB_VISIBLE = 'tab_visible', CONNECTED = 'connected', DISCONNECTED = 'disconnected', CALLBACK_RESPONSE = 'callback_response', @@ -251,6 +252,19 @@ class WorkerClient { return this.isConnected; } + /** + * Notify the worker that the tab is now visible. + * This resets the inactivity timer to prevent timeout while user is viewing. + */ + public notifyTabVisible(): void { + if (this.worker && this.isConnected) { + this.worker.port.postMessage({ + type: WorkerMessageType.TAB_VISIBLE, + rendererId: this.rendererId + }); + } + } + private handleMessage(event: MessageEvent): void { const message = event.data; diff --git a/tests/websocket/test_ws_reconnect.py b/tests/websocket/test_ws_reconnect.py new file mode 100644 index 0000000000..64826ab1e6 --- /dev/null +++ b/tests/websocket/test_ws_reconnect.py @@ -0,0 +1,449 @@ +""" +WebSocket reconnection and disconnect handling tests. + +Tests: +- Callback continuity after WebSocket reconnection +- Registry tracks active callbacks correctly +- Disconnect handling doesn't cause error spam +- Long-running callbacks survive reconnection +""" + +import asyncio +import time +import threading + +from dash import Dash, html, Input, Output, set_props +from dash.backends._ws_registry import ActiveCallbackRegistry + + +class TestActiveCallbackRegistry: + """Unit tests for the ActiveCallbackRegistry class.""" + + def test_registry_adopt_creates_entry(self): + """Test that adopt_connection creates a new registry entry.""" + registry = ActiveCallbackRegistry() + + # Mock queue-like object + class MockQueue: + def __init__(self): + self.sync_q = None + + outbound_queue = MockQueue() + pending_get_props = {} + shutdown_event = threading.Event() + + registry.adopt_connection( + "renderer1", outbound_queue, pending_get_props, shutdown_event + ) + + assert registry.get_queue("renderer1") == outbound_queue + assert registry.get_pending_get_props("renderer1") == pending_get_props + assert not registry.is_shutdown("renderer1") + + def test_registry_callback_lifecycle(self): + """Test register/unregister callback with cleanup.""" + registry = ActiveCallbackRegistry() + + class MockQueue: + def __init__(self): + self.sync_q = None + + outbound_queue = MockQueue() + shutdown_event = threading.Event() + + registry.adopt_connection("renderer1", outbound_queue, {}, shutdown_event) + + # Register callback + registry.register_callback("renderer1") + assert not registry.is_shutdown("renderer1") + + # Unregister - should clean up entry since count becomes 0 + registry.unregister_callback("renderer1") + assert registry.is_shutdown("renderer1") # Returns True when not found + + def test_registry_multiple_callbacks(self): + """Test that multiple callbacks keep entry alive.""" + registry = ActiveCallbackRegistry() + + class MockQueue: + def __init__(self): + self.sync_q = None + + outbound_queue = MockQueue() + shutdown_event = threading.Event() + + registry.adopt_connection("renderer1", outbound_queue, {}, shutdown_event) + + # Register two callbacks + registry.register_callback("renderer1") + registry.register_callback("renderer1") + + # Unregister one - entry should still exist + registry.unregister_callback("renderer1") + assert not registry.is_shutdown("renderer1") + + # Unregister second - now should be cleaned up + registry.unregister_callback("renderer1") + assert registry.is_shutdown("renderer1") + + def test_registry_adopt_after_cleanup(self): + """Test that adopt_connection works after cleanup.""" + registry = ActiveCallbackRegistry() + + class MockQueue: + def __init__(self): + self.sync_q = None + + outbound_queue = MockQueue() + shutdown_event = threading.Event() + + # First connection + registry.adopt_connection("renderer1", outbound_queue, {}, shutdown_event) + registry.register_callback("renderer1") + registry.unregister_callback("renderer1") # Cleans up + + # Re-adopt after cleanup + registry.adopt_connection("renderer1", outbound_queue, {}, shutdown_event) + assert not registry.is_shutdown("renderer1") + + def test_registry_adopt_updates_existing(self): + """Test that adopt_connection updates queues for existing entry.""" + registry = ActiveCallbackRegistry() + + class MockQueue: + def __init__(self, name): + self.name = name + self.sync_q = None + + old_queue = MockQueue("old") + new_queue = MockQueue("new") + old_shutdown = threading.Event() + new_shutdown = threading.Event() + + registry.adopt_connection("renderer1", old_queue, {}, old_shutdown) + registry.register_callback("renderer1") # Keep entry alive + + assert registry.get_queue("renderer1").name == "old" + + # Simulate reconnection + registry.adopt_connection("renderer1", new_queue, {}, new_shutdown) + + assert registry.get_queue("renderer1").name == "new" + + def test_registry_shutdown_event_respected(self): + """Test that shutdown event is checked correctly.""" + registry = ActiveCallbackRegistry() + + class MockQueue: + def __init__(self): + self.sync_q = None + + outbound_queue = MockQueue() + shutdown_event = threading.Event() + + registry.adopt_connection("renderer1", outbound_queue, {}, shutdown_event) + + assert not registry.is_shutdown("renderer1") + + shutdown_event.set() + + assert registry.is_shutdown("renderer1") + + def test_registry_unknown_renderer_is_shutdown(self): + """Test that unknown renderer IDs report as shutdown.""" + registry = ActiveCallbackRegistry() + + assert registry.is_shutdown("unknown_renderer") + assert registry.get_queue("unknown_renderer") is None + assert registry.get_pending_get_props("unknown_renderer") is None + + +def test_ws030_multiple_callbacks_same_connection(dash_duo): + """Test multiple sequential callbacks on the same WebSocket connection.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn", n_clicks=0), + html.Div("0", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return str(n_clicks or 0) + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "0") + + # Multiple clicks - each should work via the same connection + for i in range(1, 6): + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", str(i)) + + assert dash_duo.get_logs() == [] + + +def test_ws031_rapid_callbacks_registry_handling(dash_duo): + """Test that rapid callbacks are handled correctly by registry.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Rapid Click", id="btn", n_clicks=0), + html.Div("0", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return str(n_clicks or 0) + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "0") + + # Rapid clicks without waiting + for _ in range(10): + dash_duo.find_element("#btn").click() + time.sleep(0.05) # 50ms between clicks + + # Should eventually reach 10 + dash_duo.wait_for_text_to_equal("#output", "10", timeout=10) + + assert dash_duo.get_logs() == [] + + +def test_ws032_long_callback_with_set_props(dash_duo): + """Test long-running callback with intermediate set_props updates.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Start", id="btn", n_clicks=0), + html.Div("ready", id="status"), + html.Div("0", id="progress"), + ] + ) + + @app.callback( + Output("status", "children"), + Input("btn", "n_clicks"), + prevent_initial_call=True, + ) + async def long_task(n_clicks): + set_props("status", {"children": "running"}) + + # Simulate progress updates + for i in range(1, 6): + set_props("progress", {"children": str(i * 20)}) + await asyncio.sleep(0.1) + + return "done" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#status", "ready") + + dash_duo.find_element("#btn").click() + + # Should see intermediate updates + dash_duo.wait_for_text_to_equal("#status", "done", timeout=10) + dash_duo.wait_for_text_to_equal("#progress", "100") + + assert dash_duo.get_logs() == [] + + +def test_ws033_callback_after_reconnect(dash_duo): + """Test that callbacks work after WebSocket reconnection.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=2000, # 2 seconds + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn", n_clicks=0), + html.Div("0", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return str(n_clicks or 0) + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "0") + + # First click + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "1") + + # Wait for connection to timeout + time.sleep(3) + + # Click after reconnection - should still work + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "2") + + # Multiple clicks after reconnection + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "3") + + assert dash_duo.get_logs() == [] + + +def test_ws034_concurrent_callbacks(dash_duo): + """Test multiple concurrent callbacks from different inputs.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Button A", id="btn-a", n_clicks=0), + html.Button("Button B", id="btn-b", n_clicks=0), + html.Div("a:0", id="output-a"), + html.Div("b:0", id="output-b"), + ] + ) + + @app.callback(Output("output-a", "children"), Input("btn-a", "n_clicks")) + async def on_click_a(n_clicks): + await asyncio.sleep(0.1) # Small delay to ensure overlap + return f"a:{n_clicks or 0}" + + @app.callback(Output("output-b", "children"), Input("btn-b", "n_clicks")) + async def on_click_b(n_clicks): + await asyncio.sleep(0.1) + return f"b:{n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output-a", "a:0") + dash_duo.wait_for_text_to_equal("#output-b", "b:0") + + # Click both buttons rapidly + dash_duo.find_element("#btn-a").click() + dash_duo.find_element("#btn-b").click() + + dash_duo.wait_for_text_to_equal("#output-a", "a:1") + dash_duo.wait_for_text_to_equal("#output-b", "b:1") + + # More concurrent clicks + dash_duo.find_element("#btn-a").click() + dash_duo.find_element("#btn-b").click() + dash_duo.find_element("#btn-a").click() + + dash_duo.wait_for_text_to_equal("#output-a", "a:3") + dash_duo.wait_for_text_to_equal("#output-b", "b:2") + + assert dash_duo.get_logs() == [] + + +def test_ws035_callback_survives_inactivity_timeout(dash_duo): + """Test that long callback completes even when inactivity timeout triggers mid-execution. + + This is the key test for Issue #3788: when a callback runs longer than the + inactivity timeout without sending updates, the WebSocket disconnects and + reconnects. The callback should still complete and send its result via the + new connection. + """ + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=2000, # 2 seconds + ) + + app.layout = html.Div( + [ + html.Button("Start", id="btn", n_clicks=0), + html.Div("ready", id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("btn", "n_clicks"), + prevent_initial_call=True, + ) + async def silent_long_task(n_clicks): + # Wait longer than inactivity timeout WITHOUT sending any updates + # This will trigger WebSocket disconnect/reconnect mid-callback + await asyncio.sleep(5) + return f"completed:{n_clicks}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "ready") + + # Start the long task + dash_duo.find_element("#btn").click() + + # Should complete despite inactivity timeout triggering during execution + dash_duo.wait_for_text_to_equal("#output", "completed:1", timeout=15) + + # Verify subsequent callbacks also work + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "completed:2", timeout=15) + + assert dash_duo.get_logs() == [] + + +def test_ws036_set_props_after_reconnect(dash_duo): + """Test that set_props works after WebSocket reconnects mid-callback. + + This tests the registry's ability to adopt new queues so that + set_props calls use the new connection after reconnection. + """ + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=2000, # 2 seconds + ) + + app.layout = html.Div( + [ + html.Button("Start", id="btn", n_clicks=0), + html.Div("ready", id="status"), + html.Div("0", id="progress"), + ] + ) + + @app.callback( + Output("status", "children"), + Input("btn", "n_clicks"), + prevent_initial_call=True, + ) + async def task_with_late_set_props(n_clicks): + set_props("status", {"children": "started"}) + set_props("progress", {"children": "10"}) + + # Wait long enough for inactivity timeout to trigger + await asyncio.sleep(5) + + # These set_props calls happen AFTER reconnection + # They should still work via the adopted queue + set_props("progress", {"children": "100"}) + + return "done" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#status", "ready") + + dash_duo.find_element("#btn").click() + + # Should see initial updates + dash_duo.wait_for_text_to_equal("#status", "started", timeout=5) + dash_duo.wait_for_text_to_equal("#progress", "10", timeout=5) + + # Should see final update after reconnection + dash_duo.wait_for_text_to_equal("#progress", "100", timeout=15) + dash_duo.wait_for_text_to_equal("#status", "done", timeout=5) + + assert dash_duo.get_logs() == [] From 21278e0ffd2a05d6c9586427d33fc40bf832a261 Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 26 May 2026 17:05:55 -0400 Subject: [PATCH 2/4] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 62d49513c5..a262be18b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ This project adheres to [Semantic Versioning](https://semver.org/). - [#3669](https://github.com/plotly/dash/pull/3669) Selection for DataTable cleared with custom action settings - [#3680](https://github.com/plotly/dash/pull/3680) Added `search_order` prop to `Dropdown` to allow users to preserve original option order during search - Added `csrf_token_name` and `csrf_header_name` config options to allow configuring the CSRF cookie and header names. Fixes [#729](https://github.com/plotly/dash/issues/729) +- [#3797](https://github.com/plotly/dash/pull/3797) Improved websocket callback management. ## Added - [#3523](https://github.com/plotly/dash/pull/3523) Fall back to background callback function names if source cannot be found From 6f6442f0b185f32d876900b4d55f8ce89e06c62a Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 26 May 2026 18:52:22 -0400 Subject: [PATCH 3/4] lint fix --- dash/backends/ws.py | 81 ++++++++++++++++++++++++++++++--------------- 1 file changed, 54 insertions(+), 27 deletions(-) diff --git a/dash/backends/ws.py b/dash/backends/ws.py index adebdf2c4a..807a167ab7 100644 --- a/dash/backends/ws.py +++ b/dash/backends/ws.py @@ -242,35 +242,62 @@ async def run_ws_sender( messages: list[str] = [] try: while True: - # Wait indefinitely for first message, then use timeout for batching - timeout = batch_delay if messages else None - try: - msg = await asyncio.wait_for(q.get(), timeout=timeout) - if msg == SHUTDOWN_SIGNAL: - if messages: - await _send_batched(send_text, messages) - return - if msg == FLUSH_SIGNAL: - if messages: - if not await _send_batched(send_text, messages): - return # Connection closed - messages = [] - continue - if not batch_delay: - try: - await send_text(msg) - except Exception: # WebSocketDisconnect, RuntimeError, etc. - return # Connection closed - else: - messages.append(msg) - except asyncio.TimeoutError: - if not await _send_batched(send_text, messages): - return # Connection closed - messages = [] + result = await _process_ws_message(q, send_text, messages, batch_delay) + if result is False: + return except asyncio.CancelledError: pass +async def _process_ws_message( + q: "janus._AsyncQueueProxy[str]", + send_text: Callable[[str], Any], + messages: list[str], + batch_delay: float, +) -> bool | None: + """Process a single WebSocket message from the queue. + + Args: + q: The async queue to read from + send_text: Async function to send text data over WebSocket + messages: List to accumulate messages for batching (mutated in place) + batch_delay: Batch delay in seconds + + Returns: + True to continue processing, False to stop the sender loop, + None to continue (same as True but used for continue semantics). + """ + timeout = batch_delay if messages else None + try: + msg = await asyncio.wait_for(q.get(), timeout=timeout) + except asyncio.TimeoutError: + if not await _send_batched(send_text, messages): + return False + messages.clear() + return True + + if msg == SHUTDOWN_SIGNAL: + if messages: + await _send_batched(send_text, messages) + return False + + if msg == FLUSH_SIGNAL: + if messages and not await _send_batched(send_text, messages): + return False + messages.clear() + return None + + if not batch_delay: + try: + await send_text(msg) + except Exception: # pylint: disable=broad-exception-caught + return False # WebSocketDisconnect, RuntimeError, etc. + else: + messages.append(msg) + + return True + + async def _send_batched(send_text: Callable[[str], Any], messages: list) -> bool: """Send messages as a batch. @@ -291,8 +318,8 @@ async def _send_batched(send_text: Callable[[str], Any], messages: list) -> bool # Wrap in array: "[msg1,msg2,msg3]" await send_text("[" + ",".join(messages) + "]") return True - except Exception: # WebSocketDisconnect, RuntimeError, etc. - return False # Connection closed, cleanup handled by main loop + except Exception: # pylint: disable=broad-exception-caught + return False # WebSocketDisconnect, RuntimeError, etc. def make_callback_done_handler( From 2f42201532cb2d5892197f18f0703d6f42c4c2a6 Mon Sep 17 00:00:00 2001 From: philippe Date: Wed, 27 May 2026 16:39:25 -0400 Subject: [PATCH 4/4] fix lint --- dash/backends/ws.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/dash/backends/ws.py b/dash/backends/ws.py index 807a167ab7..8ee2cf11c1 100644 --- a/dash/backends/ws.py +++ b/dash/backends/ws.py @@ -254,7 +254,7 @@ async def _process_ws_message( send_text: Callable[[str], Any], messages: list[str], batch_delay: float, -) -> bool | None: +) -> bool: """Process a single WebSocket message from the queue. Args: @@ -264,17 +264,15 @@ async def _process_ws_message( batch_delay: Batch delay in seconds Returns: - True to continue processing, False to stop the sender loop, - None to continue (same as True but used for continue semantics). + True to continue processing, False to stop the sender loop. """ timeout = batch_delay if messages else None try: msg = await asyncio.wait_for(q.get(), timeout=timeout) except asyncio.TimeoutError: - if not await _send_batched(send_text, messages): - return False + success = await _send_batched(send_text, messages) messages.clear() - return True + return success if msg == SHUTDOWN_SIGNAL: if messages: @@ -282,10 +280,9 @@ async def _process_ws_message( return False if msg == FLUSH_SIGNAL: - if messages and not await _send_batched(send_text, messages): - return False + success = not messages or await _send_batched(send_text, messages) messages.clear() - return None + return success if not batch_delay: try: