diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md
index 9674c534..dab05a4e 100644
--- a/docs/CHANGELOG.md
+++ b/docs/CHANGELOG.md
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `pyproject.toml` can now be supplied via `--requirements-file` for deploy and
write-manifest.
- Perform case insensitive matching of the configured Snowflake connection authenticator.
+- New `login` and `logout` subcommands for authenticating to Connect via OAuth.
## [1.29.0] - 2026-04-29
diff --git a/docs/commands/login.md b/docs/commands/login.md
new file mode 100644
index 00000000..6fe26f29
--- /dev/null
+++ b/docs/commands/login.md
@@ -0,0 +1,3 @@
+::: mkdocs-click
+ :module: rsconnect.main
+ :command: login
diff --git a/docs/commands/logout.md b/docs/commands/logout.md
new file mode 100644
index 00000000..df0e72a6
--- /dev/null
+++ b/docs/commands/logout.md
@@ -0,0 +1,3 @@
+::: mkdocs-click
+ :module: rsconnect.main
+ :command: logout
diff --git a/mkdocs.yml b/mkdocs.yml
index 541da548..cc4b1692 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -46,6 +46,8 @@ nav:
- details: commands/details.md
- info: commands/info.md
- list: commands/list.md
+ - login: commands/login.md
+ - logout: commands/logout.md
- remove: commands/remove.md
- system: commands/system.md
- version: commands/version.md
diff --git a/pyproject.toml b/pyproject.toml
index d4a5dd98..e9a808d4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,6 +40,7 @@ test = [
"types-Flask",
"fastmcp==2.12.4; python_version >= '3.10'",
]
+keyring = ["keyring>=23.0.0"]
snowflake = ["snowflake-cli"]
mcp = ["fastmcp==2.12.4; python_version >= '3.10'"]
docs = [
diff --git a/rsconnect/api.py b/rsconnect/api.py
index bf643c8f..1312b340 100644
--- a/rsconnect/api.py
+++ b/rsconnect/api.py
@@ -222,12 +222,18 @@ def __init__(
insecure: bool = False,
ca_data: Optional[str | bytes] = None,
bootstrap_jwt: Optional[str] = None,
+ oauth_access_token: Optional[str] = None,
+ oauth_client_id: Optional[str] = None,
+ server_name: Optional[str] = None,
):
super().__init__(url, "Posit Connect")
self.api_key = api_key
self.bootstrap_jwt = bootstrap_jwt
self.insecure = insecure
self.ca_data = ca_data
+ self.oauth_access_token = oauth_access_token
+ self.oauth_client_id = oauth_client_id
+ self.server_name = server_name
# This is specifically not None.
self.cookie_jar = CookieJar()
# for compatibility with RSconnectClient
@@ -422,6 +428,126 @@ def __init__(self, server: Union[RSConnectServer, SPCSConnectServer], cookies: O
if server.api_key:
self._headers["X-RSC-Authorization"] = server.api_key
+ if (
+ isinstance(server, RSConnectServer)
+ and server.oauth_access_token
+ and not server.api_key
+ and not server.bootstrap_jwt
+ ):
+ self.authorization(f"Bearer {server.oauth_access_token}")
+
+ def request(
+ self,
+ method: str,
+ path: str,
+ query_params: Optional[Mapping[str, "JsonData"]] = None,
+ body: "str | bytes | IO[bytes] | Mapping[str, Any] | list[Any] | None" = None,
+ maximum_redirects: int = 5,
+ decode_response: bool = True,
+ headers: Optional[Mapping[str, str]] = None,
+ ) -> "JsonData | HTTPResponse":
+ can_retry = isinstance(self._server, RSConnectServer) and bool(self._server.oauth_client_id)
+ start_pos: "int | None" = None
+ if can_retry and hasattr(body, "read"):
+ if getattr(body, "seekable", lambda: False)():
+ start_pos = body.tell() # type: ignore[union-attr]
+ else:
+ body = body.read() # type: ignore[union-attr]
+ response = super().request(
+ method, path, query_params, body, maximum_redirects, decode_response, headers
+ ) # pyright: ignore[reportUnknownArgumentType]
+ if can_retry and isinstance(response, HTTPResponse) and response.status == 401:
+ if self._attempt_token_refresh():
+ if start_pos is not None:
+ body.seek(start_pos) # type: ignore[union-attr]
+ return super().request(
+ method, path, query_params, body, maximum_redirects, decode_response, headers
+ ) # pyright: ignore[reportUnknownArgumentType]
+ return response
+
+ def _attempt_token_refresh(self) -> bool:
+ from .oauth import (
+ InvalidClientError,
+ discover_oauth_metadata,
+ keyring_delete_tokens,
+ keyring_get_tokens,
+ keyring_store_token,
+ refresh_access_token,
+ register_client,
+ )
+ from .metadata import ServerStore
+
+ server = cast(RSConnectServer, self._server)
+
+ _, refresh_token = keyring_get_tokens(server.url)
+ if not refresh_token:
+ store = ServerStore()
+ entry = None
+ if server.server_name:
+ entry = store.get_by_name(server.server_name)
+ if not entry:
+ entry = store.get_by_url(server.url)
+ if entry:
+ refresh_token = entry.get("oauth_refresh_token") # type: ignore[assignment]
+ if not refresh_token:
+ return False
+
+ try:
+ metadata = discover_oauth_metadata(server.url, server.insecure, server.ca_data)
+ token_response = refresh_access_token(
+ metadata, server.oauth_client_id or "", refresh_token, server.insecure, server.ca_data
+ )
+ except InvalidClientError:
+ # Client was deleted server-side; clear stale tokens and re-register
+ keyring_delete_tokens(server.url)
+ store = ServerStore()
+ entry = None
+ if server.server_name:
+ entry = store.get_by_name(server.server_name)
+ if not entry:
+ entry = store.get_by_url(server.url)
+ if entry:
+ entry_name = str(entry.get("name", server.server_name or server.url))
+ store.update_oauth_tokens(entry_name, None, None, None)
+ try:
+ metadata = discover_oauth_metadata(server.url, server.insecure, server.ca_data)
+ new_client_id = register_client(metadata, server.url, server.insecure, server.ca_data)
+ server.oauth_client_id = new_client_id
+ if entry:
+ entry["oauth_client_id"] = new_client_id # type: ignore[typeddict-unknown-key]
+ store._set(entry_name, entry) # type: ignore[possibly-undefined]
+ logger.warning("OAuth client was re-registered; please run `rsconnect login` again.")
+ except Exception as exc:
+ logger.warning(f"OAuth client re-registration failed: {exc}. Please run `rsconnect login` again.")
+ return False
+ except Exception as exc:
+ logger.warning(f"OAuth token refresh failed: {exc}")
+ return False
+
+ new_access = token_response["access_token"]
+ new_refresh = token_response.get("refresh_token", refresh_token)
+ expires_in = token_response.get("expires_in")
+ import time
+
+ new_expiry = time.time() + expires_in if expires_in else None
+
+ self.authorization(f"Bearer {new_access}")
+ server.oauth_access_token = new_access
+
+ stored = keyring_store_token(server.url, new_access, new_refresh)
+ if not stored:
+ store = ServerStore()
+ entry = None
+ if server.server_name:
+ entry = store.get_by_name(server.server_name)
+ if not entry:
+ entry = store.get_by_url(server.url)
+ if entry:
+ entry_name = str(entry.get("name", server.server_name or server.url))
+ store.update_oauth_tokens(entry_name, new_access, new_refresh, new_expiry)
+
+ return True
+
def _tweak_response(self, response: HTTPResponse) -> JsonData | HTTPResponse:
return (
response.json_data
@@ -974,6 +1100,21 @@ def setup_remote_server(
url = cast(str, url)
account_name = cast(str, account_name)
self.remote_server = ShinyappsServer(url, account_name, token, secret)
+ elif server_data.from_store and server_data.oauth_client_id:
+ url = cast(str, url)
+ from .oauth import keyring_get_tokens
+
+ access_token, _ = keyring_get_tokens(url)
+ oauth_access_token = access_token or server_data.oauth_access_token
+ self.remote_server = RSConnectServer(
+ url,
+ None,
+ insecure,
+ ca_data,
+ oauth_access_token=oauth_access_token,
+ oauth_client_id=server_data.oauth_client_id,
+ server_name=name or server_data.name,
+ )
else:
raise RSConnectException("Unable to infer Connect server type and setup server.")
diff --git a/rsconnect/main.py b/rsconnect/main.py
index acb82939..e4284811 100644
--- a/rsconnect/main.py
+++ b/rsconnect/main.py
@@ -848,6 +848,13 @@ def list_servers(verbose: int):
click.echo(" URL: %s" % server["url"])
if server.get("api_key"):
click.echo(" API key is saved")
+ if server.get("oauth_client_id"):
+ click.echo(" OAuth Client ID: %s" % server["oauth_client_id"])
+ from .oauth import keyring_get_tokens
+
+ access, _ = keyring_get_tokens(server["url"])
+ if access:
+ click.echo(" Credentials stored in system keyring")
if server.get("insecure"):
click.echo(" Insecure mode (TLS host/certificate validation disabled)")
if server.get("ca_cert"):
@@ -958,6 +965,177 @@ def remove(
click.echo(message)
+@cli.command(
+ short_help="Authenticate with a Posit Connect server using OAuth.",
+ help=(
+ "Authenticate with a Posit Connect server using OAuth 2.1. "
+ "This opens a browser for interactive login (or uses --use-device-code for headless environments). "
+ "Tokens are stored in the system keyring when available, with fallback to the local credential store."
+ ),
+ no_args_is_help=True,
+)
+@click.option("--server", "-s", envvar="CONNECT_SERVER", required=True, help="The URL of the Posit Connect server.")
+@click.option("--name", "-n", help="Nickname for the server (defaults to server hostname).")
+@click.option("--insecure", "-i", envvar="CONNECT_INSECURE", is_flag=True, help="Disable TLS certificate verification.")
+@click.option(
+ "--cacert",
+ "-c",
+ envvar="CONNECT_CA_CERTIFICATE",
+ type=click.Path(exists=True, file_okay=True, dir_okay=False),
+ help="Path to a trusted CA certificate file for TLS.",
+)
+@click.option(
+ "--use-device-code",
+ is_flag=True,
+ default=False,
+ help="Use device code flow for headless/non-interactive environments.",
+)
+@click.option("--client-id", default=None, help="OAuth client ID (skips Dynamic Client Registration).")
+@click.option("--verbose", "-v", count=True, help="Enable verbose output. Use -vv for very verbose (debug) output.")
+@cli_exception_handler
+def login(
+ server: str,
+ name: Optional[str],
+ insecure: bool,
+ cacert: Optional[str],
+ use_device_code: bool,
+ client_id: Optional[str],
+ verbose: int,
+):
+ set_verbosity(verbose)
+
+ if not server.startswith("http"):
+ raise RSConnectException("Server URL must begin with http or https.")
+
+ ca_data = read_certificate_file(cacert) if cacert else None
+
+ if not name:
+ from urllib.parse import urlparse as _urlparse
+
+ name = _urlparse(server).hostname or server
+
+ from .oauth import (
+ InvalidClientError,
+ discover_oauth_metadata,
+ keyring_store_token,
+ login_with_browser,
+ login_with_device_code as _login_device,
+ register_client,
+ )
+
+ with cli_feedback("Discovering OAuth metadata"):
+ metadata = discover_oauth_metadata(server, insecure, ca_data)
+
+ # Resolve client_id: flag > stored > DCR
+ if not client_id:
+ existing = server_store.get_by_name(name) or server_store.get_by_url(server)
+ if existing:
+ stored_client_id = existing.get("oauth_client_id")
+ if stored_client_id:
+ client_id = str(stored_client_id)
+
+ if not client_id:
+ with cli_feedback("Registering OAuth client"):
+ client_id = register_client(metadata, server, insecure, ca_data)
+
+ def _do_login(cid: str) -> dict[str, Any]:
+ if use_device_code:
+ return _login_device(server, cid, metadata, insecure, ca_data)
+ else:
+ return login_with_browser(server, cid, metadata, insecure, ca_data)
+
+ try:
+ token_response = _do_login(client_id)
+ except InvalidClientError:
+ with cli_feedback("Re-registering OAuth client"):
+ client_id = register_client(metadata, server, insecure, ca_data)
+ token_response = _do_login(client_id)
+
+ access_token = str(token_response["access_token"])
+ refresh_token = str(token_response["refresh_token"]) if "refresh_token" in token_response else None
+ expires_in = token_response.get("expires_in")
+ import time
+
+ expiry = time.time() + int(expires_in) if expires_in else None
+
+ stored_in_keyring = keyring_store_token(server, access_token, refresh_token)
+
+ ca_data_str = ca_data.decode("utf-8") if isinstance(ca_data, bytes) else ca_data
+
+ if stored_in_keyring:
+ server_store.set(name, server, oauth_client_id=client_id, insecure=insecure, ca_data=ca_data_str)
+ else:
+ server_store.set(
+ name,
+ server,
+ oauth_client_id=client_id,
+ insecure=insecure,
+ ca_data=ca_data_str,
+ oauth_access_token=access_token,
+ oauth_refresh_token=refresh_token,
+ oauth_token_expiry=expiry,
+ )
+
+ click.echo('Logged in to "%s" (%s)' % (name, server))
+ if not stored_in_keyring:
+ click.secho(
+ "Note: keyring not available; credentials stored in local file (chmod 600).",
+ fg="yellow",
+ )
+
+
+@cli.command(
+ short_help="Remove stored OAuth credentials for a Posit Connect server.",
+ help=(
+ "Remove locally-stored OAuth credentials for a Posit Connect server. "
+ "One of --name or --server is required. "
+ "The server entry is preserved (for re-login without re-registration); "
+ "use 'rsconnect remove' to delete the entry entirely."
+ ),
+ no_args_is_help=True,
+)
+@click.option("--name", "-n", help="The nickname of the Posit Connect server to log out from.")
+@click.option("--server", "-s", help="The URL of the Posit Connect server to log out from.")
+@click.option("--verbose", "-v", count=True, help="Enable verbose output. Use -vv for very verbose (debug) output.")
+@cli_exception_handler
+def logout(
+ name: Optional[str],
+ server: Optional[str],
+ verbose: int,
+):
+ set_verbosity(verbose)
+
+ if name and server:
+ raise RSConnectException("Specify only one of --name or --server.")
+ if not name and not server:
+ raise RSConnectException("Specify one of --name or --server.")
+
+ entry = None
+ if name:
+ entry = server_store.get_by_name(name)
+ if entry is None:
+ raise RSConnectException('Nickname "%s" was not found.' % name)
+ elif server:
+ entry = server_store.get_by_url(server)
+ if entry is None:
+ raise RSConnectException('Server URL "%s" was not found.' % server)
+
+ if not entry or not entry.get("oauth_client_id"):
+ raise RSConnectException(
+ "This server was not added with 'rsconnect login'. Use 'rsconnect remove' to delete it."
+ )
+
+ server_url = entry["url"]
+ entry_name = entry["name"]
+
+ from .oauth import keyring_delete_tokens
+
+ keyring_delete_tokens(server_url)
+ server_store.update_oauth_tokens(entry_name, None, None, None)
+
+ click.echo('Logged out from "%s".' % (name or server))
+
+
def _get_names_to_check(file_or_directory: str) -> list[str]:
"""
A function to determine a set files to look for in getting information about a
diff --git a/rsconnect/metadata.py b/rsconnect/metadata.py
index 7ea6e180..409aaf75 100644
--- a/rsconnect/metadata.py
+++ b/rsconnect/metadata.py
@@ -259,6 +259,10 @@ class ServerDataDict(TypedDict):
account_name: NotRequired[str]
token: NotRequired[str]
secret: NotRequired[str]
+ oauth_client_id: NotRequired[str]
+ oauth_access_token: NotRequired[str]
+ oauth_refresh_token: NotRequired[str]
+ oauth_token_expiry: NotRequired[float]
class ServerData:
@@ -279,6 +283,10 @@ def __init__(
account_name: Optional[str] = None,
token: Optional[str] = None,
secret: Optional[str] = None,
+ oauth_client_id: Optional[str] = None,
+ oauth_access_token: Optional[str] = None,
+ oauth_refresh_token: Optional[str] = None,
+ oauth_token_expiry: Optional[float] = None,
):
self.name = name
self.url = url
@@ -290,6 +298,10 @@ def __init__(
self.account_name = account_name
self.token = token
self.secret = secret
+ self.oauth_client_id = oauth_client_id
+ self.oauth_access_token = oauth_access_token
+ self.oauth_refresh_token = oauth_refresh_token
+ self.oauth_token_expiry = oauth_token_expiry
class ServerStore(DataStore[ServerDataDict]):
@@ -338,6 +350,10 @@ def set(
account_name: Optional[str] = None,
token: Optional[str] = None,
secret: Optional[str] = None,
+ oauth_client_id: Optional[str] = None,
+ oauth_access_token: Optional[str] = None,
+ oauth_refresh_token: Optional[str] = None,
+ oauth_token_expiry: Optional[float] = None,
):
"""
Add (or update) information about a Connect server
@@ -351,6 +367,10 @@ def set(
:param account_name: shinyapps.io account name.
:param token: shinyapps.io token.
:param secret: shinyapps.io secret.
+ :param oauth_client_id: OAuth client ID.
+ :param oauth_access_token: OAuth access token (fallback when keyring unavailable).
+ :param oauth_refresh_token: OAuth refresh token (fallback when keyring unavailable).
+ :param oauth_token_expiry: OAuth token expiry as unix timestamp.
"""
common_data: ServerDataDict = {
"name": name,
@@ -360,6 +380,14 @@ def set(
target_data = dict(snowflake_connection_name=snowflake_connection_name, api_key=api_key)
elif api_key:
target_data = dict(api_key=api_key, insecure=insecure, ca_cert=ca_data)
+ elif oauth_client_id:
+ target_data: dict[str, object] = dict(oauth_client_id=oauth_client_id, insecure=insecure, ca_cert=ca_data)
+ if oauth_access_token:
+ target_data["oauth_access_token"] = oauth_access_token
+ if oauth_refresh_token:
+ target_data["oauth_refresh_token"] = oauth_refresh_token
+ if oauth_token_expiry is not None:
+ target_data["oauth_token_expiry"] = oauth_token_expiry
elif account_name:
target_data = dict(account_name=account_name, token=token, secret=secret)
else:
@@ -383,6 +411,28 @@ def remove_by_url(self, url: str):
"""
return self._remove_by_value_attr("name", "url", url)
+ def update_oauth_tokens(
+ self,
+ name: str,
+ access_token: Optional[str],
+ refresh_token: Optional[str],
+ expiry: Optional[float],
+ ) -> None:
+ """Update (or clear) stored OAuth token fields for an existing server entry."""
+ entry = self._get_by_key(name)
+ if entry is None:
+ return
+ updated: ServerDataDict = {**entry} # type: ignore[misc]
+ if access_token:
+ updated["oauth_access_token"] = access_token # type: ignore[typeddict-unknown-key]
+ updated["oauth_refresh_token"] = refresh_token # type: ignore[typeddict-unknown-key]
+ updated["oauth_token_expiry"] = expiry # type: ignore[typeddict-unknown-key]
+ else:
+ updated.pop("oauth_access_token", None) # type: ignore[misc]
+ updated.pop("oauth_refresh_token", None) # type: ignore[misc]
+ updated.pop("oauth_token_expiry", None) # type: ignore[misc]
+ self._set(name, updated)
+
def resolve(self, name: Optional[str], url: Optional[str]) -> ServerData:
"""
This function will resolve the given inputs into a set of server information.
@@ -429,6 +479,10 @@ def resolve(self, name: Optional[str], url: Optional[str]) -> ServerData:
account_name=entry.get("account_name"),
token=entry.get("token"),
secret=entry.get("secret"),
+ oauth_client_id=entry.get("oauth_client_id"),
+ oauth_access_token=entry.get("oauth_access_token"),
+ oauth_refresh_token=entry.get("oauth_refresh_token"),
+ oauth_token_expiry=entry.get("oauth_token_expiry"),
)
else:
return ServerData(
diff --git a/rsconnect/oauth.py b/rsconnect/oauth.py
new file mode 100644
index 00000000..255ce996
--- /dev/null
+++ b/rsconnect/oauth.py
@@ -0,0 +1,516 @@
+"""OAuth 2.1 authentication support for Posit Connect.
+
+Implements RFC 8414 (discovery), RFC 7591 (DCR), Authorization Code + PKCE,
+Device Code flow, token refresh, and keyring integration.
+"""
+
+from __future__ import annotations
+
+import base64
+import hashlib
+import queue
+import secrets
+import threading
+import time
+import webbrowser
+from http.server import BaseHTTPRequestHandler, HTTPServer as _HTTPServer
+from typing import Any, Dict, Optional, Tuple, cast
+from urllib.parse import parse_qs, urlencode, urlparse
+
+import click
+
+from .exception import RSConnectException
+from .http_support import HTTPResponse, HTTPServer
+from .log import logger
+
+# pyright: reportMissingTypeStubs=false
+
+_KEYRING_SERVICE = "rsconnect-python"
+_CLIENT_NAME = "rsconnect-python"
+_CALLBACK_TIMEOUT_SECONDS = 600
+
+
+class InvalidClientError(RSConnectException):
+ """Raised when the OAuth server returns an invalid_client error."""
+
+ def __init__(self) -> None:
+ super().__init__("OAuth client_id is invalid or has been deleted on the server.")
+
+
+def _check_oauth_error_response(response: HTTPResponse) -> None:
+ """Check an HTTPResponse for OAuth error codes and raise appropriately."""
+ if response.json_data and isinstance(response.json_data, dict):
+ error = response.json_data.get("error", "")
+ if error == "invalid_client":
+ raise InvalidClientError()
+ description = response.json_data.get("error_description", error)
+ if description:
+ raise RSConnectException(f"OAuth error: {description}")
+
+
+def _unwrap_json_response(response: Any) -> dict[str, Any]:
+ """Extract JSON dict from an HTTPResponse (raw HTTPServer doesn't auto-unwrap).
+
+ Returns the dict if successful, raises RSConnectException on error responses.
+ """
+ if isinstance(response, HTTPResponse):
+ if response.status and 200 <= response.status < 300 and isinstance(response.json_data, dict):
+ return cast(Dict[str, Any], response.json_data)
+ _check_oauth_error_response(response)
+ raise RSConnectException(f"OAuth request failed: HTTP {response.status}.")
+ if isinstance(response, dict):
+ return cast(Dict[str, Any], response)
+ raise RSConnectException("Unexpected OAuth response format.")
+
+
+def discover_oauth_metadata(
+ url: str,
+ insecure: bool = False,
+ ca_data: Optional[str | bytes] = None,
+) -> dict[str, Any]:
+ """Fetch OAuth 2.0 Authorization Server Metadata (RFC 8414).
+
+ Returns the parsed JSON metadata dict, or raises RSConnectException if
+ the server does not support OAuth.
+ """
+ server = HTTPServer(url, disable_tls_check=insecure, ca_data=ca_data)
+ with server:
+ response = server.get("/.well-known/oauth-authorization-server")
+
+ if isinstance(response, HTTPResponse):
+ if response.status != 200:
+ raise RSConnectException(
+ f"Server at {url} does not support OAuth 2.1 "
+ f"(discovery endpoint returned HTTP {response.status}). "
+ f"The server may need to be upgraded, or OAuth may be intentionally disabled by an administrator."
+ )
+ if isinstance(response.json_data, dict) and "token_endpoint" in response.json_data:
+ return response.json_data
+ raise RSConnectException(f"Server at {url} returned a non-JSON response from the OAuth discovery endpoint.")
+
+ if not isinstance(response, dict) or "token_endpoint" not in response:
+ raise RSConnectException(f"Server at {url} returned invalid OAuth metadata (missing token_endpoint).")
+
+ return response
+
+
+def register_client(
+ metadata: dict[str, Any],
+ url: str,
+ insecure: bool = False,
+ ca_data: Optional[str | bytes] = None,
+) -> str:
+ """Register an OAuth client via Dynamic Client Registration (RFC 7591).
+
+ Returns the client_id.
+ """
+ registration_endpoint = str(metadata.get("registration_endpoint", ""))
+ if not registration_endpoint:
+ raise RSConnectException("OAuth metadata does not include a registration_endpoint.")
+
+ parsed = urlparse(registration_endpoint)
+ base = f"{parsed.scheme}://{parsed.netloc}"
+ path = parsed.path
+
+ grant_types = ["authorization_code", "refresh_token"]
+ if metadata.get("device_authorization_endpoint"):
+ grant_types.append("urn:ietf:params:oauth:grant-type:device_code")
+
+ server = HTTPServer(base, disable_tls_check=insecure, ca_data=ca_data)
+ with server:
+ response = server.post(
+ path,
+ body={
+ "client_name": _CLIENT_NAME,
+ "redirect_uris": ["http://127.0.0.1/callback"],
+ "token_endpoint_auth_method": "none",
+ "grant_types": grant_types,
+ "response_types": ["code"],
+ },
+ )
+
+ data = _unwrap_json_response(response)
+ if "client_id" not in data:
+ raise RSConnectException("OAuth client registration returned an unexpected response (no client_id).")
+
+ return str(data["client_id"])
+
+
+def generate_pkce_pair() -> Tuple[str, str]:
+ """Generate a PKCE code_verifier and code_challenge (S256)."""
+ code_verifier = secrets.token_urlsafe(96)[:96]
+ digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
+ code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
+ return code_verifier, code_challenge
+
+
+def _exchange_code_for_token(
+ metadata: dict[str, Any],
+ client_id: str,
+ code: str,
+ code_verifier: str,
+ redirect_uri: str,
+ insecure: bool = False,
+ ca_data: Optional[str | bytes] = None,
+) -> dict[str, Any]:
+ """Exchange an authorization code for tokens."""
+ token_endpoint = str(metadata["token_endpoint"])
+ parsed = urlparse(token_endpoint)
+ base = f"{parsed.scheme}://{parsed.netloc}"
+ path = parsed.path
+
+ body = urlencode(
+ {
+ "grant_type": "authorization_code",
+ "client_id": client_id,
+ "code": code,
+ "redirect_uri": redirect_uri,
+ "code_verifier": code_verifier,
+ }
+ ).encode("utf-8")
+
+ server = HTTPServer(base, disable_tls_check=insecure, ca_data=ca_data)
+ with server:
+ response = server.request(
+ "POST",
+ path,
+ body=body,
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
+ )
+
+ data = _unwrap_json_response(response)
+ if "access_token" not in data:
+ raise RSConnectException("Token exchange returned an unexpected response.")
+
+ return data
+
+
+class _CallbackHandler(BaseHTTPRequestHandler):
+ """HTTP request handler for the OAuth redirect callback."""
+
+ result_queue: queue.Queue[Tuple[str, Optional[str], Optional[str]]]
+
+ def do_GET(self) -> None: # noqa: N802
+ qs = parse_qs(urlparse(self.path).query)
+ code = qs.get("code", [None])[0]
+ state = qs.get("state", [None])[0]
+ error = qs.get("error", [None])[0]
+
+ self.send_response(200)
+ self.send_header("Content-Type", "text/html")
+ self.end_headers()
+
+ if error:
+ self.wfile.write(b"
Authentication failed.
You may close this tab.
")
+ self.result_queue.put(("error", error, qs.get("error_description", [""])[0]))
+ elif code:
+ self.wfile.write(
+ b"Authentication successful!
You may close this tab.
"
+ )
+ self.result_queue.put(("success", code, state))
+ else:
+ self.wfile.write(b"Unexpected response.
")
+ self.result_queue.put(("error", "no_code", "No authorization code in callback"))
+
+ def log_message(self, format: str, *args: object) -> None:
+ logger.debug(f"OAuth callback server: {format % args}")
+
+
+def login_with_browser(
+ url: str,
+ client_id: str,
+ metadata: dict[str, Any],
+ insecure: bool = False,
+ ca_data: Optional[str | bytes] = None,
+) -> dict[str, Any]:
+ """Perform OAuth Authorization Code + PKCE flow via browser.
+
+ Opens the user's browser to the authorization URL and starts a local
+ HTTP server to receive the callback. Returns the token response dict.
+ """
+ code_verifier, code_challenge = generate_pkce_pair()
+ state = secrets.token_urlsafe(32)
+
+ result_queue: queue.Queue[Tuple[str, Optional[str], Optional[str]]] = queue.Queue()
+
+ callback_server = _HTTPServer(("127.0.0.1", 0), _CallbackHandler)
+ port = callback_server.server_address[1]
+ redirect_uri = f"http://127.0.0.1:{port}/callback"
+
+ # Attach queue to the handler class for this server instance
+ callback_server.RequestHandlerClass.result_queue = result_queue # type: ignore[attr-defined]
+
+ auth_endpoint = str(metadata["authorization_endpoint"])
+ auth_params = urlencode(
+ {
+ "response_type": "code",
+ "client_id": client_id,
+ "redirect_uri": redirect_uri,
+ "code_challenge": code_challenge,
+ "code_challenge_method": "S256",
+ "state": state,
+ }
+ )
+ auth_url = f"{auth_endpoint}?{auth_params}"
+
+ server_thread = threading.Thread(target=callback_server.handle_request, daemon=True)
+ server_thread.start()
+
+ if not webbrowser.open(auth_url):
+ click.echo(
+ f"Could not open browser automatically. This can happen if no display is available\n"
+ f"or localhost is blocked by network rules. Please open this URL manually:\n\n"
+ f" {auth_url}\n\n"
+ f"Waiting for authentication callback..."
+ )
+ else:
+ click.echo("Opened browser for authentication. Waiting for callback...")
+
+ server_thread.join(timeout=_CALLBACK_TIMEOUT_SECONDS)
+ callback_server.server_close()
+
+ if result_queue.empty():
+ raise RSConnectException(f"OAuth browser callback timed out after {_CALLBACK_TIMEOUT_SECONDS} seconds.")
+
+ result = result_queue.get_nowait()
+ if result[0] == "error":
+ raise RSConnectException(f"OAuth authentication failed: {result[1]} — {result[2]}")
+
+ _, code, returned_state = result
+ if returned_state != state:
+ raise RSConnectException("OAuth state mismatch — possible CSRF attack.")
+ if not code:
+ raise RSConnectException("OAuth callback did not contain an authorization code.")
+
+ return _exchange_code_for_token(metadata, client_id, code, code_verifier, redirect_uri, insecure, ca_data)
+
+
+def login_with_device_code(
+ url: str,
+ client_id: str,
+ metadata: dict[str, Any],
+ insecure: bool = False,
+ ca_data: Optional[str | bytes] = None,
+) -> dict[str, Any]:
+ """Perform OAuth Device Code flow.
+
+ Displays a URL and user code for the user to enter in a browser,
+ then polls for token completion.
+ """
+ device_endpoint = str(metadata.get("device_authorization_endpoint", ""))
+ if not device_endpoint:
+ raise RSConnectException(
+ "Server does not support the device code flow. "
+ "The server may need to be upgraded, or the device code flow may be "
+ "intentionally disabled by an administrator. Try again without --use-device-code."
+ )
+
+ parsed = urlparse(device_endpoint)
+ base = f"{parsed.scheme}://{parsed.netloc}"
+ path = parsed.path
+
+ body = urlencode({"client_id": client_id}).encode("utf-8")
+
+ server = HTTPServer(base, disable_tls_check=insecure, ca_data=ca_data)
+ with server:
+ response = server.request(
+ "POST",
+ path,
+ body=body,
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
+ )
+
+ resp = _unwrap_json_response(response)
+ device_code = str(resp.get("device_code", ""))
+ user_code = str(resp.get("user_code", ""))
+ verification_uri = str(resp.get("verification_uri", ""))
+ interval = int(resp.get("interval", 5))
+ expires_in = int(resp.get("expires_in", 600))
+
+ verification_uri_complete = str(resp.get("verification_uri_complete", "")) or verification_uri
+
+ click.echo(f"\nOpen this URL in your browser:\n\n {verification_uri_complete}\n")
+ click.echo(f"Enter the code: {user_code}\n")
+ click.echo("Waiting for authorization...")
+
+ return _poll_for_device_token(metadata, client_id, device_code, interval, expires_in, insecure, ca_data)
+
+
+def _poll_for_device_token(
+ metadata: dict[str, Any],
+ client_id: str,
+ device_code: str,
+ interval: int,
+ expires_in: int,
+ insecure: bool = False,
+ ca_data: Optional[str | bytes] = None,
+) -> dict[str, Any]:
+ """Poll the token endpoint for device code completion."""
+ token_endpoint = str(metadata["token_endpoint"])
+ parsed = urlparse(token_endpoint)
+ base = f"{parsed.scheme}://{parsed.netloc}"
+ path = parsed.path
+
+ deadline = time.time() + expires_in
+ poll_interval = interval
+
+ while time.time() < deadline:
+ time.sleep(poll_interval)
+
+ body = urlencode(
+ {
+ "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
+ "client_id": client_id,
+ "device_code": device_code,
+ }
+ ).encode("utf-8")
+
+ server = HTTPServer(base, disable_tls_check=insecure, ca_data=ca_data)
+ with server:
+ response = server.request(
+ "POST",
+ path,
+ body=body,
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
+ )
+
+ # Extract JSON from the response (raw HTTPServer always returns HTTPResponse)
+ json_data: Optional[dict[str, Any]] = None
+ if isinstance(response, HTTPResponse):
+ if isinstance(response.json_data, dict):
+ json_data = response.json_data
+ else:
+ raise RSConnectException(f"Device code token request failed: HTTP {response.status}.")
+ elif isinstance(response, dict):
+ json_data = response
+
+ if json_data is None:
+ raise RSConnectException("Device code token request returned an unexpected response.")
+
+ if "access_token" in json_data:
+ return json_data
+
+ error = str(json_data.get("error", ""))
+ if error == "authorization_pending":
+ continue
+ elif error == "slow_down":
+ poll_interval += 5
+ continue
+ elif error == "invalid_client":
+ raise InvalidClientError()
+ elif error == "expired_token":
+ raise RSConnectException("Device code expired. Please try again.")
+ elif error == "access_denied":
+ raise RSConnectException("Authorization was denied by the user.")
+ elif error:
+ description = str(json_data.get("error_description", error))
+ raise RSConnectException(f"Device code flow failed: {description}")
+ else:
+ raise RSConnectException("Device code token request returned an unexpected response.")
+
+ raise RSConnectException("Device code authorization timed out. Please try again.")
+
+
+def refresh_access_token(
+ metadata: dict[str, Any],
+ client_id: str,
+ refresh_token: str,
+ insecure: bool = False,
+ ca_data: Optional[str | bytes] = None,
+) -> dict[str, Any]:
+ """Refresh an OAuth access token using a refresh token.
+
+ Returns the new token response dict. Raises InvalidClientError if the
+ client_id has been deleted server-side.
+ """
+ token_endpoint = str(metadata["token_endpoint"])
+ parsed = urlparse(token_endpoint)
+ base = f"{parsed.scheme}://{parsed.netloc}"
+ path = parsed.path
+
+ body = urlencode(
+ {
+ "grant_type": "refresh_token",
+ "client_id": client_id,
+ "refresh_token": refresh_token,
+ }
+ ).encode("utf-8")
+
+ server = HTTPServer(base, disable_tls_check=insecure, ca_data=ca_data)
+ with server:
+ response = server.request(
+ "POST",
+ path,
+ body=body,
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
+ )
+
+ data = _unwrap_json_response(response)
+ if "access_token" not in data:
+ raise RSConnectException("Token refresh returned an unexpected response.")
+
+ return data
+
+
+# ---------------------------------------------------------------------------
+# Keyring integration
+# ---------------------------------------------------------------------------
+
+
+def keyring_store_token(server_url: str, access_token: str, refresh_token: Optional[str]) -> bool:
+ """Store OAuth tokens in the system keyring.
+
+ Returns True on success, False if keyring is not available.
+ """
+ try:
+ import keyring # type: ignore[import-untyped]
+
+ keyring.set_password(_KEYRING_SERVICE, f"{server_url}:access_token", access_token)
+ if refresh_token:
+ keyring.set_password(_KEYRING_SERVICE, f"{server_url}:refresh_token", refresh_token)
+ else:
+ try:
+ keyring.delete_password(_KEYRING_SERVICE, f"{server_url}:refresh_token")
+ except keyring.errors.PasswordDeleteError:
+ pass
+ return True
+ except ImportError:
+ return False
+ except Exception as e:
+ logger.warning(f"keyring storage failed: {e}")
+ return False
+
+
+def keyring_get_tokens(server_url: str) -> Tuple[Optional[str], Optional[str]]:
+ """Retrieve OAuth tokens from the system keyring.
+
+ Returns (access_token, refresh_token), or (None, None) if unavailable.
+ """
+ try:
+ import keyring # type: ignore[import-untyped]
+
+ access = keyring.get_password(_KEYRING_SERVICE, f"{server_url}:access_token")
+ refresh = keyring.get_password(_KEYRING_SERVICE, f"{server_url}:refresh_token")
+ return access, refresh
+ except ImportError:
+ return None, None
+ except Exception as e:
+ logger.warning(f"keyring retrieval failed: {e}")
+ return None, None
+
+
+def keyring_delete_tokens(server_url: str) -> None:
+ """Delete OAuth tokens from the system keyring."""
+ try:
+ import keyring # type: ignore[import-untyped]
+ import keyring.errors # type: ignore[import-untyped]
+
+ for suffix in (":access_token", ":refresh_token"):
+ try:
+ keyring.delete_password(_KEYRING_SERVICE, f"{server_url}{suffix}")
+ except keyring.errors.PasswordDeleteError:
+ pass
+ except ImportError:
+ pass
+ except Exception as e:
+ logger.warning(f"keyring deletion failed: {e}")
diff --git a/tests/test_oauth.py b/tests/test_oauth.py
new file mode 100644
index 00000000..bb2b75a4
--- /dev/null
+++ b/tests/test_oauth.py
@@ -0,0 +1,519 @@
+from __future__ import annotations
+
+import base64
+import hashlib
+from typing import Any
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from rsconnect.exception import RSConnectException
+from rsconnect.http_support import HTTPResponse
+from rsconnect.metadata import ServerData
+from rsconnect.oauth import (
+ InvalidClientError,
+ _exchange_code_for_token,
+ _poll_for_device_token,
+ discover_oauth_metadata,
+ generate_pkce_pair,
+ keyring_delete_tokens,
+ keyring_get_tokens,
+ keyring_store_token,
+ login_with_browser,
+ login_with_device_code,
+ refresh_access_token,
+ register_client,
+)
+
+
+FAKE_URL = "https://connect.example.com"
+FAKE_METADATA: dict[str, Any] = {
+ "issuer": FAKE_URL,
+ "authorization_endpoint": f"{FAKE_URL}/oauth/v1/authorize",
+ "token_endpoint": f"{FAKE_URL}/oauth/v1/token",
+ "registration_endpoint": f"{FAKE_URL}/oauth/v1/register",
+ "device_authorization_endpoint": f"{FAKE_URL}/oauth/v1/device",
+}
+
+
+def _make_response(status: int = 200, json_data: Any = None) -> HTTPResponse:
+ response = HTTPResponse("", body=b"")
+ response.status = status
+ response.json_data = json_data
+ return response
+
+
+@pytest.fixture
+def mock_http_server():
+ with patch("rsconnect.oauth.HTTPServer") as mock_cls:
+ mock_server = MagicMock()
+ mock_cls.return_value = mock_server
+ mock_server.__enter__ = MagicMock(return_value=mock_server)
+ mock_server.__exit__ = MagicMock(return_value=False)
+ yield mock_server
+
+
+class TestPKCE:
+ def test_generates_valid_pair(self):
+ verifier, challenge = generate_pkce_pair()
+ assert 43 <= len(verifier) <= 128
+ digest = hashlib.sha256(verifier.encode("ascii")).digest()
+ expected_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
+ assert challenge == expected_challenge
+
+
+class TestDiscoverOAuthMetadata:
+ def test_success(self, mock_http_server: MagicMock):
+ mock_http_server.get.return_value = _make_response(200, FAKE_METADATA)
+ result = discover_oauth_metadata(FAKE_URL)
+ assert result == FAKE_METADATA
+
+ def test_server_not_supporting_oauth(self, mock_http_server: MagicMock):
+ mock_http_server.get.return_value = _make_response(404, None)
+ with pytest.raises(RSConnectException, match="does not support OAuth"):
+ discover_oauth_metadata(FAKE_URL)
+
+ def test_missing_token_endpoint(self, mock_http_server: MagicMock):
+ mock_http_server.get.return_value = {"issuer": FAKE_URL}
+ with pytest.raises(RSConnectException, match="invalid OAuth metadata"):
+ discover_oauth_metadata(FAKE_URL)
+
+
+class TestRegisterClient:
+ def test_success(self, mock_http_server: MagicMock):
+ mock_http_server.post.return_value = _make_response(200, {"client_id": "test-client-123"})
+ result = register_client(FAKE_METADATA, FAKE_URL)
+ assert result == "test-client-123"
+
+ def test_failure(self, mock_http_server: MagicMock):
+ mock_http_server.post.return_value = _make_response(
+ 400, {"error": "invalid_request", "error_description": "bad request"}
+ )
+ with pytest.raises(RSConnectException, match="OAuth error"):
+ register_client(FAKE_METADATA, FAKE_URL)
+
+ def test_missing_registration_endpoint(self):
+ metadata = {k: v for k, v in FAKE_METADATA.items() if k != "registration_endpoint"}
+ with pytest.raises(RSConnectException, match="registration_endpoint"):
+ register_client(metadata, FAKE_URL)
+
+
+class TestTokenExchange:
+ def test_success(self, mock_http_server: MagicMock):
+ mock_http_server.request.return_value = _make_response(
+ 200, {"access_token": "at-123", "refresh_token": "rt-456", "expires_in": 3600}
+ )
+ result = _exchange_code_for_token(
+ FAKE_METADATA, "client-1", "auth-code", "verifier", "http://127.0.0.1:8080/callback"
+ )
+ assert result["access_token"] == "at-123"
+
+ def test_invalid_client(self, mock_http_server: MagicMock):
+ mock_http_server.request.return_value = _make_response(401, {"error": "invalid_client"})
+ with pytest.raises(InvalidClientError):
+ _exchange_code_for_token(
+ FAKE_METADATA, "bad-client", "auth-code", "verifier", "http://127.0.0.1:8080/callback"
+ )
+
+
+class TestDeviceCodeFlow:
+ def test_device_code_not_supported(self):
+ metadata = {k: v for k, v in FAKE_METADATA.items() if k != "device_authorization_endpoint"}
+ with pytest.raises(RSConnectException, match="does not support the device code flow"):
+ login_with_device_code(FAKE_URL, "client-1", metadata)
+
+ @patch("rsconnect.oauth.time.sleep")
+ def test_poll_success(self, _, mock_http_server: MagicMock):
+ mock_http_server.request.side_effect = [
+ _make_response(400, {"error": "authorization_pending"}),
+ _make_response(200, {"access_token": "at-final", "refresh_token": "rt-final"}),
+ ]
+ result = _poll_for_device_token(FAKE_METADATA, "client-1", "device-code-1", 5, 600)
+ assert result["access_token"] == "at-final"
+
+ @patch("rsconnect.oauth.time.sleep")
+ def test_poll_expired(self, _, mock_http_server: MagicMock):
+ mock_http_server.request.return_value = _make_response(400, {"error": "expired_token"})
+ with pytest.raises(RSConnectException, match="expired"):
+ _poll_for_device_token(FAKE_METADATA, "client-1", "device-code-1", 5, 600)
+
+ @patch("rsconnect.oauth.time.sleep")
+ def test_poll_invalid_client(self, _, mock_http_server: MagicMock):
+ mock_http_server.request.return_value = _make_response(401, {"error": "invalid_client"})
+ with pytest.raises(InvalidClientError):
+ _poll_for_device_token(FAKE_METADATA, "bad-client", "device-code-1", 5, 600)
+
+ @patch("rsconnect.oauth.time.sleep")
+ def test_poll_malformed_response(self, _, mock_http_server: MagicMock):
+ mock_http_server.request.return_value = _make_response(200, {})
+ with pytest.raises(RSConnectException, match="unexpected response"):
+ _poll_for_device_token(FAKE_METADATA, "client-1", "device-code-1", 5, 600)
+
+
+class TestRefreshAccessToken:
+ def test_success(self, mock_http_server: MagicMock):
+ mock_http_server.request.return_value = _make_response(
+ 200, {"access_token": "new-at", "refresh_token": "new-rt", "expires_in": 7200}
+ )
+ result = refresh_access_token(FAKE_METADATA, "client-1", "old-rt")
+ assert result["access_token"] == "new-at"
+
+ def test_invalid_client(self, mock_http_server: MagicMock):
+ mock_http_server.request.return_value = _make_response(401, {"error": "invalid_client"})
+ with pytest.raises(InvalidClientError):
+ refresh_access_token(FAKE_METADATA, "bad-client", "old-rt")
+
+
+class TestKeyringIntegration:
+ def test_store_success(self):
+ mock_keyring = MagicMock()
+ with patch.dict("sys.modules", {"keyring": mock_keyring}):
+ result = keyring_store_token("https://example.com", "at-1", "rt-1")
+ assert result is True
+ assert mock_keyring.set_password.call_count == 2
+
+ def test_store_no_keyring(self):
+ with patch.dict("sys.modules", {"keyring": None}):
+ result = keyring_store_token("https://example.com", "at-1", "rt-1")
+ assert result is False
+
+ def test_get_success(self):
+ mock_keyring = MagicMock()
+ mock_keyring.get_password.side_effect = lambda svc, key: "token-value"
+ with patch.dict("sys.modules", {"keyring": mock_keyring}):
+ access, refresh = keyring_get_tokens("https://example.com")
+ assert access == "token-value"
+ assert refresh == "token-value"
+
+ def test_get_no_keyring(self):
+ with patch.dict("sys.modules", {"keyring": None}):
+ access, refresh = keyring_get_tokens("https://example.com")
+ assert access is None
+ assert refresh is None
+
+ def test_delete_success(self):
+ mock_keyring = MagicMock()
+ mock_keyring_errors = MagicMock()
+ mock_keyring_errors.PasswordDeleteError = Exception
+ with patch.dict("sys.modules", {"keyring": mock_keyring, "keyring.errors": mock_keyring_errors}):
+ keyring_delete_tokens("https://example.com")
+ assert mock_keyring.delete_password.call_count == 2
+
+
+class TestLoginWithBrowser:
+ @patch("rsconnect.oauth.webbrowser.open", return_value=True)
+ @patch("rsconnect.oauth._exchange_code_for_token")
+ @patch("rsconnect.oauth._HTTPServer")
+ @patch("rsconnect.oauth.secrets.token_urlsafe", return_value="fixed-state")
+ def test_success(
+ self,
+ mock_token_urlsafe: MagicMock,
+ mock_httpserver_cls: MagicMock,
+ mock_exchange: MagicMock,
+ mock_browser: MagicMock,
+ ):
+ mock_exchange.return_value = {"access_token": "at-browser", "refresh_token": "rt-browser"}
+
+ mock_server_instance = MagicMock()
+ mock_server_instance.server_address = ("127.0.0.1", 9999)
+ mock_httpserver_cls.return_value = mock_server_instance
+
+ # When handle_request is called in the thread, simulate putting
+ # the auth code onto the result_queue that the function created
+ def fake_handle_request():
+ # The function sets result_queue on RequestHandlerClass before starting the thread
+ rq = mock_server_instance.RequestHandlerClass.result_queue
+ rq.put(("success", "auth-code-123", "fixed-state"))
+
+ mock_server_instance.handle_request.side_effect = fake_handle_request
+
+ result = login_with_browser(FAKE_URL, "client-1", FAKE_METADATA)
+
+ assert result == {"access_token": "at-browser", "refresh_token": "rt-browser"}
+ mock_browser.assert_called_once()
+
+
+class TestExecutorOAuthSetup:
+ @pytest.mark.parametrize(
+ "keyring_token,expected_token",
+ [
+ ("keyring-access-token", "keyring-access-token"),
+ (None, "stored-fallback-token"),
+ ],
+ ids=["keyring-available", "keyring-fallback"],
+ )
+ def test_setup_remote_server_with_oauth_entry(self, keyring_token, expected_token):
+ from rsconnect.api import RSConnectExecutor, RSConnectServer
+
+ with patch("rsconnect.oauth.keyring_get_tokens", return_value=(keyring_token, None)):
+ with patch("rsconnect.metadata.ServerStore.resolve") as mock_resolve:
+ mock_resolve.return_value = ServerData(
+ name="myserver",
+ url=FAKE_URL,
+ from_store=True,
+ oauth_client_id="client-123",
+ oauth_access_token="stored-fallback-token",
+ )
+
+ executor = RSConnectExecutor.__new__(RSConnectExecutor)
+ executor.logger = None
+ executor.ctx = None
+ executor.setup_remote_server(ctx=None, name="myserver")
+
+ assert isinstance(executor.remote_server, RSConnectServer)
+ assert executor.remote_server.oauth_access_token == expected_token
+ assert executor.remote_server.oauth_client_id == "client-123"
+
+
+class TestRefreshTokenFallback:
+ @pytest.mark.parametrize(
+ "server_name,get_by_name_rv,get_by_url_rv",
+ [
+ ("testserver", {"oauth_refresh_token": "stored-rt"}, None),
+ (FAKE_URL, None, {"oauth_refresh_token": "url-rt", "name": "myserver"}),
+ (None, None, {"oauth_refresh_token": "url-rt", "name": "resolved-name"}),
+ ],
+ ids=["name-lookup", "url-fallback", "no-server-name"],
+ )
+ @patch("rsconnect.oauth.HTTPServer")
+ @patch("rsconnect.oauth.keyring_get_tokens", return_value=(None, None))
+ @patch("rsconnect.oauth.keyring_store_token", return_value=False)
+ def test_refresh_falls_back_to_store(
+ self,
+ mock_keyring_store: MagicMock,
+ mock_keyring_get: MagicMock,
+ mock_http_server_cls: MagicMock,
+ server_name: "str | None",
+ get_by_name_rv: Any,
+ get_by_url_rv: Any,
+ ):
+ from rsconnect.api import RSConnectClient, RSConnectServer
+ from rsconnect.metadata import ServerStore
+
+ mock_http_server = MagicMock()
+ mock_http_server_cls.return_value = mock_http_server
+ mock_http_server.__enter__ = MagicMock(return_value=mock_http_server)
+ mock_http_server.__exit__ = MagicMock(return_value=False)
+
+ mock_http_server.get.return_value = _make_response(200, FAKE_METADATA)
+ mock_http_server.request.return_value = _make_response(
+ 200, {"access_token": "new-at", "refresh_token": "new-rt", "expires_in": 3600}
+ )
+
+ server = RSConnectServer(
+ FAKE_URL,
+ None,
+ False,
+ None,
+ oauth_access_token="old-at",
+ oauth_client_id="client-1",
+ server_name=server_name,
+ )
+
+ with patch.object(ServerStore, "get_by_name", return_value=get_by_name_rv):
+ with patch.object(ServerStore, "get_by_url", return_value=get_by_url_rv):
+ with patch.object(ServerStore, "update_oauth_tokens"):
+ client = RSConnectClient(server)
+ result = client._attempt_token_refresh()
+
+ assert result is True
+
+
+class TestStreamBodyRetry:
+ def test_seekable_stream_rewinds_for_retry(self):
+ import io
+
+ from rsconnect.api import RSConnectClient, RSConnectServer
+ from rsconnect.http_support import HTTPServer as _HTTPServer
+
+ server = RSConnectServer(
+ FAKE_URL,
+ None,
+ False,
+ None,
+ oauth_access_token="old-at",
+ oauth_client_id="client-1",
+ server_name="testserver",
+ )
+
+ client = RSConnectClient(server)
+ stream_body = io.BytesIO(b"bundle-payload-data")
+ call_bodies: list[object] = []
+
+ def fake_super_request(
+ self_arg, method, path, query_params, body, maximum_redirects=5, decode_response=True, headers=None
+ ):
+ call_bodies.append(body.read() if hasattr(body, "read") else body)
+ return _make_response(401, None) if len(call_bodies) == 1 else _make_response(200, {"result": "ok"})
+
+ with patch.object(_HTTPServer, "request", fake_super_request):
+ with patch.object(client, "_attempt_token_refresh", return_value=True):
+ client.request("POST", "/v1/content/upload", body=stream_body)
+
+ assert len(call_bodies) == 2
+ assert call_bodies[0] == b"bundle-payload-data"
+ assert call_bodies[1] == b"bundle-payload-data"
+
+ def test_non_seekable_stream_buffered_for_retry(self):
+ import io
+
+ from rsconnect.api import RSConnectClient, RSConnectServer
+ from rsconnect.http_support import HTTPServer as _HTTPServer
+
+ class NonSeekableStream(io.RawIOBase):
+ def __init__(self, data: bytes):
+ self._data = data
+ self._pos = 0
+
+ def read(self, size=-1):
+ if size == -1:
+ result = self._data[self._pos :]
+ else:
+ result = self._data[self._pos : self._pos + size]
+ self._pos += len(result)
+ return result
+
+ def readable(self):
+ return True
+
+ def seekable(self):
+ return False
+
+ server = RSConnectServer(
+ FAKE_URL,
+ None,
+ False,
+ None,
+ oauth_access_token="old-at",
+ oauth_client_id="client-1",
+ server_name="testserver",
+ )
+
+ client = RSConnectClient(server)
+ stream_body = NonSeekableStream(b"bundle-payload-data")
+ call_bodies: list[object] = []
+
+ def fake_super_request(
+ self_arg, method, path, query_params, body, maximum_redirects=5, decode_response=True, headers=None
+ ):
+ call_bodies.append(body)
+ return _make_response(401, None) if len(call_bodies) == 1 else _make_response(200, {"result": "ok"})
+
+ with patch.object(_HTTPServer, "request", fake_super_request):
+ with patch.object(client, "_attempt_token_refresh", return_value=True):
+ client.request("POST", "/v1/content/upload", body=stream_body)
+
+ assert len(call_bodies) == 2
+ assert call_bodies[0] == b"bundle-payload-data"
+ assert call_bodies[1] == b"bundle-payload-data"
+
+
+class TestLoginCommand:
+ @patch("rsconnect.oauth.keyring_store_token", return_value=True)
+ @patch("rsconnect.oauth.login_with_browser")
+ @patch("rsconnect.oauth.register_client", return_value="new-client-id")
+ @patch("rsconnect.oauth.discover_oauth_metadata")
+ def test_login_success(
+ self,
+ mock_discover: MagicMock,
+ mock_register: MagicMock,
+ mock_login: MagicMock,
+ mock_keyring: MagicMock,
+ ):
+ from click.testing import CliRunner
+
+ from rsconnect.main import cli
+
+ mock_discover.return_value = FAKE_METADATA
+ mock_login.return_value = {"access_token": "at-1", "refresh_token": "rt-1", "expires_in": 3600}
+
+ runner = CliRunner()
+ result = runner.invoke(cli, ["login", "--server", FAKE_URL, "--name", "test-server"])
+
+ assert result.exit_code == 0, result.output
+ assert "Logged in" in result.output
+
+ def test_login_missing_server(self):
+ from click.testing import CliRunner
+
+ from rsconnect.main import cli
+
+ runner = CliRunner()
+ result = runner.invoke(cli, ["login"])
+
+ assert "Usage:" in result.output
+
+ @patch("rsconnect.oauth.keyring_delete_tokens")
+ @patch("rsconnect.main.server_store")
+ def test_logout_non_oauth_entry(self, mock_store: MagicMock, mock_keyring_del: MagicMock):
+ from click.testing import CliRunner
+
+ from rsconnect.main import cli
+
+ mock_store.get_by_name.return_value = {"name": "myserver", "url": FAKE_URL, "api_key": "key-123"}
+
+ runner = CliRunner()
+ result = runner.invoke(cli, ["logout", "--name", "myserver"])
+
+ assert result.exit_code != 0
+ assert "not an OAuth" in result.output or "rsconnect remove" in result.output
+
+ @patch("rsconnect.oauth.keyring_delete_tokens")
+ @patch("rsconnect.main.server_store")
+ def test_logout_success(self, mock_store: MagicMock, mock_keyring_del: MagicMock):
+ from click.testing import CliRunner
+
+ from rsconnect.main import cli
+
+ mock_store.get_by_name.return_value = {
+ "name": "myserver",
+ "url": FAKE_URL,
+ "oauth_client_id": "client-123",
+ }
+ mock_store.update_oauth_tokens = MagicMock()
+
+ runner = CliRunner()
+ result = runner.invoke(cli, ["logout", "--name", "myserver"])
+
+ assert result.exit_code == 0, result.output
+ mock_keyring_del.assert_called_once()
+
+
+class TestListCommand:
+ @patch("rsconnect.oauth.keyring_get_tokens", return_value=("at-from-keyring", None))
+ @patch("rsconnect.main.server_store")
+ def test_list_oauth_entry_with_keyring(self, mock_store: MagicMock, mock_keyring: MagicMock):
+ from click.testing import CliRunner
+
+ from rsconnect.main import cli
+
+ mock_store.get_all_servers.return_value = [
+ {"name": "myserver", "url": FAKE_URL, "oauth_client_id": "client-abc"},
+ ]
+ mock_store.get_path.return_value = "/tmp/servers.json"
+
+ runner = CliRunner()
+ result = runner.invoke(cli, ["list"])
+
+ assert result.exit_code == 0, result.output
+ assert "OAuth Client ID: client-abc" in result.output
+ assert "Credentials stored in system keyring" in result.output
+
+ @patch("rsconnect.oauth.keyring_get_tokens", return_value=(None, None))
+ @patch("rsconnect.main.server_store")
+ def test_list_oauth_entry_without_keyring(self, mock_store: MagicMock, mock_keyring: MagicMock):
+ from click.testing import CliRunner
+
+ from rsconnect.main import cli
+
+ mock_store.get_all_servers.return_value = [
+ {"name": "myserver", "url": FAKE_URL, "oauth_client_id": "client-abc"},
+ ]
+ mock_store.get_path.return_value = "/tmp/servers.json"
+
+ runner = CliRunner()
+ result = runner.invoke(cli, ["list"])
+
+ assert result.exit_code == 0, result.output
+ assert "OAuth Client ID: client-abc" in result.output
+ assert "Credentials stored in system keyring" not in result.output