@@ -2622,114 +2622,199 @@ async def callback_handler() -> tuple[str, str | None]:
26222622 pass
26232623
26242624
2625- class TestConcurrentRequestsDoNotDeadlock :
2626- """Regression tests for #1326.
2625+ @pytest .mark .anyio
2626+ async def test_concurrent_request_not_blocked_by_pending_long_running_request (
2627+ oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken
2628+ ):
2629+ """Regression for #1326: a second request reaches its yield while the
2630+ first is still suspended (= simulating a server-side long-poll).
26272631
2628- Ensures that ``OAuthClientProvider. async_auth_flow`` does not serialize
2629- concurrent unrelated requests behind a long-running one (e.g. GET SSE
2630- long-poll). The fix narrows ``context.lock`` to state mutation only; the
2631- actual ``yield request`` runs outside any lock.
2632+ Before the lock-scope fix, `` async_auth_flow`` held ``context.lock``
2633+ across ``yield request``. A GET SSE long-poll would therefore hold the
2634+ lock for the entire SSE lifetime, blocking any concurrent request
2635+ waiting on the same provider's lock.
26322636 """
2637+ # Set up valid tokens so neither refresh (Phase 2) nor full OAuth
2638+ # flow (Phase 4) is triggered — we exercise the steady-state Phase 3
2639+ # yield path that previously held the lock.
2640+ oauth_provider .context .current_tokens = valid_tokens
2641+ oauth_provider .context .token_expiry_time = time .time () + 1800
2642+ oauth_provider .context .client_info = OAuthClientInformationFull (
2643+ client_id = "test_client_id" ,
2644+ client_secret = "test_client_secret" ,
2645+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
2646+ )
2647+ oauth_provider ._initialized = True
26332648
2634- @pytest .mark .anyio
2635- async def test_concurrent_request_not_blocked_by_pending_long_running_request (
2636- self , oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken
2637- ):
2638- """A second request must reach its yield while the first is still
2639- suspended at its yield (= simulating a server-side long-poll).
2640-
2641- Before this fix, ``async_auth_flow`` held ``context.lock`` across
2642- ``yield request``. A GET SSE long-poll would therefore hold the lock
2643- for the entire SSE lifetime, blocking any concurrent request waiting
2644- on the same provider's lock and producing a multi-second stall.
2645- """
2646- # Set up valid tokens so neither refresh (Phase 2) nor full OAuth
2647- # flow (Phase 4) is triggered — we want to exercise the steady-state
2648- # Phase 3 yield path that previously held the lock.
2649- oauth_provider .context .current_tokens = valid_tokens
2650- oauth_provider .context .token_expiry_time = time .time () + 1800
2651- oauth_provider .context .client_info = OAuthClientInformationFull (
2652- client_id = "test_client_id" ,
2653- client_secret = "test_client_secret" ,
2654- redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
2655- )
2656- oauth_provider ._initialized = True
2649+ # Flow 1: drive to yield, then leave suspended (simulating long-poll).
2650+ slow_request = httpx .Request ("GET" , "https://api.example.com/v1/mcp" )
2651+ slow_flow = oauth_provider .async_auth_flow (slow_request )
2652+ yielded_slow = await slow_flow .__anext__ ()
2653+ assert yielded_slow .headers .get ("Authorization" ) == "Bearer test_access_token"
26572654
2658- # Flow 1: simulate a slow request. Drive it to its yield, then
2659- # deliberately do not send a response — it stays suspended at the
2660- # yield, just like a GET SSE long-poll waiting for the next event.
2661- slow_request = httpx .Request ("GET" , "https://api.example.com/v1/mcp" )
2662- slow_flow = oauth_provider .async_auth_flow (slow_request )
2663- yielded_slow = await slow_flow .__anext__ ()
2664- assert yielded_slow .headers .get ("Authorization" ) == "Bearer test_access_token"
2665-
2666- # Flow 2: a concurrent request on the same provider. With the fix,
2667- # context.lock is not held during Flow 1's yield, so Flow 2 reaches
2668- # its yield almost immediately. Without the fix, this would block
2669- # until Flow 1 receives a response — i.e., it would hit the timeout.
2670- fast_request = httpx .Request ("POST" , "https://api.example.com/v1/mcp" )
2671- fast_flow = oauth_provider .async_auth_flow (fast_request )
2672- with anyio .fail_after (1.0 ):
2673- yielded_fast = await fast_flow .__anext__ ()
2674- assert yielded_fast .headers .get ("Authorization" ) == "Bearer test_access_token"
2675-
2676- # Clean up both generators in deterministic order.
2677- with contextlib .suppress (StopAsyncIteration ):
2678- await fast_flow .asend (httpx .Response (200 , request = yielded_fast ))
2679- with contextlib .suppress (StopAsyncIteration ):
2680- await slow_flow .asend (httpx .Response (200 , request = yielded_slow ))
2655+ # Flow 2: concurrent request. With the fix this reaches its yield
2656+ # immediately; without the fix it would block on context.lock.
2657+ fast_request = httpx .Request ("POST" , "https://api.example.com/v1/mcp" )
2658+ fast_flow = oauth_provider .async_auth_flow (fast_request )
2659+ with anyio .fail_after (5 ):
2660+ yielded_fast = await fast_flow .__anext__ ()
2661+ assert yielded_fast .headers .get ("Authorization" ) == "Bearer test_access_token"
26812662
2682- @pytest .mark .anyio
2683- async def test_concurrent_token_refresh_is_single_flight (
2684- self , oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken
2685- ):
2686- """When concurrent requests both observe an expired token, only one
2687- refresh request is sent: ``refresh_lock`` provides single-flight
2688- semantics so the second waiter re-checks state and proceeds without
2689- re-triggering refresh.
2690- """
2691- # Mark the token as expired so the next auth_flow run enters Phase 2.
2692- oauth_provider .context .current_tokens = valid_tokens
2693- oauth_provider .context .token_expiry_time = time .time () - 100 # expired
2694- oauth_provider .context .client_info = OAuthClientInformationFull (
2695- client_id = "test_client_id" ,
2696- client_secret = "test_client_secret" ,
2697- redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
2698- )
2699- oauth_provider ._initialized = True
2663+ with contextlib .suppress (StopAsyncIteration ):
2664+ await fast_flow .asend (httpx .Response (200 , request = yielded_fast ))
2665+ with contextlib .suppress (StopAsyncIteration ):
2666+ await slow_flow .asend (httpx .Response (200 , request = yielded_slow ))
27002667
2701- # Flow A: drive it to the refresh yield and suspend there.
2702- request_a = httpx .Request ("GET" , "https://api.example.com/v1/mcp" )
2703- flow_a = oauth_provider .async_auth_flow (request_a )
2704- refresh_a = await flow_a .__anext__ ()
2705- assert "grant_type=refresh_token" in refresh_a .read ().decode ()
27062668
2707- # Complete Flow A's refresh with a fresh token.
2708- refresh_response = httpx .Response (
2709- 200 ,
2710- content = (
2711- b'{"access_token": "new_access_token", "token_type": "Bearer", '
2712- b'"expires_in": 3600, "refresh_token": "new_refresh_token"}'
2713- ),
2714- request = refresh_a ,
2715- )
2716- request_a_post = await flow_a .asend (refresh_response )
2717- assert request_a_post .headers .get ("Authorization" ) == "Bearer new_access_token"
2718-
2719- # Flow B starts after Flow A's refresh has completed. Because token
2720- # state was updated under context.lock, Flow B observes the fresh
2721- # token in Phase 1, skips Phase 2 entirely, and reaches its yield
2722- # directly. No second refresh request is sent.
2723- request_b = httpx .Request ("POST" , "https://api.example.com/v1/mcp" )
2724- flow_b = oauth_provider .async_auth_flow (request_b )
2725- with anyio .fail_after (1.0 ):
2726- request_b_yielded = await flow_b .__anext__ ()
2727- assert request_b_yielded .headers .get ("Authorization" ) == "Bearer new_access_token"
2728- # Confirm Flow B yielded the original POST, not a refresh request.
2729- assert request_b_yielded .method == "POST"
2730-
2731- # Clean up.
2732- with contextlib .suppress (StopAsyncIteration ):
2733- await flow_b .asend (httpx .Response (200 , request = request_b_yielded ))
2669+ @pytest .mark .anyio
2670+ async def test_refresh_lock_double_check_skips_redundant_refresh (
2671+ oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken
2672+ ):
2673+ """Two flows enter Phase 2 with an expired token. After the first
2674+ completes a refresh, the second observes the fresh token via the
2675+ Phase 2 double-check inside ``refresh_lock`` (or directly in Phase 1
2676+ if it arrives late) and skips its own refresh.
2677+ """
2678+ oauth_provider .context .current_tokens = valid_tokens
2679+ oauth_provider .context .token_expiry_time = time .time () - 100 # expired
2680+ oauth_provider .context .client_info = OAuthClientInformationFull (
2681+ client_id = "test_client_id" ,
2682+ client_secret = "test_client_secret" ,
2683+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
2684+ )
2685+ oauth_provider ._initialized = True
2686+
2687+ # Flow A: drive to refresh yield, then complete refresh.
2688+ request_a = httpx .Request ("GET" , "https://api.example.com/v1/mcp" )
2689+ flow_a = oauth_provider .async_auth_flow (request_a )
2690+ refresh_a = await flow_a .__anext__ ()
2691+ assert "grant_type=refresh_token" in refresh_a .read ().decode ()
2692+
2693+ refresh_response = httpx .Response (
2694+ 200 ,
2695+ content = (
2696+ b'{"access_token": "new_access_token", "token_type": "Bearer", '
2697+ b'"expires_in": 3600, "refresh_token": "new_refresh_token"}'
2698+ ),
2699+ request = refresh_a ,
2700+ )
2701+ request_a_post = await flow_a .asend (refresh_response )
2702+ assert request_a_post .headers .get ("Authorization" ) == "Bearer new_access_token"
2703+
2704+ # Flow B: state already refreshed; Phase 1 sees valid token, skips Phase 2.
2705+ request_b = httpx .Request ("POST" , "https://api.example.com/v1/mcp" )
2706+ flow_b = oauth_provider .async_auth_flow (request_b )
2707+ with anyio .fail_after (5 ):
2708+ request_b_yielded = await flow_b .__anext__ ()
2709+ assert request_b_yielded .method == "POST"
2710+ assert request_b_yielded .headers .get ("Authorization" ) == "Bearer new_access_token"
2711+
2712+ with contextlib .suppress (StopAsyncIteration ):
2713+ await flow_b .asend (httpx .Response (200 , request = request_b_yielded ))
2714+ with contextlib .suppress (StopAsyncIteration ):
2715+ await flow_a .asend (httpx .Response (200 , request = request_a_post ))
2716+
2717+
2718+ @pytest .mark .anyio
2719+ async def test_refresh_with_failed_status_clears_tokens (oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken ):
2720+ """A non-2xx refresh response clears stored tokens and marks the provider
2721+ uninitialized so the next request triggers a full OAuth flow.
2722+ """
2723+ oauth_provider .context .current_tokens = valid_tokens
2724+ oauth_provider .context .token_expiry_time = time .time () - 100
2725+ oauth_provider .context .client_info = OAuthClientInformationFull (
2726+ client_id = "test_client_id" ,
2727+ client_secret = "test_client_secret" ,
2728+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
2729+ )
2730+ oauth_provider ._initialized = True
2731+
2732+ request = httpx .Request ("POST" , "https://api.example.com/v1/mcp" )
2733+ flow = oauth_provider .async_auth_flow (request )
2734+ refresh_request = await flow .__anext__ ()
2735+ assert "grant_type=refresh_token" in refresh_request .read ().decode ()
2736+
2737+ # Refresh server returns 401.
2738+ refresh_response = httpx .Response (401 , content = b'{"error": "invalid_grant"}' , request = refresh_request )
2739+ with contextlib .suppress (StopAsyncIteration ):
2740+ # After failed refresh, the flow proceeds to Phase 3 yielding the
2741+ # original request without a fresh Authorization header. We don't
2742+ # exercise the subsequent 401/full OAuth path here.
2743+ await flow .asend (refresh_response )
2744+
2745+ assert oauth_provider .context .current_tokens is None
2746+
2747+
2748+ @pytest .mark .anyio
2749+ async def test_refresh_with_invalid_json_clears_tokens (oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken ):
2750+ """A 200 refresh response with a malformed body clears stored tokens —
2751+ the pydantic ValidationError branch is taken.
2752+ """
2753+ oauth_provider .context .current_tokens = valid_tokens
2754+ oauth_provider .context .token_expiry_time = time .time () - 100
2755+ oauth_provider .context .client_info = OAuthClientInformationFull (
2756+ client_id = "test_client_id" ,
2757+ client_secret = "test_client_secret" ,
2758+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
2759+ )
2760+ oauth_provider ._initialized = True
2761+
2762+ request = httpx .Request ("POST" , "https://api.example.com/v1/mcp" )
2763+ flow = oauth_provider .async_auth_flow (request )
2764+ refresh_request = await flow .__anext__ ()
2765+
2766+ # Body does not parse as OAuthToken.
2767+ refresh_response = httpx .Response (200 , content = b"not json" , request = refresh_request )
2768+ with contextlib .suppress (StopAsyncIteration ):
2769+ await flow .asend (refresh_response )
2770+
2771+ assert oauth_provider .context .current_tokens is None
2772+
2773+
2774+ @pytest .mark .anyio
2775+ async def test_double_check_inside_refresh_lock_skips_second_refresh (
2776+ oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken , monkeypatch : pytest .MonkeyPatch
2777+ ):
2778+ """Exercise the double-check branch inside ``refresh_lock``: ``is_token_valid``
2779+ returns False in Phase 1 (= the flow decides to refresh) but True inside
2780+ the inner ``context.lock`` block (= another coroutine refreshed while we
2781+ were waiting on ``refresh_lock``). The flow must skip ``_refresh_token``
2782+ and proceed straight to Phase 3.
2783+ """
2784+ oauth_provider .context .current_tokens = valid_tokens
2785+ oauth_provider .context .token_expiry_time = time .time () - 100 # expired
2786+ oauth_provider .context .client_info = OAuthClientInformationFull (
2787+ client_id = "test_client_id" ,
2788+ client_secret = "test_client_secret" ,
2789+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
2790+ )
2791+ oauth_provider ._initialized = True
2792+
2793+ # Toggle is_token_valid: False on the first call (Phase 1 decision),
2794+ # True on the second (double-check inside refresh_lock).
2795+ call_count = {"n" : 0 }
2796+ original_is_valid = oauth_provider .context .__class__ .is_token_valid
2797+
2798+ def fake_is_token_valid (self : object ) -> bool :
2799+ call_count ["n" ] += 1
2800+ if call_count ["n" ] == 1 :
2801+ return False
2802+ # By the second call, "another coroutine" refreshed; reset token expiry
2803+ # so callers downstream see a valid token.
2804+ oauth_provider .context .token_expiry_time = time .time () + 1800
2805+ return True
2806+
2807+ monkeypatch .setattr (oauth_provider .context .__class__ , "is_token_valid" , fake_is_token_valid )
2808+ try :
2809+ request = httpx .Request ("POST" , "https://api.example.com/v1/mcp" )
2810+ flow = oauth_provider .async_auth_flow (request )
2811+ # No refresh yield is expected — the flow goes directly to its own
2812+ # request yield with the (now-valid) token header attached.
2813+ with anyio .fail_after (5 ):
2814+ yielded = await flow .__anext__ ()
2815+ assert yielded .method == "POST"
2816+ assert yielded .headers .get ("Authorization" ) == "Bearer test_access_token"
27342817 with contextlib .suppress (StopAsyncIteration ):
2735- await flow_a .asend (httpx .Response (200 , request = request_a_post ))
2818+ await flow .asend (httpx .Response (200 , request = yielded ))
2819+ finally :
2820+ monkeypatch .setattr (oauth_provider .context .__class__ , "is_token_valid" , original_is_valid )
0 commit comments