From 52351a7cd488020c8ada4370a1e290f713bbc539 Mon Sep 17 00:00:00 2001 From: Aaron Jacobs Date: Thu, 28 May 2026 08:58:34 -0400 Subject: [PATCH] feat: Add `login` and `logout` subcommands for OAuth flows. Connect version 2026.02 and later supports dynamic OAuth client registration and the authorization code flow, which is a nice way to sidestep the need for API keys. This commit wires up these flows to new `login` and `logout` subcommands. As a bonus, we store the resulting credentials in the system keyring -- when available -- for improved security. As a secondary bonus: we also support the device code flow (via `--use-device-code`) if the Connect server advertises its support for it. Indicators for the OAuth client and keyring use have also been added to the `rsconnect list` output. Unit tests are included. Closes #759. Signed-off-by: Aaron Jacobs --- docs/CHANGELOG.md | 1 + docs/commands/login.md | 3 + docs/commands/logout.md | 3 + mkdocs.yml | 2 + pyproject.toml | 1 + rsconnect/api.py | 141 +++++++++++ rsconnect/main.py | 178 ++++++++++++++ rsconnect/metadata.py | 54 +++++ rsconnect/oauth.py | 516 +++++++++++++++++++++++++++++++++++++++ tests/test_oauth.py | 519 ++++++++++++++++++++++++++++++++++++++++ 10 files changed, 1418 insertions(+) create mode 100644 docs/commands/login.md create mode 100644 docs/commands/logout.md create mode 100644 rsconnect/oauth.py create mode 100644 tests/test_oauth.py diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 9674c534..dab05a4e 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `pyproject.toml` can now be supplied via `--requirements-file` for deploy and write-manifest. - Perform case insensitive matching of the configured Snowflake connection authenticator. +- New `login` and `logout` subcommands for authenticating to Connect via OAuth. ## [1.29.0] - 2026-04-29 diff --git a/docs/commands/login.md b/docs/commands/login.md new file mode 100644 index 00000000..6fe26f29 --- /dev/null +++ b/docs/commands/login.md @@ -0,0 +1,3 @@ +::: mkdocs-click + :module: rsconnect.main + :command: login diff --git a/docs/commands/logout.md b/docs/commands/logout.md new file mode 100644 index 00000000..df0e72a6 --- /dev/null +++ b/docs/commands/logout.md @@ -0,0 +1,3 @@ +::: mkdocs-click + :module: rsconnect.main + :command: logout diff --git a/mkdocs.yml b/mkdocs.yml index 541da548..cc4b1692 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -46,6 +46,8 @@ nav: - details: commands/details.md - info: commands/info.md - list: commands/list.md + - login: commands/login.md + - logout: commands/logout.md - remove: commands/remove.md - system: commands/system.md - version: commands/version.md diff --git a/pyproject.toml b/pyproject.toml index d4a5dd98..e9a808d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ test = [ "types-Flask", "fastmcp==2.12.4; python_version >= '3.10'", ] +keyring = ["keyring>=23.0.0"] snowflake = ["snowflake-cli"] mcp = ["fastmcp==2.12.4; python_version >= '3.10'"] docs = [ diff --git a/rsconnect/api.py b/rsconnect/api.py index bf643c8f..1312b340 100644 --- a/rsconnect/api.py +++ b/rsconnect/api.py @@ -222,12 +222,18 @@ def __init__( insecure: bool = False, ca_data: Optional[str | bytes] = None, bootstrap_jwt: Optional[str] = None, + oauth_access_token: Optional[str] = None, + oauth_client_id: Optional[str] = None, + server_name: Optional[str] = None, ): super().__init__(url, "Posit Connect") self.api_key = api_key self.bootstrap_jwt = bootstrap_jwt self.insecure = insecure self.ca_data = ca_data + self.oauth_access_token = oauth_access_token + self.oauth_client_id = oauth_client_id + self.server_name = server_name # This is specifically not None. self.cookie_jar = CookieJar() # for compatibility with RSconnectClient @@ -422,6 +428,126 @@ def __init__(self, server: Union[RSConnectServer, SPCSConnectServer], cookies: O if server.api_key: self._headers["X-RSC-Authorization"] = server.api_key + if ( + isinstance(server, RSConnectServer) + and server.oauth_access_token + and not server.api_key + and not server.bootstrap_jwt + ): + self.authorization(f"Bearer {server.oauth_access_token}") + + def request( + self, + method: str, + path: str, + query_params: Optional[Mapping[str, "JsonData"]] = None, + body: "str | bytes | IO[bytes] | Mapping[str, Any] | list[Any] | None" = None, + maximum_redirects: int = 5, + decode_response: bool = True, + headers: Optional[Mapping[str, str]] = None, + ) -> "JsonData | HTTPResponse": + can_retry = isinstance(self._server, RSConnectServer) and bool(self._server.oauth_client_id) + start_pos: "int | None" = None + if can_retry and hasattr(body, "read"): + if getattr(body, "seekable", lambda: False)(): + start_pos = body.tell() # type: ignore[union-attr] + else: + body = body.read() # type: ignore[union-attr] + response = super().request( + method, path, query_params, body, maximum_redirects, decode_response, headers + ) # pyright: ignore[reportUnknownArgumentType] + if can_retry and isinstance(response, HTTPResponse) and response.status == 401: + if self._attempt_token_refresh(): + if start_pos is not None: + body.seek(start_pos) # type: ignore[union-attr] + return super().request( + method, path, query_params, body, maximum_redirects, decode_response, headers + ) # pyright: ignore[reportUnknownArgumentType] + return response + + def _attempt_token_refresh(self) -> bool: + from .oauth import ( + InvalidClientError, + discover_oauth_metadata, + keyring_delete_tokens, + keyring_get_tokens, + keyring_store_token, + refresh_access_token, + register_client, + ) + from .metadata import ServerStore + + server = cast(RSConnectServer, self._server) + + _, refresh_token = keyring_get_tokens(server.url) + if not refresh_token: + store = ServerStore() + entry = None + if server.server_name: + entry = store.get_by_name(server.server_name) + if not entry: + entry = store.get_by_url(server.url) + if entry: + refresh_token = entry.get("oauth_refresh_token") # type: ignore[assignment] + if not refresh_token: + return False + + try: + metadata = discover_oauth_metadata(server.url, server.insecure, server.ca_data) + token_response = refresh_access_token( + metadata, server.oauth_client_id or "", refresh_token, server.insecure, server.ca_data + ) + except InvalidClientError: + # Client was deleted server-side; clear stale tokens and re-register + keyring_delete_tokens(server.url) + store = ServerStore() + entry = None + if server.server_name: + entry = store.get_by_name(server.server_name) + if not entry: + entry = store.get_by_url(server.url) + if entry: + entry_name = str(entry.get("name", server.server_name or server.url)) + store.update_oauth_tokens(entry_name, None, None, None) + try: + metadata = discover_oauth_metadata(server.url, server.insecure, server.ca_data) + new_client_id = register_client(metadata, server.url, server.insecure, server.ca_data) + server.oauth_client_id = new_client_id + if entry: + entry["oauth_client_id"] = new_client_id # type: ignore[typeddict-unknown-key] + store._set(entry_name, entry) # type: ignore[possibly-undefined] + logger.warning("OAuth client was re-registered; please run `rsconnect login` again.") + except Exception as exc: + logger.warning(f"OAuth client re-registration failed: {exc}. Please run `rsconnect login` again.") + return False + except Exception as exc: + logger.warning(f"OAuth token refresh failed: {exc}") + return False + + new_access = token_response["access_token"] + new_refresh = token_response.get("refresh_token", refresh_token) + expires_in = token_response.get("expires_in") + import time + + new_expiry = time.time() + expires_in if expires_in else None + + self.authorization(f"Bearer {new_access}") + server.oauth_access_token = new_access + + stored = keyring_store_token(server.url, new_access, new_refresh) + if not stored: + store = ServerStore() + entry = None + if server.server_name: + entry = store.get_by_name(server.server_name) + if not entry: + entry = store.get_by_url(server.url) + if entry: + entry_name = str(entry.get("name", server.server_name or server.url)) + store.update_oauth_tokens(entry_name, new_access, new_refresh, new_expiry) + + return True + def _tweak_response(self, response: HTTPResponse) -> JsonData | HTTPResponse: return ( response.json_data @@ -974,6 +1100,21 @@ def setup_remote_server( url = cast(str, url) account_name = cast(str, account_name) self.remote_server = ShinyappsServer(url, account_name, token, secret) + elif server_data.from_store and server_data.oauth_client_id: + url = cast(str, url) + from .oauth import keyring_get_tokens + + access_token, _ = keyring_get_tokens(url) + oauth_access_token = access_token or server_data.oauth_access_token + self.remote_server = RSConnectServer( + url, + None, + insecure, + ca_data, + oauth_access_token=oauth_access_token, + oauth_client_id=server_data.oauth_client_id, + server_name=name or server_data.name, + ) else: raise RSConnectException("Unable to infer Connect server type and setup server.") diff --git a/rsconnect/main.py b/rsconnect/main.py index acb82939..e4284811 100644 --- a/rsconnect/main.py +++ b/rsconnect/main.py @@ -848,6 +848,13 @@ def list_servers(verbose: int): click.echo(" URL: %s" % server["url"]) if server.get("api_key"): click.echo(" API key is saved") + if server.get("oauth_client_id"): + click.echo(" OAuth Client ID: %s" % server["oauth_client_id"]) + from .oauth import keyring_get_tokens + + access, _ = keyring_get_tokens(server["url"]) + if access: + click.echo(" Credentials stored in system keyring") if server.get("insecure"): click.echo(" Insecure mode (TLS host/certificate validation disabled)") if server.get("ca_cert"): @@ -958,6 +965,177 @@ def remove( click.echo(message) +@cli.command( + short_help="Authenticate with a Posit Connect server using OAuth.", + help=( + "Authenticate with a Posit Connect server using OAuth 2.1. " + "This opens a browser for interactive login (or uses --use-device-code for headless environments). " + "Tokens are stored in the system keyring when available, with fallback to the local credential store." + ), + no_args_is_help=True, +) +@click.option("--server", "-s", envvar="CONNECT_SERVER", required=True, help="The URL of the Posit Connect server.") +@click.option("--name", "-n", help="Nickname for the server (defaults to server hostname).") +@click.option("--insecure", "-i", envvar="CONNECT_INSECURE", is_flag=True, help="Disable TLS certificate verification.") +@click.option( + "--cacert", + "-c", + envvar="CONNECT_CA_CERTIFICATE", + type=click.Path(exists=True, file_okay=True, dir_okay=False), + help="Path to a trusted CA certificate file for TLS.", +) +@click.option( + "--use-device-code", + is_flag=True, + default=False, + help="Use device code flow for headless/non-interactive environments.", +) +@click.option("--client-id", default=None, help="OAuth client ID (skips Dynamic Client Registration).") +@click.option("--verbose", "-v", count=True, help="Enable verbose output. Use -vv for very verbose (debug) output.") +@cli_exception_handler +def login( + server: str, + name: Optional[str], + insecure: bool, + cacert: Optional[str], + use_device_code: bool, + client_id: Optional[str], + verbose: int, +): + set_verbosity(verbose) + + if not server.startswith("http"): + raise RSConnectException("Server URL must begin with http or https.") + + ca_data = read_certificate_file(cacert) if cacert else None + + if not name: + from urllib.parse import urlparse as _urlparse + + name = _urlparse(server).hostname or server + + from .oauth import ( + InvalidClientError, + discover_oauth_metadata, + keyring_store_token, + login_with_browser, + login_with_device_code as _login_device, + register_client, + ) + + with cli_feedback("Discovering OAuth metadata"): + metadata = discover_oauth_metadata(server, insecure, ca_data) + + # Resolve client_id: flag > stored > DCR + if not client_id: + existing = server_store.get_by_name(name) or server_store.get_by_url(server) + if existing: + stored_client_id = existing.get("oauth_client_id") + if stored_client_id: + client_id = str(stored_client_id) + + if not client_id: + with cli_feedback("Registering OAuth client"): + client_id = register_client(metadata, server, insecure, ca_data) + + def _do_login(cid: str) -> dict[str, Any]: + if use_device_code: + return _login_device(server, cid, metadata, insecure, ca_data) + else: + return login_with_browser(server, cid, metadata, insecure, ca_data) + + try: + token_response = _do_login(client_id) + except InvalidClientError: + with cli_feedback("Re-registering OAuth client"): + client_id = register_client(metadata, server, insecure, ca_data) + token_response = _do_login(client_id) + + access_token = str(token_response["access_token"]) + refresh_token = str(token_response["refresh_token"]) if "refresh_token" in token_response else None + expires_in = token_response.get("expires_in") + import time + + expiry = time.time() + int(expires_in) if expires_in else None + + stored_in_keyring = keyring_store_token(server, access_token, refresh_token) + + ca_data_str = ca_data.decode("utf-8") if isinstance(ca_data, bytes) else ca_data + + if stored_in_keyring: + server_store.set(name, server, oauth_client_id=client_id, insecure=insecure, ca_data=ca_data_str) + else: + server_store.set( + name, + server, + oauth_client_id=client_id, + insecure=insecure, + ca_data=ca_data_str, + oauth_access_token=access_token, + oauth_refresh_token=refresh_token, + oauth_token_expiry=expiry, + ) + + click.echo('Logged in to "%s" (%s)' % (name, server)) + if not stored_in_keyring: + click.secho( + "Note: keyring not available; credentials stored in local file (chmod 600).", + fg="yellow", + ) + + +@cli.command( + short_help="Remove stored OAuth credentials for a Posit Connect server.", + help=( + "Remove locally-stored OAuth credentials for a Posit Connect server. " + "One of --name or --server is required. " + "The server entry is preserved (for re-login without re-registration); " + "use 'rsconnect remove' to delete the entry entirely." + ), + no_args_is_help=True, +) +@click.option("--name", "-n", help="The nickname of the Posit Connect server to log out from.") +@click.option("--server", "-s", help="The URL of the Posit Connect server to log out from.") +@click.option("--verbose", "-v", count=True, help="Enable verbose output. Use -vv for very verbose (debug) output.") +@cli_exception_handler +def logout( + name: Optional[str], + server: Optional[str], + verbose: int, +): + set_verbosity(verbose) + + if name and server: + raise RSConnectException("Specify only one of --name or --server.") + if not name and not server: + raise RSConnectException("Specify one of --name or --server.") + + entry = None + if name: + entry = server_store.get_by_name(name) + if entry is None: + raise RSConnectException('Nickname "%s" was not found.' % name) + elif server: + entry = server_store.get_by_url(server) + if entry is None: + raise RSConnectException('Server URL "%s" was not found.' % server) + + if not entry or not entry.get("oauth_client_id"): + raise RSConnectException( + "This server was not added with 'rsconnect login'. Use 'rsconnect remove' to delete it." + ) + + server_url = entry["url"] + entry_name = entry["name"] + + from .oauth import keyring_delete_tokens + + keyring_delete_tokens(server_url) + server_store.update_oauth_tokens(entry_name, None, None, None) + + click.echo('Logged out from "%s".' % (name or server)) + + def _get_names_to_check(file_or_directory: str) -> list[str]: """ A function to determine a set files to look for in getting information about a diff --git a/rsconnect/metadata.py b/rsconnect/metadata.py index 7ea6e180..409aaf75 100644 --- a/rsconnect/metadata.py +++ b/rsconnect/metadata.py @@ -259,6 +259,10 @@ class ServerDataDict(TypedDict): account_name: NotRequired[str] token: NotRequired[str] secret: NotRequired[str] + oauth_client_id: NotRequired[str] + oauth_access_token: NotRequired[str] + oauth_refresh_token: NotRequired[str] + oauth_token_expiry: NotRequired[float] class ServerData: @@ -279,6 +283,10 @@ def __init__( account_name: Optional[str] = None, token: Optional[str] = None, secret: Optional[str] = None, + oauth_client_id: Optional[str] = None, + oauth_access_token: Optional[str] = None, + oauth_refresh_token: Optional[str] = None, + oauth_token_expiry: Optional[float] = None, ): self.name = name self.url = url @@ -290,6 +298,10 @@ def __init__( self.account_name = account_name self.token = token self.secret = secret + self.oauth_client_id = oauth_client_id + self.oauth_access_token = oauth_access_token + self.oauth_refresh_token = oauth_refresh_token + self.oauth_token_expiry = oauth_token_expiry class ServerStore(DataStore[ServerDataDict]): @@ -338,6 +350,10 @@ def set( account_name: Optional[str] = None, token: Optional[str] = None, secret: Optional[str] = None, + oauth_client_id: Optional[str] = None, + oauth_access_token: Optional[str] = None, + oauth_refresh_token: Optional[str] = None, + oauth_token_expiry: Optional[float] = None, ): """ Add (or update) information about a Connect server @@ -351,6 +367,10 @@ def set( :param account_name: shinyapps.io account name. :param token: shinyapps.io token. :param secret: shinyapps.io secret. + :param oauth_client_id: OAuth client ID. + :param oauth_access_token: OAuth access token (fallback when keyring unavailable). + :param oauth_refresh_token: OAuth refresh token (fallback when keyring unavailable). + :param oauth_token_expiry: OAuth token expiry as unix timestamp. """ common_data: ServerDataDict = { "name": name, @@ -360,6 +380,14 @@ def set( target_data = dict(snowflake_connection_name=snowflake_connection_name, api_key=api_key) elif api_key: target_data = dict(api_key=api_key, insecure=insecure, ca_cert=ca_data) + elif oauth_client_id: + target_data: dict[str, object] = dict(oauth_client_id=oauth_client_id, insecure=insecure, ca_cert=ca_data) + if oauth_access_token: + target_data["oauth_access_token"] = oauth_access_token + if oauth_refresh_token: + target_data["oauth_refresh_token"] = oauth_refresh_token + if oauth_token_expiry is not None: + target_data["oauth_token_expiry"] = oauth_token_expiry elif account_name: target_data = dict(account_name=account_name, token=token, secret=secret) else: @@ -383,6 +411,28 @@ def remove_by_url(self, url: str): """ return self._remove_by_value_attr("name", "url", url) + def update_oauth_tokens( + self, + name: str, + access_token: Optional[str], + refresh_token: Optional[str], + expiry: Optional[float], + ) -> None: + """Update (or clear) stored OAuth token fields for an existing server entry.""" + entry = self._get_by_key(name) + if entry is None: + return + updated: ServerDataDict = {**entry} # type: ignore[misc] + if access_token: + updated["oauth_access_token"] = access_token # type: ignore[typeddict-unknown-key] + updated["oauth_refresh_token"] = refresh_token # type: ignore[typeddict-unknown-key] + updated["oauth_token_expiry"] = expiry # type: ignore[typeddict-unknown-key] + else: + updated.pop("oauth_access_token", None) # type: ignore[misc] + updated.pop("oauth_refresh_token", None) # type: ignore[misc] + updated.pop("oauth_token_expiry", None) # type: ignore[misc] + self._set(name, updated) + def resolve(self, name: Optional[str], url: Optional[str]) -> ServerData: """ This function will resolve the given inputs into a set of server information. @@ -429,6 +479,10 @@ def resolve(self, name: Optional[str], url: Optional[str]) -> ServerData: account_name=entry.get("account_name"), token=entry.get("token"), secret=entry.get("secret"), + oauth_client_id=entry.get("oauth_client_id"), + oauth_access_token=entry.get("oauth_access_token"), + oauth_refresh_token=entry.get("oauth_refresh_token"), + oauth_token_expiry=entry.get("oauth_token_expiry"), ) else: return ServerData( diff --git a/rsconnect/oauth.py b/rsconnect/oauth.py new file mode 100644 index 00000000..255ce996 --- /dev/null +++ b/rsconnect/oauth.py @@ -0,0 +1,516 @@ +"""OAuth 2.1 authentication support for Posit Connect. + +Implements RFC 8414 (discovery), RFC 7591 (DCR), Authorization Code + PKCE, +Device Code flow, token refresh, and keyring integration. +""" + +from __future__ import annotations + +import base64 +import hashlib +import queue +import secrets +import threading +import time +import webbrowser +from http.server import BaseHTTPRequestHandler, HTTPServer as _HTTPServer +from typing import Any, Dict, Optional, Tuple, cast +from urllib.parse import parse_qs, urlencode, urlparse + +import click + +from .exception import RSConnectException +from .http_support import HTTPResponse, HTTPServer +from .log import logger + +# pyright: reportMissingTypeStubs=false + +_KEYRING_SERVICE = "rsconnect-python" +_CLIENT_NAME = "rsconnect-python" +_CALLBACK_TIMEOUT_SECONDS = 600 + + +class InvalidClientError(RSConnectException): + """Raised when the OAuth server returns an invalid_client error.""" + + def __init__(self) -> None: + super().__init__("OAuth client_id is invalid or has been deleted on the server.") + + +def _check_oauth_error_response(response: HTTPResponse) -> None: + """Check an HTTPResponse for OAuth error codes and raise appropriately.""" + if response.json_data and isinstance(response.json_data, dict): + error = response.json_data.get("error", "") + if error == "invalid_client": + raise InvalidClientError() + description = response.json_data.get("error_description", error) + if description: + raise RSConnectException(f"OAuth error: {description}") + + +def _unwrap_json_response(response: Any) -> dict[str, Any]: + """Extract JSON dict from an HTTPResponse (raw HTTPServer doesn't auto-unwrap). + + Returns the dict if successful, raises RSConnectException on error responses. + """ + if isinstance(response, HTTPResponse): + if response.status and 200 <= response.status < 300 and isinstance(response.json_data, dict): + return cast(Dict[str, Any], response.json_data) + _check_oauth_error_response(response) + raise RSConnectException(f"OAuth request failed: HTTP {response.status}.") + if isinstance(response, dict): + return cast(Dict[str, Any], response) + raise RSConnectException("Unexpected OAuth response format.") + + +def discover_oauth_metadata( + url: str, + insecure: bool = False, + ca_data: Optional[str | bytes] = None, +) -> dict[str, Any]: + """Fetch OAuth 2.0 Authorization Server Metadata (RFC 8414). + + Returns the parsed JSON metadata dict, or raises RSConnectException if + the server does not support OAuth. + """ + server = HTTPServer(url, disable_tls_check=insecure, ca_data=ca_data) + with server: + response = server.get("/.well-known/oauth-authorization-server") + + if isinstance(response, HTTPResponse): + if response.status != 200: + raise RSConnectException( + f"Server at {url} does not support OAuth 2.1 " + f"(discovery endpoint returned HTTP {response.status}). " + f"The server may need to be upgraded, or OAuth may be intentionally disabled by an administrator." + ) + if isinstance(response.json_data, dict) and "token_endpoint" in response.json_data: + return response.json_data + raise RSConnectException(f"Server at {url} returned a non-JSON response from the OAuth discovery endpoint.") + + if not isinstance(response, dict) or "token_endpoint" not in response: + raise RSConnectException(f"Server at {url} returned invalid OAuth metadata (missing token_endpoint).") + + return response + + +def register_client( + metadata: dict[str, Any], + url: str, + insecure: bool = False, + ca_data: Optional[str | bytes] = None, +) -> str: + """Register an OAuth client via Dynamic Client Registration (RFC 7591). + + Returns the client_id. + """ + registration_endpoint = str(metadata.get("registration_endpoint", "")) + if not registration_endpoint: + raise RSConnectException("OAuth metadata does not include a registration_endpoint.") + + parsed = urlparse(registration_endpoint) + base = f"{parsed.scheme}://{parsed.netloc}" + path = parsed.path + + grant_types = ["authorization_code", "refresh_token"] + if metadata.get("device_authorization_endpoint"): + grant_types.append("urn:ietf:params:oauth:grant-type:device_code") + + server = HTTPServer(base, disable_tls_check=insecure, ca_data=ca_data) + with server: + response = server.post( + path, + body={ + "client_name": _CLIENT_NAME, + "redirect_uris": ["http://127.0.0.1/callback"], + "token_endpoint_auth_method": "none", + "grant_types": grant_types, + "response_types": ["code"], + }, + ) + + data = _unwrap_json_response(response) + if "client_id" not in data: + raise RSConnectException("OAuth client registration returned an unexpected response (no client_id).") + + return str(data["client_id"]) + + +def generate_pkce_pair() -> Tuple[str, str]: + """Generate a PKCE code_verifier and code_challenge (S256).""" + code_verifier = secrets.token_urlsafe(96)[:96] + digest = hashlib.sha256(code_verifier.encode("ascii")).digest() + code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + return code_verifier, code_challenge + + +def _exchange_code_for_token( + metadata: dict[str, Any], + client_id: str, + code: str, + code_verifier: str, + redirect_uri: str, + insecure: bool = False, + ca_data: Optional[str | bytes] = None, +) -> dict[str, Any]: + """Exchange an authorization code for tokens.""" + token_endpoint = str(metadata["token_endpoint"]) + parsed = urlparse(token_endpoint) + base = f"{parsed.scheme}://{parsed.netloc}" + path = parsed.path + + body = urlencode( + { + "grant_type": "authorization_code", + "client_id": client_id, + "code": code, + "redirect_uri": redirect_uri, + "code_verifier": code_verifier, + } + ).encode("utf-8") + + server = HTTPServer(base, disable_tls_check=insecure, ca_data=ca_data) + with server: + response = server.request( + "POST", + path, + body=body, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + + data = _unwrap_json_response(response) + if "access_token" not in data: + raise RSConnectException("Token exchange returned an unexpected response.") + + return data + + +class _CallbackHandler(BaseHTTPRequestHandler): + """HTTP request handler for the OAuth redirect callback.""" + + result_queue: queue.Queue[Tuple[str, Optional[str], Optional[str]]] + + def do_GET(self) -> None: # noqa: N802 + qs = parse_qs(urlparse(self.path).query) + code = qs.get("code", [None])[0] + state = qs.get("state", [None])[0] + error = qs.get("error", [None])[0] + + self.send_response(200) + self.send_header("Content-Type", "text/html") + self.end_headers() + + if error: + self.wfile.write(b"

