Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions @plotly/dash-websocket-worker/src/WebSocketManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@ export class WebSocketManager {
return this.ws !== null && this.ws.readyState === WebSocket.OPEN;
}

/**
* Reset the activity timer.
* Call this when a tab becomes visible to prevent inactivity timeout.
*/
public resetActivity(): void {
this.lastActivityTime = Date.now();
}

private createConnection(): void {
if (!this.serverUrl) {
return;
Expand Down
1 change: 1 addition & 0 deletions @plotly/dash-websocket-worker/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export enum WorkerMessageType {
DISCONNECT = 'disconnect',
CALLBACK_REQUEST = 'callback_request',
GET_PROPS_RESPONSE = 'get_props_response',
TAB_VISIBLE = 'tab_visible',

// Worker -> Renderer
CONNECTED = 'connected',
Expand Down
7 changes: 7 additions & 0 deletions @plotly/dash-websocket-worker/src/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ self.onconnect = (event: MessageEvent) => {
break;
}

case WorkerMessageType.TAB_VISIBLE: {
// Reset activity timer when tab becomes visible
// This prevents inactivity timeout while user is viewing the tab
wsManager.resetActivity();
break;
}

default:
// Forward other messages through the router
router.handleRendererMessage(message.rendererId, message);
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ This project adheres to [Semantic Versioning](https://semver.org/).
- [#3669](https://github.com/plotly/dash/pull/3669) Selection for DataTable cleared with custom action settings
- [#3680](https://github.com/plotly/dash/pull/3680) Added `search_order` prop to `Dropdown` to allow users to preserve original option order during search
- Added `csrf_token_name` and `csrf_header_name` config options to allow configuring the CSRF cookie and header names. Fixes [#729](https://github.com/plotly/dash/issues/729)
- [#3797](https://github.com/plotly/dash/pull/3797) Improved websocket callback management.

## Added
- [#3523](https://github.com/plotly/dash/pull/3523) Fall back to background callback function names if source cannot be found
Expand Down
32 changes: 31 additions & 1 deletion dash/backends/_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
SHUTDOWN_SIGNAL,
DISCONNECTED,
)
from ._ws_registry import ActiveCallbackRegistry
from ._utils import format_traceback_html

if TYPE_CHECKING: # pragma: no cover - typing only
Expand Down Expand Up @@ -677,6 +678,12 @@ def serve_websocket_callback(self, dash_app: "Dash"):
dash_app, "_websocket_allowed_origins", []
) # pylint: disable=protected-access

# Initialize registry on dash_app if not present
# pylint: disable=protected-access
if not hasattr(dash_app, "_ws_callback_registry"):
dash_app._ws_callback_registry = ActiveCallbackRegistry()
registry: ActiveCallbackRegistry = dash_app._ws_callback_registry

def validate_origin(origin: str | None, host: str | None) -> str | None:
"""Validate WebSocket origin. Returns error message or None if valid."""
if not origin:
Expand Down Expand Up @@ -723,6 +730,8 @@ async def websocket_handler(websocket: WebSocket):
executor = self.get_callback_executor()
# Track pending callback futures
pending_callbacks: Dict[str, concurrent.futures.Future] = {}
# Track current renderer ID for this connection
current_renderer_id: str | None = None

# Start sender task to drain outbound queue (sends pre-serialized text)
# pylint: disable=protected-access
Expand Down Expand Up @@ -753,6 +762,22 @@ async def websocket_handler(websocket: WebSocket):
renderer_id = message.get("rendererId", "")
payload = message.get("payload", {})

# Update current renderer ID for cleanup
current_renderer_id = renderer_id

# Adopt connection for this renderer (allows reconnection)
# Called for every callback to ensure registry entry exists
# (entry may have been cleaned up after previous callback)
registry.adopt_connection(
renderer_id,
outbound_queue,
pending_get_props,
shutdown_event,
)

# Register this callback with the registry
registry.register_callback(renderer_id)

# Validate that the callback is allowed to use WebSocket transport
# pylint: disable=protected-access
_validate.validate_websocket_callback_request(
Expand All @@ -761,12 +786,13 @@ async def websocket_handler(websocket: WebSocket):
dash_app._websocket_callbacks,
)

# Create WebSocket callback instance with outbound queue
# Create WebSocket callback instance with registry
ws_cb = DashWebsocketCallback(
pending_get_props,
renderer_id,
outbound_queue,
shutdown_event,
registry=registry,
)

# Submit callback to executor
Expand All @@ -786,6 +812,7 @@ async def websocket_handler(websocket: WebSocket):
request_id,
renderer_id,
shutdown_event,
registry=registry,
)
)
pending_callbacks[request_id] = future
Expand Down Expand Up @@ -821,6 +848,9 @@ async def websocket_handler(websocket: WebSocket):
# Cancel any pending futures
for f in pending_callbacks.values():
f.cancel()
# Cleanup registry entry if no active callbacks
if current_renderer_id is not None:
registry.cleanup_renderer(current_renderer_id)

self.server.add_api_websocket_route(ws_path, websocket_handler)

Expand Down
32 changes: 31 additions & 1 deletion dash/backends/_quart.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
SHUTDOWN_SIGNAL,
DISCONNECTED,
)
from ._ws_registry import ActiveCallbackRegistry
from ._utils import format_traceback_html

if TYPE_CHECKING:
Expand Down Expand Up @@ -521,6 +522,12 @@ def serve_websocket_callback(self, dash_app: "Dash"):
# pylint: disable=protected-access
allowed_origins = getattr(dash_app, "_websocket_allowed_origins", [])

# Initialize registry on dash_app if not present
# pylint: disable=protected-access
if not hasattr(dash_app, "_ws_callback_registry"):
dash_app._ws_callback_registry = ActiveCallbackRegistry()
registry: ActiveCallbackRegistry = dash_app._ws_callback_registry

@self.server.websocket(ws_path)
async def websocket_handler(): # pylint: disable=too-many-branches
ws = websocket
Expand Down Expand Up @@ -564,6 +571,8 @@ async def websocket_handler(): # pylint: disable=too-many-branches
executor = self.get_callback_executor()
# Track pending callback futures
pending_callbacks: Dict[str, concurrent.futures.Future] = {}
# Track current renderer ID for this connection
current_renderer_id: str | None = None

# Start sender task to drain outbound queue (sends pre-serialized text)
# pylint: disable=protected-access
Expand Down Expand Up @@ -601,6 +610,22 @@ async def websocket_handler(): # pylint: disable=too-many-branches
renderer_id = message.get("rendererId", "")
payload = message.get("payload", {})

# Update current renderer ID for cleanup
current_renderer_id = renderer_id

# Adopt connection for this renderer (allows reconnection)
# Called for every callback to ensure registry entry exists
# (entry may have been cleaned up after previous callback)
registry.adopt_connection(
renderer_id,
outbound_queue,
pending_get_props,
connection_shutdown_event,
)

# Register this callback with the registry
registry.register_callback(renderer_id)

# Validate that the callback is allowed to use WebSocket transport
# pylint: disable=protected-access
_validate.validate_websocket_callback_request(
Expand All @@ -609,12 +634,13 @@ async def websocket_handler(): # pylint: disable=too-many-branches
dash_app._websocket_callbacks,
)

# Create WebSocket callback instance with outbound queue
# Create WebSocket callback instance with registry
ws_cb = DashWebsocketCallback(
pending_get_props,
renderer_id,
outbound_queue,
connection_shutdown_event,
registry=registry,
)

# Submit callback to executor
Expand All @@ -634,6 +660,7 @@ async def websocket_handler(): # pylint: disable=too-many-branches
request_id,
renderer_id,
connection_shutdown_event,
registry=registry,
)
)
pending_callbacks[request_id] = future
Expand Down Expand Up @@ -672,6 +699,9 @@ async def websocket_handler(): # pylint: disable=too-many-branches
# Cancel any pending futures
for f in pending_callbacks.values():
f.cancel()
# Cleanup registry entry if no active callbacks
if current_renderer_id is not None:
registry.cleanup_renderer(current_renderer_id)


class QuartRequestAdapter(RequestAdapter):
Expand Down
166 changes: 166 additions & 0 deletions dash/backends/_ws_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""WebSocket callback registry for handling reconnections.

This module provides a registry that tracks active callbacks per renderer,
allowing callbacks to persist across WebSocket reconnections.
"""

from __future__ import annotations

import threading
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional

if TYPE_CHECKING:
import queue
import janus


@dataclass
class RendererState:
"""State for a single renderer's WebSocket connection."""

outbound_queue: "janus.Queue[str]"
pending_get_props: Dict[str, "queue.Queue[Any]"]
shutdown_event: threading.Event
active_callback_count: int = 0
lock: threading.Lock = field(default_factory=threading.Lock)


class ActiveCallbackRegistry:
"""Registry for active WebSocket callbacks that persists across reconnections.

When a WebSocket disconnects and reconnects, callbacks that are still running
can "adopt" the new connection's queues to continue sending updates.

Thread-safe for access from both the main event loop and worker threads.
"""

def __init__(self) -> None:
self._renderers: Dict[str, RendererState] = {}
self._lock = threading.Lock()

def adopt_connection(
self,
renderer_id: str,
outbound_queue: "janus.Queue[str]",
pending_get_props: Dict[str, "queue.Queue[Any]"],
shutdown_event: threading.Event,
) -> None:
"""Associate new connection with existing callbacks for this renderer.

When a WebSocket reconnects, this method updates the queues and shutdown
event so that running callbacks can use the new connection.

Args:
renderer_id: The renderer ID for this connection
outbound_queue: janus.Queue for sending messages
pending_get_props: Dict to track pending get_props requests
shutdown_event: Event signaling connection closure
"""
with self._lock:
if renderer_id in self._renderers:
state = self._renderers[renderer_id]
with state.lock:
state.outbound_queue = outbound_queue
state.pending_get_props = pending_get_props
state.shutdown_event = shutdown_event
else:
self._renderers[renderer_id] = RendererState(
outbound_queue=outbound_queue,
pending_get_props=pending_get_props,
shutdown_event=shutdown_event,
active_callback_count=0,
)

def register_callback(self, renderer_id: str) -> None:
"""Register a new active callback for this renderer.

Args:
renderer_id: The renderer ID
"""
with self._lock:
if renderer_id in self._renderers:
state = self._renderers[renderer_id]
with state.lock:
state.active_callback_count += 1

def unregister_callback(self, renderer_id: str) -> None:
"""Unregister a completed callback for this renderer.

If no active callbacks remain, the renderer state is cleaned up.

Args:
renderer_id: The renderer ID
"""
with self._lock:
if renderer_id in self._renderers:
state = self._renderers[renderer_id]
with state.lock:
state.active_callback_count -= 1
if state.active_callback_count <= 0:
del self._renderers[renderer_id]

def get_queue(self, renderer_id: str) -> Optional["janus.Queue[str]"]:
"""Get current outbound queue for renderer (thread-safe).

Args:
renderer_id: The renderer ID

Returns:
The current outbound queue, or None if renderer not found
"""
with self._lock:
state = self._renderers.get(renderer_id)
if state is None:
return None
with state.lock:
return state.outbound_queue

def get_pending_get_props(
self, renderer_id: str
) -> Optional[Dict[str, "queue.Queue[Any]"]]:
"""Get current pending_get_props dict for renderer (thread-safe).

Args:
renderer_id: The renderer ID

Returns:
The current pending_get_props dict, or None if renderer not found
"""
with self._lock:
state = self._renderers.get(renderer_id)
if state is None:
return None
with state.lock:
return state.pending_get_props

def is_shutdown(self, renderer_id: str) -> bool:
"""Check if current connection is shutdown.

Args:
renderer_id: The renderer ID

Returns:
True if shutdown event is set or renderer not found, False otherwise
"""
with self._lock:
state = self._renderers.get(renderer_id)
if state is None:
return True
with state.lock:
return state.shutdown_event.is_set()

def cleanup_renderer(self, renderer_id: str) -> None:
"""Clean up renderer state when connection closes.

Only removes if no active callbacks remain.

Args:
renderer_id: The renderer ID to clean up
"""
with self._lock:
state = self._renderers.get(renderer_id)
if state is not None:
with state.lock:
if state.active_callback_count <= 0:
del self._renderers[renderer_id]
Loading
Loading