diff --git a/src/uipath_mcp/_cli/_runtime/_runtime.py b/src/uipath_mcp/_cli/_runtime/_runtime.py index 9c303c3..cf2a0eb 100644 --- a/src/uipath_mcp/_cli/_runtime/_runtime.py +++ b/src/uipath_mcp/_cli/_runtime/_runtime.py @@ -38,6 +38,7 @@ from ._context import UiPathServerType from ._exception import McpErrorCode, UiPathMcpRuntimeError from ._session import BaseSessionServer, StdioSessionServer, StreamableHttpSessionServer +from ._token_refresh import TokenRefresher logger = logging.getLogger(__name__) tracer = trace.get_tracer(__name__) @@ -85,6 +86,7 @@ def __init__( self._http_stderr_drain_task: asyncio.Task[None] | None = None self._http_server_stderr_lines: list[str] = [] self._uipath = UiPath() + self._token_refresher: TokenRefresher | None = None self._cleanup_done = False # Context fields from UiPathConfig @@ -206,15 +208,19 @@ async def _run_server(self) -> UiPathRuntimeResult: root_span.set_attribute("command", str(self._server.command)) root_span.set_attribute("args", json.dumps(self._server.args)) root_span.set_attribute("span_type", "MCP Server") - bearer_token = self._uipath._config.secret + + signalr_headers = { + "X-UiPath-Internal-TenantId": str(self._tenant_id), + "X-UiPath-Internal-AccountId": str(self._org_id), + "X-UIPATH-FolderKey": self._folder_key, + "Authorization": f"Bearer {self._uipath._config.secret}", + } + + self._token_refresher = TokenRefresher(self._uipath, signalr_headers) + self._signalr_client = SignalRClient( signalr_url, - headers={ - "X-UiPath-Internal-TenantId": str(self._tenant_id), - "X-UiPath-Internal-AccountId": str(self._org_id), - "X-UIPATH-FolderKey": self._folder_key, - "Authorization": f"Bearer {bearer_token}", - }, + headers=signalr_headers, ) self._signalr_client.on("MessageReceived", self._handle_signalr_message) self._signalr_client.on( @@ -236,6 +242,7 @@ async def _run_server(self) -> UiPathRuntimeResult: run_task = asyncio.create_task(self._signalr_client.run()) cancel_task = asyncio.create_task(self._cancel_event.wait()) self._keep_alive_task = asyncio.create_task(self._keep_alive()) + self._token_refresher.start() try: # Wait for either the run to complete or cancellation @@ -297,6 +304,9 @@ async def _cleanup(self) -> None: await self._on_runtime_abort() + if self._token_refresher: + await self._token_refresher.stop() + if self._keep_alive_task: self._keep_alive_task.cancel() try: @@ -374,11 +384,11 @@ async def _handle_signalr_message(self, args: list[str]) -> None: session_server: BaseSessionServer if self._server.is_streamable_http: session_server = StreamableHttpSessionServer( - self._server, self.slug, session_id + self._server, self.slug, session_id, self._uipath ) else: session_server = StdioSessionServer( - self._server, self.slug, session_id + self._server, self.slug, session_id, self._uipath ) try: await session_server.start() diff --git a/src/uipath_mcp/_cli/_runtime/_session.py b/src/uipath_mcp/_cli/_runtime/_session.py index 6e3840b..061ae2f 100644 --- a/src/uipath_mcp/_cli/_runtime/_session.py +++ b/src/uipath_mcp/_cli/_runtime/_session.py @@ -31,7 +31,13 @@ class BaseSessionServer(ABC): """Base class with transport-agnostic message relay logic.""" - def __init__(self, server_config: McpServer, server_slug: str, session_id: str): + def __init__( + self, + server_config: McpServer, + server_slug: str, + session_id: str, + uipath: UiPath, + ): self._server_config = server_config self._server_slug = server_slug self._session_id = session_id @@ -42,7 +48,7 @@ def __init__(self, server_config: McpServer, server_slug: str, session_id: str): self._active_requests: dict[str, str] = {} self._last_request_id: str | None = None self._last_message_id: str | None = None - self._uipath = UiPath() + self._uipath = uipath self._mcp_tracer = McpTracer(tracer, logger) @property @@ -284,8 +290,14 @@ def _get_message_id(self, message: JSONRPCMessage) -> str: class StdioSessionServer(BaseSessionServer): """Manages a stdio server process for a specific session.""" - def __init__(self, server_config: McpServer, server_slug: str, session_id: str): - super().__init__(server_config, server_slug, session_id) + def __init__( + self, + server_config: McpServer, + server_slug: str, + session_id: str, + uipath: UiPath, + ): + super().__init__(server_config, server_slug, session_id, uipath) self._server_stderr_output: str | None = None @property diff --git a/src/uipath_mcp/_cli/_runtime/_token_refresh.py b/src/uipath_mcp/_cli/_runtime/_token_refresh.py new file mode 100644 index 0000000..ea5cbc1 --- /dev/null +++ b/src/uipath_mcp/_cli/_runtime/_token_refresh.py @@ -0,0 +1,297 @@ +import asyncio +import json +import logging +import os +import time +from enum import Enum +from pathlib import Path +from urllib.parse import urlparse + +import httpx +from uipath._utils._auth import parse_access_token +from uipath._utils._ssl_context import get_httpx_client_kwargs +from uipath._utils.constants import ENV_UIPATH_ACCESS_TOKEN +from uipath.platform import UiPath +from uipath.platform.common import TokenData +from uipath.platform.common._config import UiPathApiConfig + +logger = logging.getLogger(__name__) + +REFRESH_MARGIN_SECONDS = 300 # Refresh 5 minutes before expiry +FALLBACK_REFRESH_INTERVAL = 45 * 60 # 45 minutes when exp claim is unavailable +MAX_RETRY_ATTEMPTS = 3 +RETRY_BASE_DELAY = 5 # seconds +RETRY_FALLBACK_INTERVAL = 60 # seconds to wait after all retries fail + + +class AuthStrategy(Enum): + OAUTH = "oauth" + CLIENT_CREDENTIALS = "client_credentials" + NONE = "none" + + +class TokenRefresher: + """Manages token refresh for long-lived MCP runtime connections. + + Detects the authentication strategy (OAuth or client credentials) and + refreshes the token before it expires. + """ + + def __init__(self, uipath: UiPath, signalr_headers: dict[str, str] | None = None): + self._uipath = uipath + self._signalr_headers = signalr_headers + self._refresh_task: asyncio.Task[None] | None = None + self._cancel_event = asyncio.Event() + + # Client credentials config + self._client_id: str | None = os.environ.get("UIPATH_CLIENT_ID") + self._client_secret: str | None = os.environ.get("UIPATH_CLIENT_SECRET") + + # Detect strategy and resolve token URL + self._strategy = self._detect_strategy() + self._base_url: str = uipath._config.base_url + self._token_url: str | None = self._resolve_token_url() + + def _detect_strategy(self) -> AuthStrategy: + """Detect which auth flow is available for token refresh.""" + if self._client_id and self._client_secret: + return AuthStrategy.CLIENT_CREDENTIALS + + try: + auth_file = Path.cwd() / ".uipath" / ".auth.json" + if auth_file.exists(): + with open(auth_file) as f: + auth_data = json.load(f) + if auth_data.get("refresh_token"): + return AuthStrategy.OAUTH + except Exception as e: + logger.debug(f"Could not read auth file for strategy detection: {e}") + + return AuthStrategy.NONE + + def _resolve_token_url(self) -> str | None: + """Derive the identity token endpoint from base_url. + + Falls back to AuthStrategy.NONE if the URL cannot be resolved. + """ + if self._strategy == AuthStrategy.NONE: + return None + + try: + parsed = urlparse(self._base_url) + domain = f"{parsed.scheme}://{parsed.hostname}" + if parsed.port: + domain += f":{parsed.port}" + return f"{domain}/identity_/connect/token" + except Exception as e: + logger.error( + f"Could not resolve token URL from base_url '{self._base_url}': {e}; " + "token refresh will be disabled" + ) + self._strategy = AuthStrategy.NONE + return None + + @property + def strategy(self) -> AuthStrategy: + return self._strategy + + def start(self) -> None: + """Start the background refresh task.""" + if self._strategy == AuthStrategy.NONE: + logger.info("No token refresh strategy available; refresh disabled") + return + + self._cancel_event.clear() + self._refresh_task = asyncio.create_task(self._refresh_loop()) + logger.info("Token refresh background task started") + + async def stop(self) -> None: + """Stop the background refresh task.""" + self._cancel_event.set() + if self._refresh_task and not self._refresh_task.done(): + self._refresh_task.cancel() + try: + await asyncio.wait_for(self._refresh_task, timeout=5.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass + self._refresh_task = None + logger.info("Token refresh stopped") + + async def _refresh_loop(self) -> None: + """Background loop that refreshes the token before expiry.""" + try: + while not self._cancel_event.is_set(): + wait_seconds = self._seconds_until_refresh() + + if wait_seconds > 0: + try: + await asyncio.wait_for( + self._cancel_event.wait(), + timeout=wait_seconds, + ) + break # cancel_event was set + except asyncio.TimeoutError: + pass # time to refresh + + refreshed = await self._try_refresh() + if not refreshed and not self._cancel_event.is_set(): + logger.error( + "All token refresh attempts failed. " + f"Will retry in {RETRY_FALLBACK_INTERVAL}s. " + "The token may expire causing failures." + ) + try: + await asyncio.wait_for( + self._cancel_event.wait(), + timeout=RETRY_FALLBACK_INTERVAL, + ) + break + except asyncio.TimeoutError: + pass + + except asyncio.CancelledError: + logger.info("Token refresh loop cancelled") + raise + + async def _try_refresh(self) -> bool: + """Attempt to refresh the token with retries. Returns True on success.""" + for attempt in range(MAX_RETRY_ATTEMPTS): + try: + if self._strategy == AuthStrategy.OAUTH: + token_data = await self._refresh_oauth() + elif self._strategy == AuthStrategy.CLIENT_CREDENTIALS: + token_data = await self._refresh_client_credentials() + else: + return False + + self._propagate_token(token_data) + + exp = self._get_token_expiry() + if exp: + logger.info( + f"Token refreshed successfully. " + f"New expiry in {(exp - time.time()) / 60:.1f} min" + ) + else: + logger.info("Token refreshed successfully.") + return True + + except Exception as e: + retry_delay = RETRY_BASE_DELAY * (2**attempt) + logger.error( + f"Token refresh attempt {attempt + 1}/{MAX_RETRY_ATTEMPTS} " + f"failed: {e}" + ) + if attempt < MAX_RETRY_ATTEMPTS - 1: + logger.info(f"Retrying in {retry_delay}s...") + try: + await asyncio.wait_for( + self._cancel_event.wait(), + timeout=retry_delay, + ) + return False # cancel_event was set + except asyncio.TimeoutError: + continue + + return False + + async def _refresh_oauth(self) -> TokenData: + """Refresh using OAuth refresh_token grant.""" + auth_file = Path.cwd() / ".uipath" / ".auth.json" + with open(auth_file) as f: + auth_data = json.load(f) + + refresh_token = auth_data.get("refresh_token") + if not refresh_token: + raise ValueError("No refresh_token found in .uipath/.auth.json") + + from uipath._cli._auth._oidc_utils import OidcUtils + + parsed = urlparse(self._base_url) + domain = f"{parsed.scheme}://{parsed.hostname}" + if parsed.port: + domain += f":{parsed.port}" + + auth_config = OidcUtils.get_auth_config(domain) + client_id = auth_config.get("client_id") + + data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": client_id, + } + + async with httpx.AsyncClient(**get_httpx_client_kwargs()) as client: + response = await client.post( + self._token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + response.raise_for_status() + return TokenData.model_validate(response.json()) + + async def _refresh_client_credentials(self) -> TokenData: + """Refresh using client_credentials grant.""" + data = { + "grant_type": "client_credentials", + "client_id": self._client_id, + "client_secret": self._client_secret, + } + + async with httpx.AsyncClient(**get_httpx_client_kwargs()) as client: + response = await client.post( + self._token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + response.raise_for_status() + return TokenData.model_validate(response.json()) + + def _propagate_token(self, token_data: TokenData) -> None: + """Update all token consumers after a successful refresh.""" + new_token = token_data.access_token + + # Replace UiPath config + self._uipath._config = UiPathApiConfig( + base_url=self._uipath._config.base_url, + secret=new_token, + ) + + # Update SignalR headers so reconnect negotiate uses the fresh token + if self._signalr_headers is not None: + self._signalr_headers["Authorization"] = f"Bearer {new_token}" + + # Update environment variable + os.environ[ENV_UIPATH_ACCESS_TOKEN] = new_token + + # Persist to .auth.json + if self._strategy == AuthStrategy.OAUTH: + try: + from uipath._cli._auth._utils import update_auth_file + + update_auth_file(token_data) + except Exception as e: + logger.warning(f"Failed to update .auth.json: {e}") + + def _get_token_expiry(self) -> float | None: + """Parse the JWT exp claim to get the expiry timestamp.""" + try: + claims = parse_access_token(self._uipath._config.secret) + exp = claims.get("exp") + if exp is not None: + return float(exp) + except Exception as e: + logger.warning(f"Failed to parse token expiry: {e}") + return None + + def _seconds_until_refresh(self) -> float: + """Calculate seconds to wait before next refresh attempt.""" + exp = self._get_token_expiry() + if exp is None: + return FALLBACK_REFRESH_INTERVAL + + remaining = exp - time.time() + if remaining <= REFRESH_MARGIN_SECONDS: + return 0 + + return remaining - REFRESH_MARGIN_SECONDS diff --git a/src/uipath_mcp/_cli/_utils/_config.py b/src/uipath_mcp/_cli/_utils/_config.py index ec8a0e0..0143dd7 100644 --- a/src/uipath_mcp/_cli/_utils/_config.py +++ b/src/uipath_mcp/_cli/_utils/_config.py @@ -91,14 +91,7 @@ def get_servers(self) -> list[McpServer]: def get_server(self, name: str) -> McpServer | None: """ Get a server model by name. - If there's only one server available, return that one regardless of name. - Otherwise, look up the server by the provided name. """ - # If there's only one server, return it - if len(self._servers) == 1: - return next(iter(self._servers.values())) - - # Otherwise, fall back to looking up by name return self._servers.get(name) def get_server_names(self) -> list[str]: