From 8a76f8c56582906ec4778bf0cb7c456b6d125286 Mon Sep 17 00:00:00 2001 From: philippe Date: Wed, 17 Jun 2026 16:26:33 -0400 Subject: [PATCH] Add per-connection threadpool websocket callback executor. --- CHANGELOG.md | 3 ++ dash/backends/_fastapi.py | 10 ++++-- dash/backends/_quart.py | 10 ++++-- dash/backends/base_server.py | 28 ++++++--------- dash/dash.py | 2 ++ tests/unit/test_websocket_executor.py | 51 +++++++++++++++++++++++++++ 6 files changed, 82 insertions(+), 22 deletions(-) create mode 100644 tests/unit/test_websocket_executor.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e18cb41c78..50e6614b4c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,9 @@ This project adheres to [Semantic Versioning](https://semver.org/). ## [UNRELEASED] +### Added +- Per-connection WebSocket callback thread pools. Each WebSocket connection now gets its own `ThreadPoolExecutor` instead of sharing a single app-wide pool, so long-lived (session-persistent) callbacks on one connection no longer limit the number of concurrent users. The per-connection size is configurable via the new `websocket_max_workers` argument to `Dash` (default `4`). + ### Fixed - [#3805](https://github.com/plotly/dash/pull/3805) Fix FastAPI POST routes deadlock caused by middleware consuming request body. Fixes [#3801](https://github.com/plotly/dash/issues/3801). - [#3813](https://github.com/plotly/dash/pull/3813) Fix websockets using incorrect path when deployed behind a proxy diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 97dce1379a..61235b6695 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -731,8 +731,12 @@ async def websocket_handler(websocket: WebSocket): pending_get_props: Dict[str, queue.Queue] = {} # Shutdown event to signal connection closure to worker threads shutdown_event = threading.Event() - # Get thread pool executor - executor = self.get_callback_executor() + # Create a per-connection thread pool executor so that long-lived + # callbacks on one connection cannot starve worker threads for others. + # pylint: disable=protected-access + executor = self.create_callback_executor( + getattr(dash_app, "_websocket_max_workers", 4) + ) # Track pending callback futures pending_callbacks: Dict[str, concurrent.futures.Future] = {} @@ -833,6 +837,8 @@ async def websocket_handler(websocket: WebSocket): # Cancel any pending futures for f in pending_callbacks.values(): f.cancel() + # Shut down this connection's executor (don't block the event loop) + executor.shutdown(wait=False) self.server.add_api_websocket_route(ws_path, websocket_handler) diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index a6d09d1e1c..0cc8772a76 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -559,8 +559,12 @@ async def websocket_handler(): # pylint: disable=too-many-branches pending_get_props: Dict[str, queue.Queue] = {} # Shutdown event to signal connection closure to worker threads connection_shutdown_event = threading.Event() - # Get thread pool executor - executor = self.get_callback_executor() + # Create a per-connection thread pool executor so that long-lived + # callbacks on one connection cannot starve worker threads for others. + # pylint: disable=protected-access + executor = self.create_callback_executor( + getattr(dash_app, "_websocket_max_workers", 4) + ) # Track pending callback futures pending_callbacks: Dict[str, concurrent.futures.Future] = {} @@ -671,6 +675,8 @@ async def websocket_handler(): # pylint: disable=too-many-branches # Cancel any pending futures for f in pending_callbacks.values(): f.cancel() + # Shut down this connection's executor (don't block the event loop) + executor.shutdown(wait=False) class QuartRequestAdapter(RequestAdapter): diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index 52443d4104..ed06663c14 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -189,12 +189,16 @@ def __init__(self, server: ServerType) -> None: """ super().__init__() self.server = server - self._callback_executor: ThreadPoolExecutor | None = None - def get_callback_executor( + def create_callback_executor( self, max_workers: int | None = None ) -> ThreadPoolExecutor: - """Get or create the thread pool executor for callback execution. + """Create a new thread pool executor for callback execution. + + A fresh executor is created per WebSocket connection so that long-lived + (session-persistent) callbacks on one connection cannot exhaust worker + threads shared with other connections. The executor should be shut down + when its connection closes. Args: max_workers: Maximum number of worker threads. If None, uses default. @@ -202,21 +206,9 @@ def get_callback_executor( Returns: ThreadPoolExecutor instance for running callbacks. """ - if self._callback_executor is None: - self._callback_executor = ThreadPoolExecutor( - max_workers=max_workers, thread_name_prefix="dash-callback-" - ) - return self._callback_executor - - def shutdown_executor(self, wait: bool = True) -> None: - """Shutdown the callback executor. - - Args: - wait: If True, wait for pending tasks to complete. - """ - if self._callback_executor is not None: - self._callback_executor.shutdown(wait=wait) - self._callback_executor = None + return ThreadPoolExecutor( + max_workers=max_workers, thread_name_prefix="dash-callback-" + ) def __call__(self, *args, **kwargs) -> Any: """Make the server wrapper callable as a WSGI/ASGI application. diff --git a/dash/dash.py b/dash/dash.py index f547b95b56..05f7700b4c 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -490,6 +490,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches websocket_inactivity_timeout: Optional[int] = 300000, websocket_heartbeat_interval: Optional[int] = 30000, websocket_batch_delay: Optional[float] = 0.005, + websocket_max_workers: Optional[int] = 4, **obsolete, ): @@ -651,6 +652,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches self._websocket_inactivity_timeout = websocket_inactivity_timeout self._websocket_heartbeat_interval = websocket_heartbeat_interval self._websocket_batch_delay = websocket_batch_delay + self._websocket_max_workers = websocket_max_workers self.logger = logging.getLogger(__name__) diff --git a/tests/unit/test_websocket_executor.py b/tests/unit/test_websocket_executor.py new file mode 100644 index 0000000000..b693469d1d --- /dev/null +++ b/tests/unit/test_websocket_executor.py @@ -0,0 +1,51 @@ +"""Unit tests for the per-connection WebSocket callback thread pool. + +These verify that each WebSocket connection gets its own ThreadPoolExecutor +(rather than a single shared, app-wide pool), so that long-lived +(session-persistent) callbacks on one connection cannot exhaust worker threads +shared with other connections, and that the per-connection size is configurable +via the ``websocket_max_workers`` argument to ``Dash``. +""" + +from concurrent.futures import ThreadPoolExecutor + +from dash import Dash + + +def test_websocket_max_workers_default(): + """websocket_max_workers defaults to 4.""" + app = Dash(__name__) + assert app._websocket_max_workers == 4 + + +def test_websocket_max_workers_custom(): + """websocket_max_workers is stored when provided.""" + app = Dash(__name__, websocket_max_workers=16) + assert app._websocket_max_workers == 16 + + +def test_create_callback_executor_is_per_connection(): + """Each call returns a fresh executor, not a cached shared one.""" + backend = Dash(__name__).backend + + ex1 = backend.create_callback_executor(4) + ex2 = backend.create_callback_executor(4) + try: + assert isinstance(ex1, ThreadPoolExecutor) + assert isinstance(ex2, ThreadPoolExecutor) + # Distinct instances => one connection's pool can't starve another's. + assert ex1 is not ex2 + finally: + ex1.shutdown(wait=False) + ex2.shutdown(wait=False) + + +def test_create_callback_executor_honors_max_workers(): + """max_workers is forwarded to the ThreadPoolExecutor.""" + backend = Dash(__name__).backend + + ex = backend.create_callback_executor(7) + try: + assert ex._max_workers == 7 + finally: + ex.shutdown(wait=False)