From 6ea7f53a232c1504f2dde3d74b07afbbd11d87e2 Mon Sep 17 00:00:00 2001 From: choiyounghoon Date: Tue, 10 Feb 2026 23:11:26 +0900 Subject: [PATCH] =?UTF-8?q?=20fix(auth):=20OAuthClientProvider=20async=5Fa?= =?UTF-8?q?uth=5Fflow=20lock=20=EB=B2=84=EA=B7=B8=20=EC=88=98=EC=A0=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit async with self.context.lock: 블록이 모든 yield 문을 감싸고 있어 generator suspend/resume 시 RuntimeError 발생. 해결: 각 yield 지점 전후로 lock acquire/release. 기존 기능 유지, 테스트 통과 (166 passed). --- src/mcp/client/auth/oauth2.py | 139 ++++++++++++++++++++-------------- 1 file changed, 84 insertions(+), 55 deletions(-) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 41aecc6f2..3bd826948 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -503,7 +503,14 @@ async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}") async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: - """HTTPX auth flow integration.""" + """HTTPX auth flow integration. + + Note: We release the lock around each yield point to avoid holding it + across generator suspensions, which causes "current task is not holding + this lock" errors when resumed in a different task context. + """ + # Phase 1: Initialize and check token validity + refresh_request = None async with self.context.lock: if not self._initialized: await self._initialize() # pragma: no cover @@ -514,33 +521,38 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. if not self.context.is_token_valid() and self.context.can_refresh_token(): # Try to refresh token refresh_request = await self._refresh_token() # pragma: no cover - refresh_response = yield refresh_request # pragma: no cover - if not await self._handle_refresh_response(refresh_response): # pragma: no cover + # Phase 2: Refresh token if needed (yield without lock held) + if refresh_request is not None: # pragma: no cover + refresh_response = yield refresh_request + async with self.context.lock: + if not await self._handle_refresh_response(refresh_response): # Refresh failed, need full re-authentication self._initialized = False + # Phase 3: Add auth header if token is valid + async with self.context.lock: if self.context.is_token_valid(): self._add_auth_header(request) - response = yield request + response = yield request - if response.status_code == 401: - # Perform full OAuth flow - try: - # OAuth flow must be inline due to generator constraints - www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response) - - # Step 1: Discover protected resource metadata (SEP-985 with fallback support) - prm_discovery_urls = build_protected_resource_metadata_discovery_urls( - www_auth_resource_metadata_url, self.context.server_url - ) + if response.status_code == 401: + # Perform full OAuth flow (release lock around each yield) + try: + # OAuth flow must be inline due to generator constraints + www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response) - for url in prm_discovery_urls: # pragma: no branch - discovery_request = create_oauth_metadata_request(url) + # Step 1: Discover protected resource metadata (SEP-985 with fallback support) + prm_discovery_urls = build_protected_resource_metadata_discovery_urls( + www_auth_resource_metadata_url, self.context.server_url + ) - discovery_response = yield discovery_request # sending request + for url in prm_discovery_urls: # pragma: no branch + discovery_request = create_oauth_metadata_request(url) + discovery_response = yield discovery_request + async with self.context.lock: prm = await handle_protected_resource_response(discovery_response) if prm: # Validate PRM resource matches server URL (RFC 8707) @@ -553,36 +565,41 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. ) # this is always true as authorization_servers has a min length of 1 self.context.auth_server_url = str(prm.authorization_servers[0]) - break - else: - logger.debug(f"Protected resource metadata discovery failed: {url}") + if prm: + break + logger.debug(f"Protected resource metadata discovery failed: {url}") + + async with self.context.lock: asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( self.context.auth_server_url, self.context.server_url ) - # Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers) - for url in asm_discovery_urls: # pragma: no branch - oauth_metadata_request = create_oauth_metadata_request(url) - oauth_metadata_response = yield oauth_metadata_request + # Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers) + for url in asm_discovery_urls: # pragma: no branch + oauth_metadata_request = create_oauth_metadata_request(url) + oauth_metadata_response = yield oauth_metadata_request + async with self.context.lock: ok, asm = await handle_auth_metadata_response(oauth_metadata_response) if not ok: break - if ok and asm: + if asm: self.context.oauth_metadata = asm break - else: - logger.debug(f"OAuth metadata discovery failed: {url}") + logger.debug(f"OAuth metadata discovery failed: {url}") - # Step 3: Apply scope selection strategy + # Step 3: Apply scope selection strategy + async with self.context.lock: self.context.client_metadata.scope = get_client_metadata_scopes( extract_scope_from_www_auth(response), self.context.protected_resource_metadata, self.context.oauth_metadata, ) - # Step 4: Register client or use URL-based client ID (CIMD) + # Step 4: Register client or use URL-based client ID (CIMD) + registration_request = None + async with self.context.lock: if not self.context.client_info: if should_use_client_metadata_url( self.context.oauth_metadata, self.context.client_metadata_url @@ -602,40 +619,52 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. self.context.client_metadata, self.context.get_authorization_base_url(self.context.server_url), ) - registration_response = yield registration_request - client_information = await handle_registration_response(registration_response) - self.context.client_info = client_information - await self.context.storage.set_client_info(client_information) - # Step 5: Perform authorization and complete token exchange - token_response = yield await self._perform_authorization() + if registration_request is not None: + registration_response = yield registration_request + async with self.context.lock: + client_information = await handle_registration_response(registration_response) + self.context.client_info = client_information + await self.context.storage.set_client_info(client_information) + + # Step 5: Perform authorization and complete token exchange + async with self.context.lock: + authorization_request = await self._perform_authorization() + token_response = yield authorization_request + async with self.context.lock: await self._handle_token_response(token_response) - except Exception: # pragma: no cover - logger.exception("OAuth flow error") - raise + except Exception: # pragma: no cover + logger.exception("OAuth flow error") + raise - # Retry with new tokens + # Retry with new tokens + async with self.context.lock: self._add_auth_header(request) - yield request - elif response.status_code == 403: - # Step 1: Extract error field from WWW-Authenticate header - error = extract_field_from_www_auth(response, "error") - - # Step 2: Check if we need to step-up authorization - if error == "insufficient_scope": # pragma: no branch - try: - # Step 2a: Update the required scopes + yield request + elif response.status_code == 403: + # Step 1: Extract error field from WWW-Authenticate header + error = extract_field_from_www_auth(response, "error") + + # Step 2: Check if we need to step-up authorization + if error == "insufficient_scope": # pragma: no branch + try: + # Step 2a: Update the required scopes + async with self.context.lock: self.context.client_metadata.scope = get_client_metadata_scopes( extract_scope_from_www_auth(response), self.context.protected_resource_metadata ) - # Step 2b: Perform (re-)authorization and token exchange - token_response = yield await self._perform_authorization() + # Step 2b: Perform (re-)authorization and token exchange + async with self.context.lock: + authorization_request = await self._perform_authorization() + token_response = yield authorization_request + async with self.context.lock: await self._handle_token_response(token_response) - except Exception: # pragma: no cover - logger.exception("OAuth flow error") - raise + except Exception: # pragma: no cover + logger.exception("OAuth flow error") + raise - # Retry with new tokens + # Retry with new tokens + async with self.context.lock: self._add_auth_header(request) - yield request + yield request