diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 0b27f43692..867e9a25b0 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -28,10 +28,11 @@ ) from _err from dash.fingerprint import check_fingerprint -from dash import _validate +from dash import _validate, get_app from dash.exceptions import PreventUpdate from .base_server import BaseDashServer, RequestAdapter, ResponseAdapter from ._utils import format_traceback_html +import traceback if TYPE_CHECKING: # pragma: no cover - typing only from dash import Dash @@ -122,8 +123,12 @@ async def _initialize_dev_tools(self) -> None: self.dash_app.enable_dev_tools(**config, first_run=False) self._dev_tools_initialized = True - def _setup_timing(self, request: Request) -> None: + async def _setup_timing(self, request: Request) -> None: """Set up timing information for the request.""" + try: + request.state.json_body = await request.json() if request.headers.get("content-type", "").startswith("application/json") else None + except: + request.state.json_body = None if self.enable_timing: request.state.timing_information = { "__dash_server": {"dur": time.time(), "desc": None} @@ -179,6 +184,12 @@ async def _handle_error( async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # Handle lifespan events (startup/shutdown) if scope["type"] == "lifespan": + try: + dash_app = get_app() + dash_app.backend._setup_catchall() + except: + print("Error during catch-all setup:") + print(traceback.format_exc()) await self._initialize_dev_tools() await self.app(scope, receive, send) return @@ -193,7 +204,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: token = set_current_request(request) try: - self._setup_timing(request) + await self._setup_timing(request) await self._run_before_hooks() await self.app(scope, receive, send) @@ -275,11 +286,24 @@ async def index(_request: Request): dash_app._add_url("", index, methods=["GET"]) def setup_catchall(self, dash_app: Dash): - async def catchall(_request: Request): - return Response(content=dash_app.index(), media_type="text/html") + '''This is needed to ensure that all routes are handled by FastAPI + and passed through the middleware, which is necessary for features like authentication + and timing to work correctly on all routes. FastAPI will match this catch-all route + for any path that isn't matched by a more specific route, allowing the middleware to + process the request and then return the appropriate response (e.g., 404 if no Dash route matches).''' - # pylint: disable=protected-access - dash_app._add_url("{path:path}", catchall, methods=["GET"]) + + def _setup_catchall(self): + try: + print("Setting up catch-all route for unmatched paths") + dash_app = get_app() + async def catchall(_request: Request): + return Response(content=dash_app.index(), media_type="text/html") + + # pylint: disable=protected-access + self.add_url_rule("{path:path}", catchall, methods=["GET"]) + except: + print(traceback.format_exc()) def add_url_rule( self, @@ -289,6 +313,7 @@ def add_url_rule( methods: list[str] | None = None, include_in_schema: bool = False, ): + print(f"Adding URL rule: {rule} -> {view_func} (endpoint: {endpoint}, methods: {methods})") if rule == "": rule = "/" if isinstance(view_func, str): @@ -481,7 +506,7 @@ def add_redirect_rule(self, app, fullname, path): def serve_callback(self, dash_app: Dash): async def _dispatch(request: Request): # pylint: disable=protected-access - body = await request.json() + body = self.request_adapter().get_json() cb_ctx = dash_app._initialize_context( body ) # pylint: disable=protected-access @@ -641,5 +666,13 @@ def origin(self): def path(self): return self._request.url.path + async def _get_json(self, request: Request=None): + req = self._request + if not hasattr(req.state, "json_body"): + req.state.json_body = await request.json() + return req.state.json_body + def get_json(self): - return asyncio.run(self._request.json()) + if not hasattr(self, "_request") or self._request is None: + self._request = get_current_request() + return self._request.state.json_body