Authentication failed.

You may close this tab.

") + self.result_queue.put(("error", error, qs.get("error_description", [""])[0])) + elif code: + self.wfile.write( + b"

Authentication successful!

You may close this tab.

" + ) + self.result_queue.put(("success", code, state)) + else: + self.wfile.write(b"

Unexpected response.

") + self.result_queue.put(("error", "no_code", "No authorization code in callback")) + + def log_message(self, format: str, *args: object) -> None: + logger.debug(f"OAuth callback server: {format % args}") + + +def login_with_browser( + url: str, + client_id: str, + metadata: dict[str, Any], + insecure: bool = False, + ca_data: Optional[str | bytes] = None, +) -> dict[str, Any]: + """Perform OAuth Authorization Code + PKCE flow via browser. + + Opens the user's browser to the authorization URL and starts a local + HTTP server to receive the callback. Returns the token response dict. + """ + code_verifier, code_challenge = generate_pkce_pair() + state = secrets.token_urlsafe(32) + + result_queue: queue.Queue[Tuple[str, Optional[str], Optional[str]]] = queue.Queue() + + callback_server = _HTTPServer(("127.0.0.1", 0), _CallbackHandler) + port = callback_server.server_address[1] + redirect_uri = f"http://127.0.0.1:{port}/callback" + + # Attach queue to the handler class for this server instance + callback_server.RequestHandlerClass.result_queue = result_queue # type: ignore[attr-defined] + + auth_endpoint = str(metadata["authorization_endpoint"]) + auth_params = urlencode( + { + "response_type": "code", + "client_id": client_id, + "redirect_uri": redirect_uri, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "state": state, + } + ) + auth_url = f"{auth_endpoint}?{auth_params}" + + server_thread = threading.Thread(target=callback_server.handle_request, daemon=True) + server_thread.start() + + if not webbrowser.open(auth_url): + click.echo( + f"Could not open browser automatically. This can happen if no display is available\n" + f"or localhost is blocked by network rules. Please open this URL manually:\n\n" + f" {auth_url}\n\n" + f"Waiting for authentication callback..." + ) + else: + click.echo("Opened browser for authentication. Waiting for callback...") + + server_thread.join(timeout=_CALLBACK_TIMEOUT_SECONDS) + callback_server.server_close() + + if result_queue.empty(): + raise RSConnectException(f"OAuth browser callback timed out after {_CALLBACK_TIMEOUT_SECONDS} seconds.") + + result = result_queue.get_nowait() + if result[0] == "error": + raise RSConnectException(f"OAuth authentication failed: {result[1]} — {result[2]}") + + _, code, returned_state = result + if returned_state != state: + raise RSConnectException("OAuth state mismatch — possible CSRF attack.") + if not code: + raise RSConnectException("OAuth callback did not contain an authorization code.") + + return _exchange_code_for_token(metadata, client_id, code, code_verifier, redirect_uri, insecure, ca_data) + + +def login_with_device_code( + url: str, + client_id: str, + metadata: dict[str, Any], + insecure: bool = False, + ca_data: Optional[str | bytes] = None, +) -> dict[str, Any]: + """Perform OAuth Device Code flow. + + Displays a URL and user code for the user to enter in a browser, + then polls for token completion. + """ + device_endpoint = str(metadata.get("device_authorization_endpoint", "")) + if not device_endpoint: + raise RSConnectException( + "Server does not support the device code flow. " + "The server may need to be upgraded, or the device code flow may be " + "intentionally disabled by an administrator. Try again without --use-device-code." + ) + + parsed = urlparse(device_endpoint) + base = f"{parsed.scheme}://{parsed.netloc}" + path = parsed.path + + body = urlencode({"client_id": client_id}).encode("utf-8") + + server = HTTPServer(base, disable_tls_check=insecure, ca_data=ca_data) + with server: + response = server.request( + "POST", + path, + body=body, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + + resp = _unwrap_json_response(response) + device_code = str(resp.get("device_code", "")) + user_code = str(resp.get("user_code", "")) + verification_uri = str(resp.get("verification_uri", "")) + interval = int(resp.get("interval", 5)) + expires_in = int(resp.get("expires_in", 600)) + + verification_uri_complete = str(resp.get("verification_uri_complete", "")) or verification_uri + + click.echo(f"\nOpen this URL in your browser:\n\n {verification_uri_complete}\n") + click.echo(f"Enter the code: {user_code}\n") + click.echo("Waiting for authorization...") + + return _poll_for_device_token(metadata, client_id, device_code, interval, expires_in, insecure, ca_data) + + +def _poll_for_device_token( + metadata: dict[str, Any], + client_id: str, + device_code: str, + interval: int, + expires_in: int, + insecure: bool = False, + ca_data: Optional[str | bytes] = None, +) -> dict[str, Any]: + """Poll the token endpoint for device code completion.""" + token_endpoint = str(metadata["token_endpoint"]) + parsed = urlparse(token_endpoint) + base = f"{parsed.scheme}://{parsed.netloc}" + path = parsed.path + + deadline = time.time() + expires_in + poll_interval = interval + + while time.time() < deadline: + time.sleep(poll_interval) + + body = urlencode( + { + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", + "client_id": client_id, + "device_code": device_code, + } + ).encode("utf-8") + + server = HTTPServer(base, disable_tls_check=insecure, ca_data=ca_data) + with server: + response = server.request( + "POST", + path, + body=body, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + + # Extract JSON from the response (raw HTTPServer always returns HTTPResponse) + json_data: Optional[dict[str, Any]] = None + if isinstance(response, HTTPResponse): + if isinstance(response.json_data, dict): + json_data = response.json_data + else: + raise RSConnectException(f"Device code token request failed: HTTP {response.status}.") + elif isinstance(response, dict): + json_data = response + + if json_data is None: + raise RSConnectException("Device code token request returned an unexpected response.") + + if "access_token" in json_data: + return json_data + + error = str(json_data.get("error", "")) + if error == "authorization_pending": + continue + elif error == "slow_down": + poll_interval += 5 + continue + elif error == "invalid_client": + raise InvalidClientError() + elif error == "expired_token": + raise RSConnectException("Device code expired. Please try again.") + elif error == "access_denied": + raise RSConnectException("Authorization was denied by the user.") + elif error: + description = str(json_data.get("error_description", error)) + raise RSConnectException(f"Device code flow failed: {description}") + else: + raise RSConnectException("Device code token request returned an unexpected response.") + + raise RSConnectException("Device code authorization timed out. Please try again.") + + +def refresh_access_token( + metadata: dict[str, Any], + client_id: str, + refresh_token: str, + insecure: bool = False, + ca_data: Optional[str | bytes] = None, +) -> dict[str, Any]: + """Refresh an OAuth access token using a refresh token. + + Returns the new token response dict. Raises InvalidClientError if the + client_id has been deleted server-side. + """ + token_endpoint = str(metadata["token_endpoint"]) + parsed = urlparse(token_endpoint) + base = f"{parsed.scheme}://{parsed.netloc}" + path = parsed.path + + body = urlencode( + { + "grant_type": "refresh_token", + "client_id": client_id, + "refresh_token": refresh_token, + } + ).encode("utf-8") + + server = HTTPServer(base, disable_tls_check=insecure, ca_data=ca_data) + with server: + response = server.request( + "POST", + path, + body=body, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + + data = _unwrap_json_response(response) + if "access_token" not in data: + raise RSConnectException("Token refresh returned an unexpected response.") + + return data + + +# --------------------------------------------------------------------------- +# Keyring integration +# --------------------------------------------------------------------------- + + +def keyring_store_token(server_url: str, access_token: str, refresh_token: Optional[str]) -> bool: + """Store OAuth tokens in the system keyring. + + Returns True on success, False if keyring is not available. + """ + try: + import keyring # type: ignore[import-untyped] + + keyring.set_password(_KEYRING_SERVICE, f"{server_url}:access_token", access_token) + if refresh_token: + keyring.set_password(_KEYRING_SERVICE, f"{server_url}:refresh_token", refresh_token) + else: + try: + keyring.delete_password(_KEYRING_SERVICE, f"{server_url}:refresh_token") + except keyring.errors.PasswordDeleteError: + pass + return True + except ImportError: + return False + except Exception as e: + logger.warning(f"keyring storage failed: {e}") + return False + + +def keyring_get_tokens(server_url: str) -> Tuple[Optional[str], Optional[str]]: + """Retrieve OAuth tokens from the system keyring. + + Returns (access_token, refresh_token), or (None, None) if unavailable. + """ + try: + import keyring # type: ignore[import-untyped] + + access = keyring.get_password(_KEYRING_SERVICE, f"{server_url}:access_token") + refresh = keyring.get_password(_KEYRING_SERVICE, f"{server_url}:refresh_token") + return access, refresh + except ImportError: + return None, None + except Exception as e: + logger.warning(f"keyring retrieval failed: {e}") + return None, None + + +def keyring_delete_tokens(server_url: str) -> None: + """Delete OAuth tokens from the system keyring.""" + try: + import keyring # type: ignore[import-untyped] + import keyring.errors # type: ignore[import-untyped] + + for suffix in (":access_token", ":refresh_token"): + try: + keyring.delete_password(_KEYRING_SERVICE, f"{server_url}{suffix}") + except keyring.errors.PasswordDeleteError: + pass + except ImportError: + pass + except Exception as e: + logger.warning(f"keyring deletion failed: {e}") diff --git a/tests/test_oauth.py b/tests/test_oauth.py new file mode 100644 index 00000000..bb2b75a4 --- /dev/null +++ b/tests/test_oauth.py @@ -0,0 +1,519 @@ +from __future__ import annotations + +import base64 +import hashlib +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from rsconnect.exception import RSConnectException +from rsconnect.http_support import HTTPResponse +from rsconnect.metadata import ServerData +from rsconnect.oauth import ( + InvalidClientError, + _exchange_code_for_token, + _poll_for_device_token, + discover_oauth_metadata, + generate_pkce_pair, + keyring_delete_tokens, + keyring_get_tokens, + keyring_store_token, + login_with_browser, + login_with_device_code, + refresh_access_token, + register_client, +) + + +FAKE_URL = "https://connect.example.com" +FAKE_METADATA: dict[str, Any] = { + "issuer": FAKE_URL, + "authorization_endpoint": f"{FAKE_URL}/oauth/v1/authorize", + "token_endpoint": f"{FAKE_URL}/oauth/v1/token", + "registration_endpoint": f"{FAKE_URL}/oauth/v1/register", + "device_authorization_endpoint": f"{FAKE_URL}/oauth/v1/device", +} + + +def _make_response(status: int = 200, json_data: Any = None) -> HTTPResponse: + response = HTTPResponse("", body=b"") + response.status = status + response.json_data = json_data + return response + + +@pytest.fixture +def mock_http_server(): + with patch("rsconnect.oauth.HTTPServer") as mock_cls: + mock_server = MagicMock() + mock_cls.return_value = mock_server + mock_server.__enter__ = MagicMock(return_value=mock_server) + mock_server.__exit__ = MagicMock(return_value=False) + yield mock_server + + +class TestPKCE: + def test_generates_valid_pair(self): + verifier, challenge = generate_pkce_pair() + assert 43 <= len(verifier) <= 128 + digest = hashlib.sha256(verifier.encode("ascii")).digest() + expected_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + assert challenge == expected_challenge + + +class TestDiscoverOAuthMetadata: + def test_success(self, mock_http_server: MagicMock): + mock_http_server.get.return_value = _make_response(200, FAKE_METADATA) + result = discover_oauth_metadata(FAKE_URL) + assert result == FAKE_METADATA + + def test_server_not_supporting_oauth(self, mock_http_server: MagicMock): + mock_http_server.get.return_value = _make_response(404, None) + with pytest.raises(RSConnectException, match="does not support OAuth"): + discover_oauth_metadata(FAKE_URL) + + def test_missing_token_endpoint(self, mock_http_server: MagicMock): + mock_http_server.get.return_value = {"issuer": FAKE_URL} + with pytest.raises(RSConnectException, match="invalid OAuth metadata"): + discover_oauth_metadata(FAKE_URL) + + +class TestRegisterClient: + def test_success(self, mock_http_server: MagicMock): + mock_http_server.post.return_value = _make_response(200, {"client_id": "test-client-123"}) + result = register_client(FAKE_METADATA, FAKE_URL) + assert result == "test-client-123" + + def test_failure(self, mock_http_server: MagicMock): + mock_http_server.post.return_value = _make_response( + 400, {"error": "invalid_request", "error_description": "bad request"} + ) + with pytest.raises(RSConnectException, match="OAuth error"): + register_client(FAKE_METADATA, FAKE_URL) + + def test_missing_registration_endpoint(self): + metadata = {k: v for k, v in FAKE_METADATA.items() if k != "registration_endpoint"} + with pytest.raises(RSConnectException, match="registration_endpoint"): + register_client(metadata, FAKE_URL) + + +class TestTokenExchange: + def test_success(self, mock_http_server: MagicMock): + mock_http_server.request.return_value = _make_response( + 200, {"access_token": "at-123", "refresh_token": "rt-456", "expires_in": 3600} + ) + result = _exchange_code_for_token( + FAKE_METADATA, "client-1", "auth-code", "verifier", "http://127.0.0.1:8080/callback" + ) + assert result["access_token"] == "at-123" + + def test_invalid_client(self, mock_http_server: MagicMock): + mock_http_server.request.return_value = _make_response(401, {"error": "invalid_client"}) + with pytest.raises(InvalidClientError): + _exchange_code_for_token( + FAKE_METADATA, "bad-client", "auth-code", "verifier", "http://127.0.0.1:8080/callback" + ) + + +class TestDeviceCodeFlow: + def test_device_code_not_supported(self): + metadata = {k: v for k, v in FAKE_METADATA.items() if k != "device_authorization_endpoint"} + with pytest.raises(RSConnectException, match="does not support the device code flow"): + login_with_device_code(FAKE_URL, "client-1", metadata) + + @patch("rsconnect.oauth.time.sleep") + def test_poll_success(self, _, mock_http_server: MagicMock): + mock_http_server.request.side_effect = [ + _make_response(400, {"error": "authorization_pending"}), + _make_response(200, {"access_token": "at-final", "refresh_token": "rt-final"}), + ] + result = _poll_for_device_token(FAKE_METADATA, "client-1", "device-code-1", 5, 600) + assert result["access_token"] == "at-final" + + @patch("rsconnect.oauth.time.sleep") + def test_poll_expired(self, _, mock_http_server: MagicMock): + mock_http_server.request.return_value = _make_response(400, {"error": "expired_token"}) + with pytest.raises(RSConnectException, match="expired"): + _poll_for_device_token(FAKE_METADATA, "client-1", "device-code-1", 5, 600) + + @patch("rsconnect.oauth.time.sleep") + def test_poll_invalid_client(self, _, mock_http_server: MagicMock): + mock_http_server.request.return_value = _make_response(401, {"error": "invalid_client"}) + with pytest.raises(InvalidClientError): + _poll_for_device_token(FAKE_METADATA, "bad-client", "device-code-1", 5, 600) + + @patch("rsconnect.oauth.time.sleep") + def test_poll_malformed_response(self, _, mock_http_server: MagicMock): + mock_http_server.request.return_value = _make_response(200, {}) + with pytest.raises(RSConnectException, match="unexpected response"): + _poll_for_device_token(FAKE_METADATA, "client-1", "device-code-1", 5, 600) + + +class TestRefreshAccessToken: + def test_success(self, mock_http_server: MagicMock): + mock_http_server.request.return_value = _make_response( + 200, {"access_token": "new-at", "refresh_token": "new-rt", "expires_in": 7200} + ) + result = refresh_access_token(FAKE_METADATA, "client-1", "old-rt") + assert result["access_token"] == "new-at" + + def test_invalid_client(self, mock_http_server: MagicMock): + mock_http_server.request.return_value = _make_response(401, {"error": "invalid_client"}) + with pytest.raises(InvalidClientError): + refresh_access_token(FAKE_METADATA, "bad-client", "old-rt") + + +class TestKeyringIntegration: + def test_store_success(self): + mock_keyring = MagicMock() + with patch.dict("sys.modules", {"keyring": mock_keyring}): + result = keyring_store_token("https://example.com", "at-1", "rt-1") + assert result is True + assert mock_keyring.set_password.call_count == 2 + + def test_store_no_keyring(self): + with patch.dict("sys.modules", {"keyring": None}): + result = keyring_store_token("https://example.com", "at-1", "rt-1") + assert result is False + + def test_get_success(self): + mock_keyring = MagicMock() + mock_keyring.get_password.side_effect = lambda svc, key: "token-value" + with patch.dict("sys.modules", {"keyring": mock_keyring}): + access, refresh = keyring_get_tokens("https://example.com") + assert access == "token-value" + assert refresh == "token-value" + + def test_get_no_keyring(self): + with patch.dict("sys.modules", {"keyring": None}): + access, refresh = keyring_get_tokens("https://example.com") + assert access is None + assert refresh is None + + def test_delete_success(self): + mock_keyring = MagicMock() + mock_keyring_errors = MagicMock() + mock_keyring_errors.PasswordDeleteError = Exception + with patch.dict("sys.modules", {"keyring": mock_keyring, "keyring.errors": mock_keyring_errors}): + keyring_delete_tokens("https://example.com") + assert mock_keyring.delete_password.call_count == 2 + + +class TestLoginWithBrowser: + @patch("rsconnect.oauth.webbrowser.open", return_value=True) + @patch("rsconnect.oauth._exchange_code_for_token") + @patch("rsconnect.oauth._HTTPServer") + @patch("rsconnect.oauth.secrets.token_urlsafe", return_value="fixed-state") + def test_success( + self, + mock_token_urlsafe: MagicMock, + mock_httpserver_cls: MagicMock, + mock_exchange: MagicMock, + mock_browser: MagicMock, + ): + mock_exchange.return_value = {"access_token": "at-browser", "refresh_token": "rt-browser"} + + mock_server_instance = MagicMock() + mock_server_instance.server_address = ("127.0.0.1", 9999) + mock_httpserver_cls.return_value = mock_server_instance + + # When handle_request is called in the thread, simulate putting + # the auth code onto the result_queue that the function created + def fake_handle_request(): + # The function sets result_queue on RequestHandlerClass before starting the thread + rq = mock_server_instance.RequestHandlerClass.result_queue + rq.put(("success", "auth-code-123", "fixed-state")) + + mock_server_instance.handle_request.side_effect = fake_handle_request + + result = login_with_browser(FAKE_URL, "client-1", FAKE_METADATA) + + assert result == {"access_token": "at-browser", "refresh_token": "rt-browser"} + mock_browser.assert_called_once() + + +class TestExecutorOAuthSetup: + @pytest.mark.parametrize( + "keyring_token,expected_token", + [ + ("keyring-access-token", "keyring-access-token"), + (None, "stored-fallback-token"), + ], + ids=["keyring-available", "keyring-fallback"], + ) + def test_setup_remote_server_with_oauth_entry(self, keyring_token, expected_token): + from rsconnect.api import RSConnectExecutor, RSConnectServer + + with patch("rsconnect.oauth.keyring_get_tokens", return_value=(keyring_token, None)): + with patch("rsconnect.metadata.ServerStore.resolve") as mock_resolve: + mock_resolve.return_value = ServerData( + name="myserver", + url=FAKE_URL, + from_store=True, + oauth_client_id="client-123", + oauth_access_token="stored-fallback-token", + ) + + executor = RSConnectExecutor.__new__(RSConnectExecutor) + executor.logger = None + executor.ctx = None + executor.setup_remote_server(ctx=None, name="myserver") + + assert isinstance(executor.remote_server, RSConnectServer) + assert executor.remote_server.oauth_access_token == expected_token + assert executor.remote_server.oauth_client_id == "client-123" + + +class TestRefreshTokenFallback: + @pytest.mark.parametrize( + "server_name,get_by_name_rv,get_by_url_rv", + [ + ("testserver", {"oauth_refresh_token": "stored-rt"}, None), + (FAKE_URL, None, {"oauth_refresh_token": "url-rt", "name": "myserver"}), + (None, None, {"oauth_refresh_token": "url-rt", "name": "resolved-name"}), + ], + ids=["name-lookup", "url-fallback", "no-server-name"], + ) + @patch("rsconnect.oauth.HTTPServer") + @patch("rsconnect.oauth.keyring_get_tokens", return_value=(None, None)) + @patch("rsconnect.oauth.keyring_store_token", return_value=False) + def test_refresh_falls_back_to_store( + self, + mock_keyring_store: MagicMock, + mock_keyring_get: MagicMock, + mock_http_server_cls: MagicMock, + server_name: "str | None", + get_by_name_rv: Any, + get_by_url_rv: Any, + ): + from rsconnect.api import RSConnectClient, RSConnectServer + from rsconnect.metadata import ServerStore + + mock_http_server = MagicMock() + mock_http_server_cls.return_value = mock_http_server + mock_http_server.__enter__ = MagicMock(return_value=mock_http_server) + mock_http_server.__exit__ = MagicMock(return_value=False) + + mock_http_server.get.return_value = _make_response(200, FAKE_METADATA) + mock_http_server.request.return_value = _make_response( + 200, {"access_token": "new-at", "refresh_token": "new-rt", "expires_in": 3600} + ) + + server = RSConnectServer( + FAKE_URL, + None, + False, + None, + oauth_access_token="old-at", + oauth_client_id="client-1", + server_name=server_name, + ) + + with patch.object(ServerStore, "get_by_name", return_value=get_by_name_rv): + with patch.object(ServerStore, "get_by_url", return_value=get_by_url_rv): + with patch.object(ServerStore, "update_oauth_tokens"): + client = RSConnectClient(server) + result = client._attempt_token_refresh() + + assert result is True + + +class TestStreamBodyRetry: + def test_seekable_stream_rewinds_for_retry(self): + import io + + from rsconnect.api import RSConnectClient, RSConnectServer + from rsconnect.http_support import HTTPServer as _HTTPServer + + server = RSConnectServer( + FAKE_URL, + None, + False, + None, + oauth_access_token="old-at", + oauth_client_id="client-1", + server_name="testserver", + ) + + client = RSConnectClient(server) + stream_body = io.BytesIO(b"bundle-payload-data") + call_bodies: list[object] = [] + + def fake_super_request( + self_arg, method, path, query_params, body, maximum_redirects=5, decode_response=True, headers=None + ): + call_bodies.append(body.read() if hasattr(body, "read") else body) + return _make_response(401, None) if len(call_bodies) == 1 else _make_response(200, {"result": "ok"}) + + with patch.object(_HTTPServer, "request", fake_super_request): + with patch.object(client, "_attempt_token_refresh", return_value=True): + client.request("POST", "/v1/content/upload", body=stream_body) + + assert len(call_bodies) == 2 + assert call_bodies[0] == b"bundle-payload-data" + assert call_bodies[1] == b"bundle-payload-data" + + def test_non_seekable_stream_buffered_for_retry(self): + import io + + from rsconnect.api import RSConnectClient, RSConnectServer + from rsconnect.http_support import HTTPServer as _HTTPServer + + class NonSeekableStream(io.RawIOBase): + def __init__(self, data: bytes): + self._data = data + self._pos = 0 + + def read(self, size=-1): + if size == -1: + result = self._data[self._pos :] + else: + result = self._data[self._pos : self._pos + size] + self._pos += len(result) + return result + + def readable(self): + return True + + def seekable(self): + return False + + server = RSConnectServer( + FAKE_URL, + None, + False, + None, + oauth_access_token="old-at", + oauth_client_id="client-1", + server_name="testserver", + ) + + client = RSConnectClient(server) + stream_body = NonSeekableStream(b"bundle-payload-data") + call_bodies: list[object] = [] + + def fake_super_request( + self_arg, method, path, query_params, body, maximum_redirects=5, decode_response=True, headers=None + ): + call_bodies.append(body) + return _make_response(401, None) if len(call_bodies) == 1 else _make_response(200, {"result": "ok"}) + + with patch.object(_HTTPServer, "request", fake_super_request): + with patch.object(client, "_attempt_token_refresh", return_value=True): + client.request("POST", "/v1/content/upload", body=stream_body) + + assert len(call_bodies) == 2 + assert call_bodies[0] == b"bundle-payload-data" + assert call_bodies[1] == b"bundle-payload-data" + + +class TestLoginCommand: + @patch("rsconnect.oauth.keyring_store_token", return_value=True) + @patch("rsconnect.oauth.login_with_browser") + @patch("rsconnect.oauth.register_client", return_value="new-client-id") + @patch("rsconnect.oauth.discover_oauth_metadata") + def test_login_success( + self, + mock_discover: MagicMock, + mock_register: MagicMock, + mock_login: MagicMock, + mock_keyring: MagicMock, + ): + from click.testing import CliRunner + + from rsconnect.main import cli + + mock_discover.return_value = FAKE_METADATA + mock_login.return_value = {"access_token": "at-1", "refresh_token": "rt-1", "expires_in": 3600} + + runner = CliRunner() + result = runner.invoke(cli, ["login", "--server", FAKE_URL, "--name", "test-server"]) + + assert result.exit_code == 0, result.output + assert "Logged in" in result.output + + def test_login_missing_server(self): + from click.testing import CliRunner + + from rsconnect.main import cli + + runner = CliRunner() + result = runner.invoke(cli, ["login"]) + + assert "Usage:" in result.output + + @patch("rsconnect.oauth.keyring_delete_tokens") + @patch("rsconnect.main.server_store") + def test_logout_non_oauth_entry(self, mock_store: MagicMock, mock_keyring_del: MagicMock): + from click.testing import CliRunner + + from rsconnect.main import cli + + mock_store.get_by_name.return_value = {"name": "myserver", "url": FAKE_URL, "api_key": "key-123"} + + runner = CliRunner() + result = runner.invoke(cli, ["logout", "--name", "myserver"]) + + assert result.exit_code != 0 + assert "not an OAuth" in result.output or "rsconnect remove" in result.output + + @patch("rsconnect.oauth.keyring_delete_tokens") + @patch("rsconnect.main.server_store") + def test_logout_success(self, mock_store: MagicMock, mock_keyring_del: MagicMock): + from click.testing import CliRunner + + from rsconnect.main import cli + + mock_store.get_by_name.return_value = { + "name": "myserver", + "url": FAKE_URL, + "oauth_client_id": "client-123", + } + mock_store.update_oauth_tokens = MagicMock() + + runner = CliRunner() + result = runner.invoke(cli, ["logout", "--name", "myserver"]) + + assert result.exit_code == 0, result.output + mock_keyring_del.assert_called_once() + + +class TestListCommand: + @patch("rsconnect.oauth.keyring_get_tokens", return_value=("at-from-keyring", None)) + @patch("rsconnect.main.server_store") + def test_list_oauth_entry_with_keyring(self, mock_store: MagicMock, mock_keyring: MagicMock): + from click.testing import CliRunner + + from rsconnect.main import cli + + mock_store.get_all_servers.return_value = [ + {"name": "myserver", "url": FAKE_URL, "oauth_client_id": "client-abc"}, + ] + mock_store.get_path.return_value = "/tmp/servers.json" + + runner = CliRunner() + result = runner.invoke(cli, ["list"]) + + assert result.exit_code == 0, result.output + assert "OAuth Client ID: client-abc" in result.output + assert "Credentials stored in system keyring" in result.output + + @patch("rsconnect.oauth.keyring_get_tokens", return_value=(None, None)) + @patch("rsconnect.main.server_store") + def test_list_oauth_entry_without_keyring(self, mock_store: MagicMock, mock_keyring: MagicMock): + from click.testing import CliRunner + + from rsconnect.main import cli + + mock_store.get_all_servers.return_value = [ + {"name": "myserver", "url": FAKE_URL, "oauth_client_id": "client-abc"}, + ] + mock_store.get_path.return_value = "/tmp/servers.json" + + runner = CliRunner() + result = runner.invoke(cli, ["list"]) + + assert result.exit_code == 0, result.output + assert "OAuth Client ID: client-abc" in result.output + assert "Credentials stored in system keyring" not in result.output