diff --git a/docs/ref/extensions/sandbox/sprites/sandbox.md b/docs/ref/extensions/sandbox/sprites/sandbox.md new file mode 100644 index 0000000000..02e07a8dd1 --- /dev/null +++ b/docs/ref/extensions/sandbox/sprites/sandbox.md @@ -0,0 +1,3 @@ +# `Sandbox` + +::: agents.extensions.sandbox.sprites.sandbox diff --git a/docs/sandbox/clients.md b/docs/sandbox/clients.md index bd21da63d3..991660ac72 100644 --- a/docs/sandbox/clients.md +++ b/docs/sandbox/clients.md @@ -96,6 +96,7 @@ For provider-specific setup notes and links for the checked-in extension example | `E2BSandboxClient` | `openai-agents[e2b]` | [E2B runner](https://github.com/openai/openai-agents-python/blob/main/examples/sandbox/extensions/e2b_runner.py) | | `ModalSandboxClient` | `openai-agents[modal]` | [Modal runner](https://github.com/openai/openai-agents-python/blob/main/examples/sandbox/extensions/modal_runner.py) | | `RunloopSandboxClient` | `openai-agents[runloop]` | [Runloop runner](https://github.com/openai/openai-agents-python/blob/main/examples/sandbox/extensions/runloop/runner.py) | +| `SpritesSandboxClient` | `openai-agents[sprites]` | [Sprites runner](https://github.com/openai/openai-agents-python/blob/main/examples/sandbox/extensions/sprites_runner.py) | | `VercelSandboxClient` | `openai-agents[vercel]` | [Vercel runner](https://github.com/openai/openai-agents-python/blob/main/examples/sandbox/extensions/vercel_runner.py) | @@ -113,6 +114,7 @@ Hosted sandbox clients expose provider-specific mount strategies. Choose the bac | `DaytonaSandboxClient` | Supports rclone-backed cloud storage mounts with `DaytonaCloudBucketMountStrategy`; use it with `S3Mount`, `GCSMount`, `R2Mount`, `AzureBlobMount`, and `BoxMount`. | | `E2BSandboxClient` | Supports rclone-backed cloud storage mounts with `E2BCloudBucketMountStrategy`; use it with `S3Mount`, `GCSMount`, `R2Mount`, `AzureBlobMount`, and `BoxMount`. | | `RunloopSandboxClient` | Supports rclone-backed cloud storage mounts with `RunloopCloudBucketMountStrategy`; use it with `S3Mount`, `GCSMount`, `R2Mount`, `AzureBlobMount`, and `BoxMount`. | +| `SpritesSandboxClient` | Supports rclone-backed cloud storage mounts with `SpritesCloudBucketMountStrategy`; use it with `S3Mount`, `GCSMount`, `R2Mount`, `AzureBlobMount`, and `BoxMount`. The strategy lazy-installs `rclone` and `fuse` via `sudo apt-get` if the sprite image does not preinstall them. Sprites exposes at most one external HTTP port per sprite (declared as a service in the sprite image); other ports must be reverse-proxied inside the VM. | | `VercelSandboxClient` | No hosted-specific mount strategy is currently exposed. Use manifest files, repos, or other workspace inputs instead. | @@ -130,6 +132,7 @@ The table below summarizes which remote storage entries each backend can mount d | `DaytonaSandboxClient` | ✓ | ✓ | ✓ | ✓ | ✓ | - | | `E2BSandboxClient` | ✓ | ✓ | ✓ | ✓ | ✓ | - | | `RunloopSandboxClient` | ✓ | ✓ | ✓ | ✓ | ✓ | - | +| `SpritesSandboxClient` | ✓ | ✓ | ✓ | ✓ | ✓ | - | | `VercelSandboxClient` | - | - | - | - | - | - | diff --git a/examples/sandbox/extensions/README.md b/examples/sandbox/extensions/README.md index 837d9dfa28..0109b2bb73 100644 --- a/examples/sandbox/extensions/README.md +++ b/examples/sandbox/extensions/README.md @@ -7,7 +7,7 @@ They intentionally keep the flow simple: 1. Build a tiny manifest in memory. 2. Create a `SandboxAgent` that inspects that workspace through one shell tool. -3. Run the agent against E2B, Modal, Daytona, Cloudflare, Runloop, Blaxel, or Vercel. +3. Run the agent against E2B, Modal, Daytona, Cloudflare, Runloop, Blaxel, Sprites, or Vercel. All of these examples require `OPENAI_API_KEY`, because they call the model through the normal `Runner` path. Each cloud backend also needs its own provider credentials. @@ -328,6 +328,52 @@ the default home and working directory become `/root`, so the example also uses `/root` as its manifest workspace root. If you configure root launch in your own code, either rely on that root-mode default or explicitly choose a `manifest.root` under `/root`. +## Sprites + +### Setup + +Install the repo extra: + +```bash +uv sync --extra sprites +``` + +Create a Sprites organization and API token at [sprites.dev](https://sprites.dev/), +and export the required environment variables: + +```bash +export OPENAI_API_KEY=... +export SPRITES_API_TOKEN=... +# Optional, defaults to https://api.sprites.dev: +# export SPRITES_API_URL=https://api.sprites.dev +``` + +### Run + +```bash +uv run python examples/sandbox/extensions/sprites_runner.py --stream +``` + +Useful flags: + +- `--sprite-name ` — attach to an existing sprite instead of creating an + ephemeral one. The example skips delete-on-exit when this is set. +- `--skip-snapshot-check` — skip the tar workspace persistence verification. +- `--question "..."` — override the default prompt. + +The Sprites client resolves the API token from `SPRITES_API_TOKEN` (override via +`SpritesSandboxClient(token=...)`) and supports exec, filesystem read/write, +PTY-mode interactive exec, and tar-based workspace snapshots. Sprites exposes +at most one external HTTP port per sprite — declare it as a service with +`--http-port` in the sprite image, then reference it via +`SpritesSandboxClientOptions(exposed_ports=(,))`. + +For cloud-bucket mounts, attach `SpritesCloudBucketMountStrategy` from +`agents.extensions.sandbox.sprites` to any rclone-compatible mount type +(`S3Mount`, `R2Mount`, `GCSMount`, `AzureBlobMount`, `BoxMount`). The strategy +lazy-installs `rclone` and the `fuse` package via `sudo apt-get` on first use +if the sprite image does not preinstall them. + ## Blaxel ### Setup diff --git a/examples/sandbox/extensions/sprites_runner.py b/examples/sandbox/extensions/sprites_runner.py new file mode 100644 index 0000000000..118f98c347 --- /dev/null +++ b/examples/sandbox/extensions/sprites_runner.py @@ -0,0 +1,225 @@ +""" +Minimal Sprites-backed sandbox example for manual validation. + +This example creates a small in-memory workspace, lets the agent inspect it +through one shell tool, and prints a short answer. By default an ephemeral +sprite is created and deleted at the end; pass ``--sprite-name `` to +attach to an existing sprite instead. +""" + +from __future__ import annotations + +import argparse +import asyncio +import io +import os +import sys +import tempfile +from pathlib import Path +from typing import cast + +from openai.types.responses import ResponseTextDeltaEvent + +from agents import ModelSettings, Runner +from agents.run import RunConfig +from agents.sandbox import LocalSnapshotSpec, Manifest, SandboxAgent, SandboxRunConfig +from agents.sandbox.session import BaseSandboxSession + +if __package__ is None or __package__ == "": + sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +from examples.sandbox.misc.example_support import text_manifest # noqa: E402 +from examples.sandbox.misc.workspace_shell import WorkspaceShellCapability # noqa: E402 + +try: + from agents.extensions.sandbox import ( + SpritesSandboxClient, + SpritesSandboxClientOptions, + ) +except Exception as exc: # pragma: no cover - import path depends on optional extras + raise SystemExit( + "Sprites sandbox examples require the optional repo extra.\n" + "Install it with: uv sync --extra sprites" + ) from exc + + +DEFAULT_QUESTION = "Summarize this sandbox workspace in 2 sentences." +SNAPSHOT_CHECK_PATH = Path("snapshot-check.txt") +SNAPSHOT_CHECK_CONTENT = "sprites snapshot round-trip ok\n" + + +def _build_manifest() -> Manifest: + return text_manifest( + { + "README.md": ( + "# Sprites Demo Workspace\n\n" + "This workspace exists to validate the Sprites sandbox backend manually.\n" + ), + "handoff.md": ( + "# Handoff\n\n" + "- Customer: Northwind Traders.\n" + "- Goal: validate Sprites sandbox exec and persistence flows.\n" + "- Current status: v1 backend slice (exec + fs + PTY) is wired and under test.\n" + ), + "todo.md": ( + "# Todo\n\n" + "1. Inspect the workspace files.\n" + "2. Summarize the current status in two sentences.\n" + ), + } + ) + + +def _require_env(name: str) -> None: + if os.environ.get(name): + return + raise SystemExit(f"{name} must be set before running this example.") + + +async def _read_text(session: BaseSandboxSession, path: Path) -> str: + data = await session.read(path) + text = cast(str | bytes, data.read()) + if isinstance(text, bytes): + return text.decode("utf-8") + return text + + +async def _verify_stop_resume(*, sprite_name: str | None) -> None: + """Round-trip a workspace through tar persistence and reattach. + + With ``sprite_name=None`` an ephemeral sprite is created, persisted, and + then resumed against itself. With a named sprite the same flow runs + against the existing sprite (no create/delete on the API). + """ + + client = SpritesSandboxClient() + options = SpritesSandboxClientOptions(sprite_name=sprite_name) + + with tempfile.TemporaryDirectory(prefix="sprites-snapshot-example-") as snapshot_dir: + sandbox = await client.create( + manifest=_build_manifest(), + snapshot=LocalSnapshotSpec(base_path=Path(snapshot_dir)), + options=options, + ) + + try: + await sandbox.start() + await sandbox.write( + SNAPSHOT_CHECK_PATH, + io.BytesIO(SNAPSHOT_CHECK_CONTENT.encode("utf-8")), + ) + await sandbox.stop() + finally: + await sandbox.shutdown() + + resumed = await client.resume(sandbox.state) + try: + await resumed.start() + restored = await _read_text(resumed, SNAPSHOT_CHECK_PATH) + if restored != SNAPSHOT_CHECK_CONTENT: + raise RuntimeError( + f"Snapshot resume verification failed: expected " + f"{SNAPSHOT_CHECK_CONTENT!r}, got {restored!r}" + ) + finally: + await resumed.aclose() + if sprite_name is None: + # Ephemeral sandbox should clean up the sprite created by ``resume``. + await client.delete(resumed) + + print("snapshot round-trip ok") + + +async def main( + *, + model: str, + question: str, + sprite_name: str | None, + skip_snapshot_check: bool, + stream: bool, +) -> None: + _require_env("OPENAI_API_KEY") + _require_env("SPRITES_API_TOKEN") + + if not skip_snapshot_check: + await _verify_stop_resume(sprite_name=sprite_name) + + manifest = _build_manifest() + agent = SandboxAgent( + name="Sprites Sandbox Assistant", + model=model, + instructions=( + "Answer questions about the sandbox workspace. Inspect the files before answering " + "and keep the response concise. Cite the file names you inspected." + ), + default_manifest=manifest, + capabilities=[WorkspaceShellCapability()], + model_settings=ModelSettings(tool_choice="required"), + ) + + client = SpritesSandboxClient() + sandbox = await client.create( + manifest=manifest, + options=SpritesSandboxClientOptions(sprite_name=sprite_name), + ) + + run_config = RunConfig( + sandbox=SandboxRunConfig(session=sandbox), + tracing_disabled=True, + workflow_name="Sprites sandbox example", + ) + + try: + async with sandbox: + if not stream: + result = await Runner.run(agent, question, run_config=run_config) + print(result.final_output) + return + + stream_result = Runner.run_streamed(agent, question, run_config=run_config) + saw_text_delta = False + async for event in stream_result.stream_events(): + if event.type == "raw_response_event" and isinstance( + event.data, ResponseTextDeltaEvent + ): + if not saw_text_delta: + print("assistant> ", end="", flush=True) + saw_text_delta = True + print(event.data.delta, end="", flush=True) + + if saw_text_delta: + print() + finally: + await client.delete(sandbox) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="gpt-5.5", help="Model name to use.") + parser.add_argument("--question", default=DEFAULT_QUESTION, help="Prompt to send to the agent.") + parser.add_argument( + "--sprite-name", + default=None, + help=( + "Existing sprite to attach to. When omitted, an ephemeral sprite is " + "created and deleted automatically." + ), + ) + parser.add_argument( + "--skip-snapshot-check", + action="store_true", + default=False, + help="Skip the tar workspace persistence verification before the agent run.", + ) + parser.add_argument("--stream", action="store_true", default=False, help="Stream the response.") + args = parser.parse_args() + + asyncio.run( + main( + model=args.model, + question=args.question, + sprite_name=args.sprite_name, + skip_snapshot_check=args.skip_snapshot_check, + stream=args.stream, + ) + ) diff --git a/pyproject.toml b/pyproject.toml index 2fd8547ee9..808f0ccaa6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ cloudflare = ["aiohttp>=3.12,<4"] e2b = ["e2b==2.20.0", "e2b-code-interpreter==2.4.1"] modal = ["modal==1.3.5"] runloop = ["runloop_api_client>=1.16.0,<2.0.0"] +sprites = ["sprites-py>=0.0.1rc37,<0.2"] vercel = ["vercel>=0.5.6,<0.6"] s3 = ["boto3>=1.34"] temporal = [ @@ -164,6 +165,10 @@ ignore_missing_imports = true module = ["vercel", "vercel.*"] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = ["sprites", "sprites.*"] +ignore_missing_imports = true + [tool.coverage.run] source = ["src/agents"] omit = [ diff --git a/src/agents/extensions/sandbox/__init__.py b/src/agents/extensions/sandbox/__init__.py index d7b082ba1f..53bb0c88f6 100644 --- a/src/agents/extensions/sandbox/__init__.py +++ b/src/agents/extensions/sandbox/__init__.py @@ -109,6 +109,26 @@ except Exception: # pragma: no cover _HAS_VERCEL = False +try: + from .sprites import ( + DEFAULT_SPRITES_API_URL as DEFAULT_SPRITES_API_URL, + DEFAULT_SPRITES_CONTEXT_PATH as DEFAULT_SPRITES_CONTEXT_PATH, + DEFAULT_SPRITES_WAIT_FOR_RUNNING_TIMEOUT_S as DEFAULT_SPRITES_WAIT_FOR_RUNNING_TIMEOUT_S, # noqa: E501 + DEFAULT_SPRITES_WORKSPACE_ROOT as DEFAULT_SPRITES_WORKSPACE_ROOT, + SpritesCheckpoints as SpritesCheckpoints, + SpritesCloudBucketMountStrategy as SpritesCloudBucketMountStrategy, + SpritesPlatformContext as SpritesPlatformContext, + SpritesSandboxClient as SpritesSandboxClient, + SpritesSandboxClientOptions as SpritesSandboxClientOptions, + SpritesSandboxSession as SpritesSandboxSession, + SpritesSandboxSessionState as SpritesSandboxSessionState, + SpritesUrlAccess as SpritesUrlAccess, + ) + + _HAS_SPRITES = True +except Exception: # pragma: no cover + _HAS_SPRITES = False + __all__: list[str] = [] if _HAS_E2B: @@ -207,3 +227,21 @@ "RunloopUserParameters", ] ) + +if _HAS_SPRITES: + __all__.extend( + [ + "DEFAULT_SPRITES_API_URL", + "DEFAULT_SPRITES_CONTEXT_PATH", + "DEFAULT_SPRITES_WAIT_FOR_RUNNING_TIMEOUT_S", + "DEFAULT_SPRITES_WORKSPACE_ROOT", + "SpritesCheckpoints", + "SpritesCloudBucketMountStrategy", + "SpritesPlatformContext", + "SpritesSandboxClient", + "SpritesSandboxClientOptions", + "SpritesSandboxSession", + "SpritesSandboxSessionState", + "SpritesUrlAccess", + ] + ) diff --git a/src/agents/extensions/sandbox/sprites/__init__.py b/src/agents/extensions/sandbox/sprites/__init__.py new file mode 100644 index 0000000000..d91669cacf --- /dev/null +++ b/src/agents/extensions/sandbox/sprites/__init__.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from .capabilities import ( + DEFAULT_SPRITES_CONTEXT_PATH, + SpritesCheckpoints, + SpritesPlatformContext, + SpritesUrlAccess, + UrlVisibility, + clear_platform_context_cache, +) +from .mounts import SpritesCloudBucketMountStrategy +from .sandbox import ( + DEFAULT_SPRITES_API_URL, + DEFAULT_SPRITES_WAIT_FOR_RUNNING_TIMEOUT_S, + DEFAULT_SPRITES_WORKSPACE_ROOT, + SpritesSandboxClient, + SpritesSandboxClientOptions, + SpritesSandboxSession, + SpritesSandboxSessionState, +) + +__all__ = [ + "DEFAULT_SPRITES_API_URL", + "DEFAULT_SPRITES_CONTEXT_PATH", + "DEFAULT_SPRITES_WAIT_FOR_RUNNING_TIMEOUT_S", + "DEFAULT_SPRITES_WORKSPACE_ROOT", + "SpritesCheckpoints", + "SpritesCloudBucketMountStrategy", + "SpritesPlatformContext", + "SpritesSandboxClient", + "SpritesSandboxClientOptions", + "SpritesSandboxSession", + "SpritesSandboxSessionState", + "SpritesUrlAccess", + "UrlVisibility", + "clear_platform_context_cache", +] diff --git a/src/agents/extensions/sandbox/sprites/capabilities.py b/src/agents/extensions/sandbox/sprites/capabilities.py new file mode 100644 index 0000000000..6b6fe7a6f3 --- /dev/null +++ b/src/agents/extensions/sandbox/sprites/capabilities.py @@ -0,0 +1,409 @@ +"""Sprites-specific agent capabilities.""" + +from __future__ import annotations + +import asyncio +from typing import Any, Literal + +from ....run_context import RunContextWrapper +from ....sandbox.capabilities.capability import Capability +from ....sandbox.manifest import Manifest +from ....sandbox.session.base_sandbox_session import BaseSandboxSession +from ....tool import Tool, function_tool + +DEFAULT_SPRITES_CONTEXT_PATH = "/.sprite/llm.txt" + +UrlVisibility = Literal["public", "sprite"] +"""Sprites URL visibility values. ``"sprite"`` restricts the URL to organization +members (the platform's default); ``"public"`` opens it to the internet.""" + +# Module-level cache of the framed platform-context text keyed by sprite name. +# ``Capability.clone`` runs every agent turn and resets per-instance attribute +# state, so a per-instance cache would re-exec ``cat /.sprite/llm.txt`` every +# turn — waking a paused sprite for nothing on turns where the model never +# calls a tool. Caching at module scope by sprite name lets the file land +# exactly once per sprite for the life of the process. ``clear_platform_context_cache`` +# below is exposed for applications that want to force a re-fetch (e.g. +# after a sprite image upgrade). +_PLATFORM_CONTEXT_CACHE: dict[tuple[str, str, str], str] = {} + + +def clear_platform_context_cache(sprite_name: str | None = None, path: str | None = None) -> None: + """Forget cached platform-context text. + + With no arguments, clears every entry. Pass ``sprite_name`` (and optionally + ``path``) to evict a specific entry. + """ + + if sprite_name is None: + _PLATFORM_CONTEXT_CACHE.clear() + return + for key in list(_PLATFORM_CONTEXT_CACHE.keys()): + if key[0] != sprite_name: + continue + if path is not None and key[1] != path: + continue + del _PLATFORM_CONTEXT_CACHE[key] + + +class SpritesPlatformContext(Capability): + """Inject the sprite's ``/.sprite/llm.txt`` platform-context file into the agent's instructions. + + Sprites bundle an LLM-facing document at ``/.sprite/llm.txt`` describing + available CLI commands (``sprite-env services``, ``sprite-env checkpoints``), + platform behavior (URL routing, idle pause, network policy), and security + rules (e.g. "HTTP services may become PUBLIC — never expose secrets"). + + Adding this capability to a ``SandboxAgent`` reads the file once per + session and appends its contents to the system prompt so the model can + use the platform's own primitives correctly without the application + embedding sprite-specific guidance into its instructions. + + Example: + + agent = SandboxAgent( + ..., + capabilities=[ + WorkspaceShellCapability(), + Filesystem(), + SpritesPlatformContext(), + ], + ) + + The file is read via ``session.exec("cat", path)``, which bypasses the + workspace path-validation that would otherwise reject paths outside the + manifest root. + """ + + type: Literal["sprites_platform_context"] = "sprites_platform_context" + path: str = DEFAULT_SPRITES_CONTEXT_PATH + """Sprite-side path of the context file. Defaults to ``/.sprite/llm.txt``.""" + + timeout_s: float = 5.0 + """Timeout for the ``cat`` exec call.""" + + async def instructions(self, manifest: Manifest) -> str | None: + session = self.session + if session is None: + return None + + sprite_name = _resolve_sprite_name(session) + workspace_root = manifest.root + # Cache key includes workspace root because the framing references + # manifest.root verbatim — different roots produce different text. + cache_key = (sprite_name or "", self.path, workspace_root) + cached = _PLATFORM_CONTEXT_CACHE.get(cache_key) + if cached is not None: + return cached + + try: + result = await session.exec("cat", "--", self.path, shell=False, timeout=self.timeout_s) + except Exception: + return None + if not result.ok(): + return None + + text = result.stdout.decode("utf-8", errors="replace").strip() + if not text: + return None + + framed = ( + "The following is platform context for the Sprites sandbox you are running " + "in. It describes available CLI commands (e.g. `sprite-env services`, " + "`sprite-env checkpoints`), platform behavior, and security rules. Treat " + "it as authoritative when choosing how to interact with the sandbox.\n\n" + "\n" + f"{text}\n" + "\n\n" + f"Important: this agent's workspace root is `{workspace_root}`. Sprites " + f"services created via `sprite-env services create` run with their own " + f"working directory (typically the user's home directory) — NOT in the " + f"workspace. ALWAYS pass `--dir {workspace_root}` (or a workspace " + f"subdirectory) to `sprite-env services create` so the service starts " + f"in the right place. Example:\n\n" + f" sprite-env services create web \\\n" + f" --cmd python3 --args -m,http.server,8080 \\\n" + f" --dir {workspace_root} \\\n" + f" --http-port 8080\n\n" + f"Without `--dir`, an HTTP server will list the home directory and any " + f"file-reading service will look in the wrong place." + ) + if sprite_name: + _PLATFORM_CONTEXT_CACHE[cache_key] = framed + return framed + + +def _resolve_sprite_handle(session: BaseSandboxSession | None) -> Any | None: + """Return the underlying ``sprites.Sprite`` from a SpritesSandboxSession, or None. + + Capabilities are bound to the runtime ``SandboxSession`` wrapper, not the + inner backend session — so we dig through ``_inner`` to reach the + SpritesSandboxSession's ``_sprite`` attribute. + """ + + if session is None: + return None + inner = getattr(session, "_inner", session) + sprite = getattr(inner, "_sprite", None) + return sprite + + +def _resolve_sprite_name(session: BaseSandboxSession | None) -> str | None: + """Return the underlying sprite's name, or None if not yet known.""" + + if session is None: + return None + inner = getattr(session, "_inner", session) + state = getattr(inner, "state", None) + name = getattr(state, "sprite_name", None) if state is not None else None + return name if isinstance(name, str) and name else None + + +class SpritesUrlAccess(Capability): + """Expose a tool that lets the agent toggle the sprite's public URL visibility. + + Sprite URL access is a *host-platform* setting, not something the in-VM + ``sprite-env`` CLI can change — the in-VM API socket only exposes + services/checkpoints. Without this capability, an agent asked to "make the + URL public" tends to thrash between unauthenticated commands. This + capability wraps ``Sprite.update_url_settings`` (which already has the + application's API token via ``SpritesSandboxClient``) so the model can + flip visibility in one call. + + Going ``public`` is gated by ``allow_public`` (default ``False``). The + application must explicitly opt in to expose that option to the agent; + otherwise the tool only accepts ``"sprite"`` (org-members-only). + + Example: + + agent = SandboxAgent( + ..., + capabilities=[ + WorkspaceShellCapability(), + Filesystem(), + SpritesPlatformContext(), + SpritesUrlAccess(allow_public=True), + ], + ) + """ + + type: Literal["sprites_url_access"] = "sprites_url_access" + allow_public: bool = False + """When ``False`` (default), the tool refuses ``visibility="public"``.""" + + def tools(self) -> list[Tool]: + capability = self + allow_public = self.allow_public + if allow_public: + allowed_doc = ( + "Pass 'public' to make the sprite reachable from the open internet, " + "or 'sprite' to restrict it to organization members." + ) + else: + allowed_doc = ( + "Pass 'sprite' to restrict the sprite URL to organization members. " + "(The 'public' option has been disabled by application policy.)" + ) + + @function_tool(name_override="set_sprite_url_visibility") + async def set_sprite_url_visibility( + ctx: RunContextWrapper[Any], + visibility: UrlVisibility, + ) -> str: + """Change the sprite's public URL access mode.""" + + _ = ctx + return await capability._apply_visibility(visibility) + + # Stash a docstring fragment for tools that introspect descriptions. + setattr(set_sprite_url_visibility, "_allowed_doc", allowed_doc) # noqa: B010 + return [set_sprite_url_visibility] + + async def _apply_visibility(self, visibility: str) -> str: + if visibility not in ("public", "sprite"): + return f"error: visibility must be 'public' or 'sprite', got {visibility!r}" + if visibility == "public" and not self.allow_public: + return ( + "error: setting URL to 'public' is disabled by application policy. " + "Use visibility='sprite' to keep it private to org members." + ) + + sprite = _resolve_sprite_handle(self.session) + if sprite is None: + return "error: sprite handle not available (session not started?)" + try: + from sprites.types import URLSettings + + await asyncio.to_thread(sprite.update_url_settings, URLSettings(auth=visibility)) + except Exception as exc: # noqa: BLE001 + return f"error updating URL settings: {exc!r}" + return f"sprite URL visibility is now {visibility!r}" + + +class SpritesCheckpoints(Capability): + """Expose tools to create, list, and (optionally) restore native sprite checkpoints. + + Sprite checkpoints are point-in-time snapshots of the writable filesystem + overlay. They're a Sprites-specific feature — most other sandbox providers + don't have anything equivalent at this granularity. This capability lets + the agent take a checkpoint before risky multi-file work and (when + explicitly enabled) roll back to it. + + Restore is destructive — it replaces the entire workspace. Gate it + deliberately with ``allow_restore``. Default ``False``: the agent can save + checkpoints freely but cannot roll back without application opt-in. + + Example: + + agent = SandboxAgent( + ..., + capabilities=[ + ..., + SpritesCheckpoints(allow_restore=True), + ], + ) + """ + + type: Literal["sprites_checkpoints"] = "sprites_checkpoints" + allow_restore: bool = False + """When ``False`` (default), the restore tool is omitted entirely.""" + + def tools(self) -> list[Tool]: + capability = self + + @function_tool(name_override="create_sprite_checkpoint") + async def create_sprite_checkpoint( + ctx: RunContextWrapper[Any], + comment: str = "", + ) -> str: + """Create a sprite filesystem checkpoint and return its id and metadata.""" + + _ = ctx + return await capability._create(comment) + + @function_tool(name_override="list_sprite_checkpoints") + async def list_sprite_checkpoints( + ctx: RunContextWrapper[Any], + ) -> str: + """List all sprite checkpoints (most recent first).""" + + _ = ctx + return await capability._list() + + tools_list: list[Tool] = [create_sprite_checkpoint, list_sprite_checkpoints] + + if self.allow_restore: + + @function_tool(name_override="restore_sprite_checkpoint") + async def restore_sprite_checkpoint( + ctx: RunContextWrapper[Any], + checkpoint_id: str, + ) -> str: + """Restore the sprite filesystem to a previously-created checkpoint. + + DESTRUCTIVE: replaces the entire workspace with the checkpoint state. + Any uncommitted changes since the checkpoint are lost. + """ + + _ = ctx + return await capability._restore(checkpoint_id) + + tools_list.append(restore_sprite_checkpoint) + + return tools_list + + async def _create(self, comment: str) -> str: + sprite = _resolve_sprite_handle(self.session) + if sprite is None: + return "error: sprite handle not available (session not started?)" + + def _do_create() -> dict[str, Any]: + # ``Sprite.create_checkpoint`` returns an iterator of ``StreamMessage`` + # (no checkpoint id in the stream itself), so consume it and then + # pull the most-recent saved checkpoint from ``list_checkpoints``. + stream = sprite.create_checkpoint(comment) + errors: list[str] = [] + for msg in stream: + if getattr(msg, "type", "") == "error": + err = getattr(msg, "error", None) or getattr(msg, "data", None) + if err: + errors.append(str(err)) + if errors: + raise RuntimeError("; ".join(errors)) + existing = sprite.list_checkpoints() + # ``Current`` is the platform's live-state pointer that always + # appears at the top of the list; skip it so we report the actual + # saved snapshot we just made. + saved = [c for c in existing if str(getattr(c, "id", "")).lower() != "current"] + if not saved: + return {} + saved.sort(key=lambda c: c.create_time, reverse=True) + latest = saved[0] + return { + "id": latest.id, + "comment": latest.comment or "", + "created_at": latest.create_time.isoformat(), + } + + try: + result = await asyncio.to_thread(_do_create) + except Exception as exc: # noqa: BLE001 + return f"error creating checkpoint: {exc!r}" + if not result: + return "checkpoint creation completed but no checkpoint was found" + return ( + f"checkpoint created: id={result['id']!r}, " + f"comment={result['comment']!r}, created_at={result['created_at']!r}" + ) + + async def _list(self) -> str: + sprite = _resolve_sprite_handle(self.session) + if sprite is None: + return "error: sprite handle not available (session not started?)" + try: + checkpoints = await asyncio.to_thread(sprite.list_checkpoints) + except Exception as exc: # noqa: BLE001 + return f"error listing checkpoints: {exc!r}" + if not checkpoints: + return "no checkpoints" + rows = [ + f"- {c.id} (created {c.create_time.isoformat()})" + + (f": {c.comment}" if c.comment else "") + for c in checkpoints + ] + return "\n".join(rows) + + async def _restore(self, checkpoint_id: str) -> str: + if not self.allow_restore: + return "error: restore is disabled by application policy" + sprite = _resolve_sprite_handle(self.session) + if sprite is None: + return "error: sprite handle not available (session not started?)" + + def _do_restore() -> list[str]: + stream = sprite.restore_checkpoint(checkpoint_id) + errors: list[str] = [] + for msg in stream: + if getattr(msg, "type", "") == "error": + err = getattr(msg, "error", None) or getattr(msg, "data", None) + if err: + errors.append(str(err)) + return errors + + try: + errors = await asyncio.to_thread(_do_restore) + except Exception as exc: # noqa: BLE001 + return f"error restoring checkpoint: {exc!r}" + if errors: + return f"restore completed with errors: {'; '.join(errors)}" + return f"restored checkpoint {checkpoint_id!r}" + + +__all__ = [ + "DEFAULT_SPRITES_CONTEXT_PATH", + "SpritesCheckpoints", + "SpritesPlatformContext", + "SpritesUrlAccess", + "UrlVisibility", + "clear_platform_context_cache", +] diff --git a/src/agents/extensions/sandbox/sprites/mounts.py b/src/agents/extensions/sandbox/sprites/mounts.py new file mode 100644 index 0000000000..508d0ac7ee --- /dev/null +++ b/src/agents/extensions/sandbox/sprites/mounts.py @@ -0,0 +1,279 @@ +"""Mount strategy for Sprites sandboxes.""" + +from __future__ import annotations + +import shlex +from pathlib import Path +from typing import Literal + +from ....sandbox.entries.mounts.base import InContainerMountStrategy, Mount, MountStrategyBase +from ....sandbox.entries.mounts.patterns import RcloneMountPattern +from ....sandbox.errors import MountConfigError +from ....sandbox.materialization import MaterializedFile +from ....sandbox.session.base_sandbox_session import BaseSandboxSession + +# Sprite VMs run as the unprivileged ``sprite`` user with passwordless sudo. +# ``SpritesSandboxSession.exec`` rejects ``user=`` kwargs, so we prefix privileged +# commands with ``sudo -n`` instead of escalating through the framework. +_SUDO = "sudo -n" +_APT = ( + f"{_SUDO} env DEBIAN_FRONTEND=noninteractive DEBCONF_NOWARNINGS=yes apt-get -o Dpkg::Use-Pty=0" +) + +# Detection commands echo a sentinel into stdout based on the *local* shell's +# evaluation of the conditional. We rely on stdout instead of ``ExecResult.ok()`` +# because the sprite-env WS control protocol currently drops exec exit codes +# (the OP_COMPLETE envelope ships ``{"ok": true}`` with no exit-code field, so +# the Python client defaults to 0 for every command). Stdout sentinels are +# also more robust against tools that exit non-zero on benign warnings. +_PRESENT = "__SPRITES_PRESENT__" +_MISSING = "__SPRITES_MISSING__" +_MOUNTED = "__SPRITES_MOUNTED__" +_NOT_MOUNTED = "__SPRITES_NOT_MOUNTED__" + + +def _detect_cmd(condition: str) -> str: + """Return a shell snippet that prints _PRESENT or _MISSING based on `condition`.""" + + return f"if {condition}; then echo {_PRESENT}; else echo {_MISSING}; fi" + + +_RCLONE_CHECK = _detect_cmd("command -v rclone >/dev/null 2>&1 || test -x /usr/local/bin/rclone") +_FUSERMOUNT_CHECK = _detect_cmd( + "command -v fusermount3 >/dev/null 2>&1 || command -v fusermount >/dev/null 2>&1" +) +_FUSE_KERNEL_CHECK = _detect_cmd("test -c /dev/fuse && grep -qw fuse /proc/filesystems") +_APT_CHECK = _detect_cmd("command -v apt-get >/dev/null 2>&1") +_INSTALL_RCLONE_COMMANDS = ( + f"{_APT} update -qq", + f"{_APT} install -y -qq curl unzip ca-certificates fuse", + f"curl -fsSL https://rclone.org/install.sh | {_SUDO} bash", +) +# fuse package brings ``fusermount`` along — install it together with rclone +# so the FUSE-mode mount path works out-of-the-box on stock sprite images. +_INSTALL_FUSE_COMMANDS = ( + f"{_APT} update -qq", + f"{_APT} install -y -qq fuse", +) +_FUSE_ALLOW_OTHER = ( + f"{_SUDO} chmod a+rw /dev/fuse && " + f"{_SUDO} touch /etc/fuse.conf && " + "(grep -qxF user_allow_other /etc/fuse.conf || " + f"printf '\\nuser_allow_other\\n' | {_SUDO} tee -a /etc/fuse.conf >/dev/null)" +) + + +def _stdout_says(result: object, sentinel: str) -> bool: + stdout = getattr(result, "stdout", b"") or b"" + return sentinel.encode("ascii") in stdout + + +async def _ensure_fuse_support(session: BaseSandboxSession) -> None: + kernel = await session.exec("sh", "-lc", _FUSE_KERNEL_CHECK, shell=False) + if not _stdout_says(kernel, _PRESENT): + raise MountConfigError( + message="Sprites cloud bucket mounts require FUSE support in the kernel", + context={"missing": "fuse"}, + ) + + fusermount = await session.exec("sh", "-lc", _FUSERMOUNT_CHECK, shell=False) + if not _stdout_says(fusermount, _PRESENT): + apt = await session.exec("sh", "-lc", _APT_CHECK, shell=False) + if not _stdout_says(apt, _PRESENT): + raise MountConfigError( + message="fusermount is not installed and apt-get is unavailable; " + "preinstall the fuse package", + context={"package": "fuse"}, + ) + for command in _INSTALL_FUSE_COMMANDS: + await session.exec("sh", "-lc", command, shell=False, timeout=300) + recheck = await session.exec("sh", "-lc", _FUSERMOUNT_CHECK, shell=False) + if not _stdout_says(recheck, _PRESENT): + raise MountConfigError( + message="fuse install attempt completed but fusermount is still not on PATH", + context={"package": "fuse"}, + ) + + # /dev/fuse must be accessible to the unprivileged user and ``user_allow_other`` + # has to be enabled for ``--allow-other``. Failures here would be surfaced by + # the rclone mount itself; we don't gate on this exec's exit code because the + # control-WS protocol drops it. + await session.exec("sh", "-lc", _FUSE_ALLOW_OTHER, shell=False, timeout=30) + + +async def _ensure_rclone(session: BaseSandboxSession) -> None: + rclone = await session.exec("sh", "-lc", _RCLONE_CHECK, shell=False) + if _stdout_says(rclone, _PRESENT): + return + + apt = await session.exec("sh", "-lc", _APT_CHECK, shell=False) + if not _stdout_says(apt, _PRESENT): + raise MountConfigError( + message="rclone is not installed and apt-get is unavailable; preinstall rclone", + context={"package": "rclone"}, + ) + + for command in _INSTALL_RCLONE_COMMANDS: + await session.exec("sh", "-lc", command, shell=False, timeout=300) + + rclone = await session.exec("sh", "-lc", _RCLONE_CHECK, shell=False) + if not _stdout_says(rclone, _PRESENT): + raise MountConfigError( + message="rclone install attempt completed but rclone is still not on PATH", + context={"package": "rclone"}, + ) + + +async def _verify_mount_active(session: BaseSandboxSession, mount_path: Path) -> None: + """Confirm ``mount_path`` is a live mountpoint after activation. + + Without reliable exit codes from the platform we can't detect a failed + rclone mount via ``rclone mount``'s return value. Probe the kernel's view + of the path instead: ``mountpoint -q`` returns 0 iff the path is a mount + boundary. The shell wraps the conditional and emits a stdout sentinel so + the verification is transport-independent. ``rclone mount --daemon`` forks + and the parent returns immediately, so we poll briefly to give the daemon + time to bind. + """ + + quoted = shlex.quote(str(mount_path)) + probe_cmd = ( + f"for _ in 1 2 3 4 5 6 7 8 9 10; do " + f"if mountpoint -q {quoted}; then echo {_MOUNTED}; exit 0; fi; " + "sleep 0.5; " + f"done; echo {_NOT_MOUNTED}" + ) + probe = await session.exec("sh", "-lc", probe_cmd, shell=False, timeout=30) + if not _stdout_says(probe, _MOUNTED): + raise MountConfigError( + message="rclone mount completed but the path is not a live mountpoint", + context={"path": str(mount_path)}, + ) + + # Force rclone to materialize the root directory listing before we hand + # control back to the caller. Without this, the next ``readdir`` from the + # agent races the daemon's first listing fetch and can briefly observe an + # empty directory. The exit code is irrelevant here — we just want the + # side effect of priming rclone's dir cache. + await session.exec("sh", "-lc", f"ls {quoted} >/dev/null 2>&1", shell=False, timeout=15) + + +async def _default_user_ids(session: BaseSandboxSession) -> tuple[str, str] | None: + result = await session.exec("sh", "-lc", "id -u; id -g", shell=False, timeout=30) + if not result.ok(): + return None + + lines = result.stdout.decode("utf-8", errors="replace").splitlines() + if len(lines) < 2 or not lines[0].isdigit() or not lines[1].isdigit(): + return None + return lines[0], lines[1] + + +def _append_option(args: list[str], option: str, *values: str) -> None: + if option not in args: + args.extend([option, *values]) + + +async def _rclone_pattern_for_session( + session: BaseSandboxSession, + pattern: RcloneMountPattern, +) -> RcloneMountPattern: + if pattern.mode != "fuse": + return pattern + + extra_args = list(pattern.extra_args) + _append_option(extra_args, "--allow-other") + user_ids = await _default_user_ids(session) + if user_ids is not None: + uid, gid = user_ids + _append_option(extra_args, "--uid", uid) + _append_option(extra_args, "--gid", gid) + + return pattern.model_copy(update={"extra_args": extra_args}) + + +def _assert_sprites_session(session: BaseSandboxSession) -> None: + if type(session).__name__ != "SpritesSandboxSession": + raise MountConfigError( + message="sprites cloud bucket mounts require a SpritesSandboxSession", + context={"session_type": type(session).__name__}, + ) + + +class SpritesCloudBucketMountStrategy(MountStrategyBase): + """Mount rclone-backed cloud storage in Sprites sandboxes.""" + + type: Literal["sprites_cloud_bucket"] = "sprites_cloud_bucket" + pattern: RcloneMountPattern = RcloneMountPattern(mode="fuse") + + def _delegate(self) -> InContainerMountStrategy: + return InContainerMountStrategy(pattern=self.pattern) + + async def _delegate_for_session(self, session: BaseSandboxSession) -> InContainerMountStrategy: + return InContainerMountStrategy( + pattern=await _rclone_pattern_for_session(session, self.pattern) + ) + + def validate_mount(self, mount: Mount) -> None: + self._delegate().validate_mount(mount) + + async def activate( + self, + mount: Mount, + session: BaseSandboxSession, + dest: Path, + base_dir: Path, + ) -> list[MaterializedFile]: + _assert_sprites_session(session) + if self.pattern.mode == "fuse": + await _ensure_fuse_support(session) + await _ensure_rclone(session) + delegate = await self._delegate_for_session(session) + files = await delegate.activate(mount, session, dest, base_dir) + if self.pattern.mode == "fuse": + mount_path = mount._resolve_mount_path(session, dest) + await _verify_mount_active(session, mount_path) + return files + + async def deactivate( + self, + mount: Mount, + session: BaseSandboxSession, + dest: Path, + base_dir: Path, + ) -> None: + _assert_sprites_session(session) + await self._delegate().deactivate(mount, session, dest, base_dir) + + async def teardown_for_snapshot( + self, + mount: Mount, + session: BaseSandboxSession, + path: Path, + ) -> None: + _assert_sprites_session(session) + await self._delegate().teardown_for_snapshot(mount, session, path) + + async def restore_after_snapshot( + self, + mount: Mount, + session: BaseSandboxSession, + path: Path, + ) -> None: + _assert_sprites_session(session) + if self.pattern.mode == "fuse": + await _ensure_fuse_support(session) + await _ensure_rclone(session) + delegate = await self._delegate_for_session(session) + await delegate.restore_after_snapshot(mount, session, path) + + def build_docker_volume_driver_config( + self, + mount: Mount, + ) -> tuple[str, dict[str, str], bool] | None: + return None + + +__all__ = [ + "SpritesCloudBucketMountStrategy", +] diff --git a/src/agents/extensions/sandbox/sprites/sandbox.py b/src/agents/extensions/sandbox/sprites/sandbox.py new file mode 100644 index 0000000000..7c4a781beb --- /dev/null +++ b/src/agents/extensions/sandbox/sprites/sandbox.py @@ -0,0 +1,1353 @@ +"""Sprites sandbox (https://sprites.dev) implementation. + +Create a Sprites organization, set ``SPRITES_API_TOKEN``, and optionally +``SPRITES_API_URL`` (defaults to ``https://api.sprites.dev``). + +This module provides a Sprites-backed sandbox client/session that delegates to +the ``sprites-py`` SDK. Exec runs over the multiplexed control-plane WebSocket +(``ControlConnection`` / ``OpConn``) directly so cancellation, timeout, and +streaming work cleanly with the agents-python event loop. Short, non-streaming +lifecycle calls (``create_sprite``, ``get_sprite``, ``delete_sprite``, +filesystem read/write) are wrapped in ``asyncio.to_thread`` because the +upstream SDK exposes them synchronously. + +The ``sprites-py`` dependency is intended to be optional (installed via the +``[sprites]`` extra), so package-level exports guard imports of this module. +Within this module the upstream SDK is imported normally so IDEs can resolve +and navigate types. +""" + +from __future__ import annotations + +import asyncio +import io +import logging +import os +import posixpath +import tarfile +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from pathlib import Path, PurePosixPath +from typing import Any, Literal, cast +from urllib.parse import urlsplit + +import sprites +from sprites import Sprite, SpritesClient +from sprites.control import ( + ControlConnection, + OpConn, + get_control_connection, + release_control_connection, +) +from sprites.exceptions import ( + AuthenticationError, + FileNotFoundError_, + NetworkError, + NotFoundError, + SpriteError, +) +from sprites.types import URLSettings + +from ....sandbox.errors import ( + ConfigurationError, + ErrorCode, + ExecTimeoutError, + ExecTransportError, + ExposedPortUnavailableError, + WorkspaceArchiveReadError, + WorkspaceArchiveWriteError, + WorkspaceReadNotFoundError, + WorkspaceStartError, + WorkspaceWriteTypeError, +) +from ....sandbox.manifest import Manifest +from ....sandbox.session import SandboxSession, SandboxSessionState +from ....sandbox.session.base_sandbox_session import BaseSandboxSession +from ....sandbox.session.dependencies import Dependencies +from ....sandbox.session.manager import Instrumentation +from ....sandbox.session.pty_types import ( + PTY_PROCESSES_MAX, + PTY_PROCESSES_WARNING, + PtyExecUpdate, + allocate_pty_process_id, + clamp_pty_yield_time_ms, + process_id_to_prune_from_meta, + resolve_pty_write_yield_time_ms, + truncate_text_by_tokens, +) +from ....sandbox.session.runtime_helpers import ( + RESOLVE_WORKSPACE_PATH_HELPER, + RuntimeHelperScript, +) +from ....sandbox.session.sandbox_client import BaseSandboxClient, BaseSandboxClientOptions +from ....sandbox.snapshot import SnapshotBase, SnapshotSpec, resolve_snapshot +from ....sandbox.types import ExecResult, ExposedPortEndpoint, User +from ....sandbox.util.tar_utils import UnsafeTarMemberError, validate_tarfile +from ....sandbox.workspace_paths import coerce_posix_path, posix_path_as_path, sandbox_path_str + +WorkspacePersistenceMode = Literal["tar"] +"""Workspace persistence modes supported by the Sprites sandbox. + +Only ``"tar"`` is supported in v1; native sprite checkpoints are tracked as a +follow-up because their iterator-based streaming API needs a separate async +wrapper. +""" + +UrlAuth = Literal["sprite", "public"] + +DEFAULT_SPRITES_API_URL = "https://api.sprites.dev" +DEFAULT_SPRITES_WORKSPACE_ROOT = "/workspace" +DEFAULT_SPRITES_WAIT_FOR_RUNNING_TIMEOUT_S = 45.0 +DEFAULT_SPRITES_IDLE_CLOSE_SECONDS = 60.0 +"""Default idle threshold after which control connections are closed so the +sprite can drop back to ``warm`` and stop accruing running-state cost. The +next I/O reopens a control connection; the platform auto-wakes the sprite on +traffic arrival, so the cost is just the WS reconnect (~1s).""" + +# The upstream sprite status enum is not exported from sprites-py; values are +# defined by the API. A sprite that has finished provisioning reports either +# ``"warm"`` (VM is up, idle, ready to accept requests) or ``"running"`` +# (actively handling HTTP traffic). Both are valid for our purposes — exec and +# filesystem operations succeed as soon as the sprite is warm; the platform +# transitions warm → running automatically when traffic arrives. +_SPRITE_READY_STATUSES = frozenset({"warm", "running"}) +_SPRITE_READY_POLL_INTERVAL_S = 1.0 +_DEFAULT_MANIFEST_ROOT = cast(str, Manifest.model_fields["root"].default) + +logger = logging.getLogger(__name__) + + +def _resolve_manifest_root(manifest: Manifest | None) -> Manifest: + """Pin a Sprites-specific workspace root when the manifest uses the framework default.""" + + if manifest is None: + return Manifest(root=DEFAULT_SPRITES_WORKSPACE_ROOT) + if manifest.root == _DEFAULT_MANIFEST_ROOT: + return manifest.model_copy(update={"root": DEFAULT_SPRITES_WORKSPACE_ROOT}) + return manifest + + +@dataclass +class _SpritePtyProcessEntry: + """Tracks an in-flight PTY operation for ``SpritesSandboxSession``.""" + + op_conn: OpConn + control: ControlConnection + tty: bool + output_chunks: deque[bytes] = field(default_factory=deque) + output_notify: asyncio.Event = field(default_factory=asyncio.Event) + last_used: float = field(default_factory=time.monotonic) + + +def _validate_tar_bytes(raw: bytes) -> None: + """Validate that ``raw`` is a safe tar archive before extraction.""" + + try: + with tarfile.open(fileobj=io.BytesIO(raw), mode="r:*") as tar: + validate_tarfile(tar) + except UnsafeTarMemberError as exc: + raise ValueError(str(exc)) from exc + except (tarfile.TarError, OSError) as exc: + raise ValueError("invalid tar stream") from exc + + +class SpritesSandboxClientOptions(BaseSandboxClientOptions): + """Client options for the Sprites sandbox backend. + + Field order is part of the v1 public API (pinned by + ``tests/sandbox/test_compatibility_guards.py``); future fields must be + appended. + """ + + type: Literal["sprites"] = "sprites" + sprite_name: str | None = None + """Existing sprite to attach to. When ``None`` (default), a fresh sprite is + created and deleted at session shutdown.""" + + url_auth: UrlAuth = "sprite" + """URL auth mode for the sprite. ``"sprite"`` restricts access to + organization members (default); ``"public"`` exposes the sprite URL to the + public internet.""" + + ram_mb: int | None = None + cpus: int | None = None + region: str | None = None + storage_gb: int | None = None + """Optional sprite ``SpriteConfig`` knobs. Ignored when attaching to an + existing sprite.""" + + exposed_ports: tuple[int, ...] = () + """Ports expected to be exposed by services declared in the sprite image. + Sprites supports at most one externally routable port per sprite, so this + tuple may have at most one entry.""" + + env: dict[str, str] | None = None + """Reserved for future per-session environment overrides; not yet wired + through to the sprite create call by ``sprites-py``.""" + + timeout_ms: int | None = None + """Reserved for future sprite-side idle timeout configuration.""" + + wait_for_running_timeout_s: float = DEFAULT_SPRITES_WAIT_FOR_RUNNING_TIMEOUT_S + """How long to poll ``get_sprite`` waiting for the sprite to reach + ``running`` status before raising ``WorkspaceStartError``.""" + + workspace_persistence: WorkspacePersistenceMode = "tar" + """Workspace persistence mode. v1 supports only ``"tar"``.""" + + idle_close_seconds: float = DEFAULT_SPRITES_IDLE_CLOSE_SECONDS + """Seconds of inactivity after which the session closes its control + connections so the sprite can drop back to ``warm``. Set to ``0`` (or + any negative value) to disable — connections stay open until shutdown. + Default ``60.0`` matches Sprites' running-state idle billing window.""" + + def __init__( + self, + sprite_name: str | None = None, + url_auth: UrlAuth = "sprite", + ram_mb: int | None = None, + cpus: int | None = None, + region: str | None = None, + storage_gb: int | None = None, + exposed_ports: tuple[int, ...] = (), + env: dict[str, str] | None = None, + timeout_ms: int | None = None, + wait_for_running_timeout_s: float = DEFAULT_SPRITES_WAIT_FOR_RUNNING_TIMEOUT_S, + workspace_persistence: WorkspacePersistenceMode = "tar", + idle_close_seconds: float = DEFAULT_SPRITES_IDLE_CLOSE_SECONDS, + *, + type: Literal["sprites"] = "sprites", + ) -> None: + super().__init__( + type=type, + sprite_name=sprite_name, + url_auth=url_auth, + ram_mb=ram_mb, + cpus=cpus, + region=region, + storage_gb=storage_gb, + exposed_ports=exposed_ports, + env=env, + timeout_ms=timeout_ms, + wait_for_running_timeout_s=wait_for_running_timeout_s, + workspace_persistence=workspace_persistence, + idle_close_seconds=idle_close_seconds, + ) + + +class SpritesSandboxSessionState(SandboxSessionState): + """Serializable state for a Sprites-backed session. + + ``token`` and ``base_url`` are intentionally absent — ``resume()`` reads + them from the live ``SpritesSandboxClient`` instead, matching the + token-non-leakage contract documented for the Vercel provider. + """ + + type: Literal["sprites"] = "sprites" + sprite_name: str + created_by_us: bool = True + url_auth: UrlAuth = "sprite" + ram_mb: int | None = None + cpus: int | None = None + region: str | None = None + storage_gb: int | None = None + env: dict[str, str] | None = None + timeout_ms: int | None = None + workspace_persistence: WorkspacePersistenceMode = "tar" + idle_close_seconds: float = DEFAULT_SPRITES_IDLE_CLOSE_SECONDS + + +class SpritesSandboxSession(BaseSandboxSession): + """SandboxSession implementation backed by a Sprites sprite.""" + + state: SpritesSandboxSessionState + _client: SpritesClient | None + _sprite: Sprite | None + _control: ControlConnection | None + _token: str | None + _base_url: str + _pty_lock: asyncio.Lock + _pty_processes: dict[int, _SpritePtyProcessEntry] + _reserved_pty_process_ids: set[int] + _warmth_verified: bool + _last_activity_at: float + _idle_close_seconds: float + _idle_watch_task: asyncio.Task[None] | None + + def __init__( + self, + *, + state: SpritesSandboxSessionState, + token: str | None = None, + base_url: str = DEFAULT_SPRITES_API_URL, + client: SpritesClient | None = None, + sprite: Sprite | None = None, + ) -> None: + self.state = state + self._token = token + self._base_url = base_url + self._client = client + self._sprite = sprite + self._control = None + self._pty_lock = asyncio.Lock() + self._pty_processes = {} + self._reserved_pty_process_ids = set() + self._warmth_verified = False + # Idle-close: when an I/O operation hasn't run for ``idle_close_seconds``, + # the watcher closes the control-connection pool so the sprite can drop + # to ``warm`` and stop accruing running-state cost. The next I/O + # operation reopens a connection; the platform auto-wakes the sprite on + # traffic arrival. + self._last_activity_at = time.monotonic() + self._idle_close_seconds = float(state.idle_close_seconds) + self._idle_watch_task = None + + @classmethod + def from_state( + cls, + state: SpritesSandboxSessionState, + *, + token: str | None = None, + base_url: str = DEFAULT_SPRITES_API_URL, + client: SpritesClient | None = None, + sprite: Sprite | None = None, + ) -> SpritesSandboxSession: + return cls(state=state, token=token, base_url=base_url, client=client, sprite=sprite) + + def supports_pty(self) -> bool: + return True + + # ----- internal helpers ----- + + def _ensure_client_sync(self) -> SpritesClient: + client = self._client + if client is not None: + return client + if not self._token: + raise ConfigurationError( + message=( + "SpritesSandboxSession requires a Sprites API token " + "(set SPRITES_API_TOKEN or pass token=... to SpritesSandboxClient)" + ), + error_code=ErrorCode.SANDBOX_CONFIG_INVALID, + op="start", + context={"backend": "sprites"}, + ) + client = SpritesClient(token=self._token, base_url=self._base_url, control_mode=True) + self._client = client + return client + + async def _ensure_sprite(self) -> Sprite: + existing = self._sprite + if existing is not None: + return existing + + client = self._ensure_client_sync() + sprite: Sprite + if self.state.created_by_us: + # Provision a fresh sprite. ``create_sprite`` raises eagerly if the + # platform rejects the request, so we still surface creation + # failures synchronously here. + config = self._build_sprite_config() + try: + sprite = await asyncio.to_thread( + client.create_sprite, self.state.sprite_name, config + ) + except (NetworkError, AuthenticationError, NotFoundError, SpriteError) as exc: + raise WorkspaceStartError( + path=posix_path_as_path(coerce_posix_path(self.state.manifest.root)), + context={ + "backend": "sprites", + "sprite_name": self.state.sprite_name, + "reason": "create_failed", + }, + cause=exc, + ) from exc + await self._maybe_update_url_settings(sprite) + self._sprite = sprite + return sprite + + # Named-attach: just construct the handle. + sprite = await asyncio.to_thread(client.sprite, self.state.sprite_name) + self._sprite = sprite + return sprite + + # Both ephemeral and named-attach paths now defer the wait-for-running poll + # (and the URL/org-info refresh that comes with it) until the first I/O + # operation runs ``_ensure_warm``. The platform auto-wakes paused sprites + # on traffic arrival and the create POST raises eagerly on rejection, so + # this purely shifts the warm-up cost from session creation to first use + # without losing any safety. Callers that need ``Sprite.url`` (e.g. + # ``_resolve_exposed_port``) call ``_ensure_warm`` themselves. + + async def _ensure_warm(self) -> None: + """Block until the sprite is ready to accept I/O, but only on first use. + + ``_warmth_verified`` is sticky for the life of the session; cached + until a transport error invalidates it (e.g., the sprite was deleted + out from under us and we have to re-attach in a recovery flow). + """ + + self._touch_activity() + if self._warmth_verified: + return + await self._wait_for_sprite_running() + self._warmth_verified = True + + def _invalidate_warmth(self) -> None: + """Force the next I/O operation to re-poll the sprite's status.""" + + self._warmth_verified = False + + def _touch_activity(self) -> None: + """Mark this moment as the most recent I/O. Starts the idle watcher + if it isn't already running.""" + + self._last_activity_at = time.monotonic() + self._maybe_start_idle_watch() + + def _maybe_start_idle_watch(self) -> None: + if self._idle_close_seconds <= 0: + return + task = self._idle_watch_task + if task is not None and not task.done(): + return + try: + self._idle_watch_task = asyncio.create_task(self._idle_watch_loop()) + except RuntimeError: + # No running event loop (e.g. unit-test fixture creating a session + # outside an asyncio context). The watcher will start on the next + # I/O call from inside an active loop. + self._idle_watch_task = None + + async def _idle_watch_loop(self) -> None: + try: + while True: + # Sleep until the configured idle window elapses since the + # most-recent activity, re-checking each loop because activity + # may have happened during the sleep and reset the deadline. + elapsed = time.monotonic() - self._last_activity_at + remaining = self._idle_close_seconds - elapsed + if remaining > 0: + await asyncio.sleep(remaining) + continue + await self._close_idle_control_connections() + # Watcher exits; the next I/O calls ``_touch_activity`` which + # will respawn it. + return + except asyncio.CancelledError: + pass + + async def _close_idle_control_connections(self) -> None: + """Close pooled control connections so the sprite can drop to ``warm``. + + Skipped when there are active PTY operations — those need their + connections kept alive. + """ + + if self._pty_processes: + return + sprite = self._sprite + if sprite is None: + return + try: + await sprite.close_control_connection() + except Exception: + pass + + def _build_sprite_config(self) -> sprites.SpriteConfig | None: + if ( + self.state.ram_mb is None + and self.state.cpus is None + and self.state.region is None + and self.state.storage_gb is None + ): + return None + from sprites.types import SpriteConfig + + return SpriteConfig( + ram_mb=self.state.ram_mb, + cpus=self.state.cpus, + region=self.state.region, + storage_gb=self.state.storage_gb, + ) + + async def _maybe_update_url_settings(self, sprite: Sprite) -> None: + # The default URL auth mode set by the API is "sprite"; only call the + # update endpoint when the user asked for something different to avoid + # unnecessary round-trips. + if self.state.url_auth == "sprite": + return + try: + await asyncio.to_thread( + sprite.update_url_settings, URLSettings(auth=self.state.url_auth) + ) + except SpriteError: + # URL auth is best-effort during create; if the platform does not + # accept the value the user can update it later via the dashboard + # without breaking the session. + return + + async def _wait_for_sprite_running(self) -> None: + client = self._ensure_client_sync() + deadline_s = max(0.0, float(self.state.timeout_ms or 0) / 1000.0) or float( + DEFAULT_SPRITES_WAIT_FOR_RUNNING_TIMEOUT_S + ) + loop = asyncio.get_event_loop() + start = loop.time() + last_status: str | None = None + while True: + try: + refreshed = await asyncio.to_thread(client.get_sprite, self.state.sprite_name) + except NotFoundError as exc: + raise WorkspaceStartError( + path=posix_path_as_path(coerce_posix_path(self.state.manifest.root)), + context={ + "backend": "sprites", + "sprite_name": self.state.sprite_name, + "reason": "sprite_not_found", + }, + cause=exc, + ) from exc + except SpriteError as exc: + raise WorkspaceStartError( + path=posix_path_as_path(coerce_posix_path(self.state.manifest.root)), + context={ + "backend": "sprites", + "sprite_name": self.state.sprite_name, + "reason": "wait_for_running_failed", + }, + cause=exc, + ) from exc + + last_status = refreshed.status + if last_status in _SPRITE_READY_STATUSES: + self._sprite = refreshed + return + if loop.time() - start >= deadline_s: + raise WorkspaceStartError( + path=posix_path_as_path(coerce_posix_path(self.state.manifest.root)), + context={ + "backend": "sprites", + "sprite_name": self.state.sprite_name, + "reason": "wait_for_running_timeout", + "last_status": last_status or "unknown", + "timeout_s": deadline_s, + }, + ) + await asyncio.sleep(_SPRITE_READY_POLL_INTERVAL_S) + + async def _ensure_control(self) -> ControlConnection: + sprite = await self._ensure_sprite() + try: + return await get_control_connection(sprite) + except Exception as exc: + raise ExecTransportError( + command=("",), + context={"backend": "sprites", "sprite_name": self.state.sprite_name}, + cause=exc, + ) from exc + + def _release_control(self, control: ControlConnection) -> None: + sprite = self._sprite + if sprite is None: + return + try: + release_control_connection(sprite, control) + except Exception: + pass + + def _validate_exposed_ports(self) -> None: + if len(self.state.exposed_ports) > 1: + raise ConfigurationError( + message=( + "Sprites supports at most one external exposed port per sprite; " + "additional ports must be proxied inside the VM" + ), + error_code=ErrorCode.SANDBOX_CONFIG_INVALID, + op="start", + context={ + "backend": "sprites", + "exposed_ports": list(self.state.exposed_ports), + }, + ) + + # ----- BaseSandboxSession overrides ----- + + def _runtime_helpers(self) -> tuple[RuntimeHelperScript, ...]: + return (RESOLVE_WORKSPACE_PATH_HELPER,) + + async def _validate_path_access(self, path: Path | str, *, for_write: bool = False) -> Path: + return await self._validate_remote_path_access(path, for_write=for_write) + + def _reject_user_arg(self, *, op: Literal["exec", "read", "write"], user: str | User) -> None: + user_name = user.name if isinstance(user, User) else user + raise ConfigurationError( + message=( + "SpritesSandboxSession does not support sandbox-local users; " + f"`{op}` must be called without `user`" + ), + error_code=ErrorCode.SANDBOX_CONFIG_INVALID, + op=op, + context={"backend": "sprites", "user": user_name}, + ) + + def _prepare_exec_command( + self, + *command: str | Path, + shell: bool | list[str], + user: str | User | None, + ) -> list[str]: + if user is not None: + self._reject_user_arg(op="exec", user=user) + return super()._prepare_exec_command(*command, shell=shell, user=user) + + async def _prepare_backend_workspace(self) -> None: + # Bootstrap: create the workspace root from ``/`` because the workspace + # directory does not yet exist, and ``_exec_internal`` would otherwise + # try to ``chdir`` into it. + root = PurePosixPath(posixpath.normpath(self.state.manifest.root)) + result = await self._exec_with_cwd( + ["mkdir", "-p", "--", root.as_posix()], cwd=None, timeout=30.0 + ) + if not result.ok(): + raise WorkspaceStartError( + path=posix_path_as_path(root), + context={ + "backend": "sprites", + "sprite_name": self.state.sprite_name, + "exit_code": result.exit_code, + "stdout": result.stdout.decode("utf-8", errors="replace"), + "stderr": result.stderr.decode("utf-8", errors="replace"), + }, + ) + + async def running(self) -> bool: + if self._client is None: + return False + try: + refreshed: Sprite = await asyncio.to_thread( + self._client.get_sprite, self.state.sprite_name + ) + except Exception: + return False + return bool(refreshed.status in _SPRITE_READY_STATUSES) + + async def shutdown(self) -> None: + # Stop the idle watcher first so it doesn't race with our cleanup. + watcher = self._idle_watch_task + if watcher is not None and not watcher.done(): + watcher.cancel() + try: + await watcher + except (asyncio.CancelledError, Exception): + pass + self._idle_watch_task = None + + # Tear down any in-flight PTY operations first so their control connections + # are released back to the pool before the sprite is deleted. + try: + await asyncio.wait_for(self.pty_terminate_all(), timeout=2.0) + except Exception: + pass + + # Order matters for fast cleanup: delete the sprite FIRST (which kills + # server-side WebSockets immediately), then close local client state. + # Otherwise we wait up to ~2s per still-open control connection on the + # WS close handshake + read-task drain. + if self.state.created_by_us and self._client is not None: + try: + await asyncio.wait_for( + asyncio.to_thread(self._client.delete_sprite, self.state.sprite_name), + timeout=5.0, + ) + except NotFoundError: + pass + except Exception: + pass + + # Now close local control connections. They'll see ConnectionClosed + # from the now-deleted sprite and exit fast; cap at 2s as a guardrail. + if self._sprite is not None: + try: + await asyncio.wait_for(self._sprite.close_control_connection(), timeout=2.0) + except Exception: + pass + self._sprite = None + + if self._client is not None: + try: + self._client.close() + except Exception: + pass + self._client = None + + async def _exec_internal( + self, + *command: str | Path, + timeout: float | None = None, + ) -> ExecResult: + normalized = [str(part) for part in command] + return await self._exec_with_cwd(normalized, cwd=self.state.manifest.root, timeout=timeout) + + async def _exec_with_cwd( + self, + command: list[str], + *, + cwd: str | None, + timeout: float | None, + ) -> ExecResult: + normalized = [str(part) for part in command] + if not normalized: + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + + await self._ensure_warm() + + control: ControlConnection | None = None + op_conn: OpConn | None = None + try: + control = await self._ensure_control() + try: + op_conn = await control.start_op( + "exec", + cmd=list(normalized), + dir=cwd, + stdin=False, + ) + except Exception as exc: + raise ExecTransportError( + command=normalized, + context={"backend": "sprites", "sprite_name": self.state.sprite_name}, + cause=exc, + ) from exc + + try: + exit_code = await asyncio.wait_for(op_conn.wait(), timeout=timeout) + except asyncio.TimeoutError as exc: + # Best-effort: signal the remote process before propagating the timeout. + try: + await op_conn.signal("KILL") + except Exception: + pass + raise ExecTimeoutError(command=normalized, timeout_s=timeout, cause=exc) from exc + + return ExecResult( + stdout=op_conn.get_stdout(), + stderr=op_conn.get_stderr(), + exit_code=exit_code, + ) + except (ExecTimeoutError, ExecTransportError): + raise + except Exception as exc: + raise ExecTransportError( + command=normalized, + context={"backend": "sprites", "sprite_name": self.state.sprite_name}, + cause=exc, + ) from exc + finally: + if control is not None: + self._release_control(control) + + # ----- PTY ----- + + def _make_pty_callback(self, entry: _SpritePtyProcessEntry) -> Any: + # ``OpConn.handle_data`` invokes callbacks synchronously from the read + # loop running on this event loop, so a sync callback is correct. + def _callback(payload: bytes) -> None: + if not payload: + return + entry.output_chunks.append(bytes(payload)) + entry.output_notify.set() + + return _callback + + async def pty_exec_start( + self, + *command: str | Path, + timeout: float | None = None, + shell: bool | list[str] = True, + user: str | User | None = None, + tty: bool = False, + yield_time_s: float | None = None, + max_output_tokens: int | None = None, + ) -> PtyExecUpdate: + sanitized_command = self._prepare_exec_command(*command, shell=shell, user=user) + # ``_ensure_control`` will lazily call ``_ensure_sprite``; no extra await here. + await self._ensure_warm() + + cc: ControlConnection | None = None + op: OpConn | None = None + entry: _SpritePtyProcessEntry | None = None + registered = False + pruned_entry: _SpritePtyProcessEntry | None = None + process_id = 0 + process_count = 0 + + try: + cc = await self._ensure_control() + try: + op = await cc.start_op( + "exec", + cmd=list(sanitized_command), + dir=self.state.manifest.root, + tty=tty, + rows=24, + cols=80, + stdin=True, + ) + except Exception as exc: + raise ExecTransportError( + command=sanitized_command, + context={"backend": "sprites", "sprite_name": self.state.sprite_name}, + cause=exc, + ) from exc + + entry = _SpritePtyProcessEntry(op_conn=op, control=cc, tty=tty) + # Register callbacks before any ``await`` to minimize the start-time + # race; pre-drain whatever already landed in the OpConn's internal + # buffers between ``start_op`` returning and this point. + callback = self._make_pty_callback(entry) + op.on_stdout = callback + op.on_stderr = callback + pre_stdout = op.get_stdout() + pre_stderr = op.get_stderr() + if pre_stdout: + entry.output_chunks.append(pre_stdout) + if pre_stderr: + entry.output_chunks.append(pre_stderr) + if pre_stdout or pre_stderr: + entry.output_notify.set() + + async with self._pty_lock: + process_id = allocate_pty_process_id(self._reserved_pty_process_ids) + self._reserved_pty_process_ids.add(process_id) + pruned_entry = self._prune_pty_processes_if_needed() + self._pty_processes[process_id] = entry + process_count = len(self._pty_processes) + registered = True + except asyncio.CancelledError: + if not registered and entry is not None: + await self._terminate_pty_entry(entry) + elif cc is not None: + self._release_control(cc) + raise + except ExecTransportError: + if cc is not None and (entry is None or not registered): + self._release_control(cc) + raise + except Exception as exc: + if not registered and entry is not None: + await self._terminate_pty_entry(entry) + elif cc is not None: + self._release_control(cc) + raise ExecTransportError( + command=sanitized_command, + context={"backend": "sprites", "sprite_name": self.state.sprite_name}, + cause=exc, + ) from exc + + if pruned_entry is not None: + await self._terminate_pty_entry(pruned_entry) + + if process_count >= PTY_PROCESSES_WARNING: + logger.warning( + "Sprites PTY process count reached warning threshold: %s active sessions", + process_count, + ) + + yield_time_ms = 10_000 if yield_time_s is None else int(yield_time_s * 1000) + output, original_token_count = await self._collect_pty_output( + entry=entry, + yield_time_ms=clamp_pty_yield_time_ms(yield_time_ms), + max_output_tokens=max_output_tokens, + ) + return await self._finalize_pty_update( + process_id=process_id, + entry=entry, + output=output, + original_token_count=original_token_count, + ) + + async def pty_write_stdin( + self, + *, + session_id: int, + chars: str, + yield_time_s: float | None = None, + max_output_tokens: int | None = None, + ) -> PtyExecUpdate: + async with self._pty_lock: + entry = self._resolve_pty_session_entry( + pty_processes=self._pty_processes, + session_id=session_id, + ) + + if chars: + payload = chars.encode("utf-8") + try: + await entry.op_conn.write(payload) + except Exception as exc: + raise ExecTransportError( + command=("",), + context={ + "backend": "sprites", + "sprite_name": self.state.sprite_name, + "session_id": session_id, + }, + cause=exc, + ) from exc + + yield_time_ms = 250 if yield_time_s is None else int(yield_time_s * 1000) + output, original_token_count = await self._collect_pty_output( + entry=entry, + yield_time_ms=resolve_pty_write_yield_time_ms( + yield_time_ms=yield_time_ms, input_empty=chars == "" + ), + max_output_tokens=max_output_tokens, + ) + entry.last_used = time.monotonic() + return await self._finalize_pty_update( + process_id=session_id, + entry=entry, + output=output, + original_token_count=original_token_count, + ) + + async def pty_terminate_all(self) -> None: + async with self._pty_lock: + entries = list(self._pty_processes.values()) + self._pty_processes.clear() + self._reserved_pty_process_ids.clear() + for entry in entries: + await self._terminate_pty_entry(entry) + + async def _collect_pty_output( + self, + *, + entry: _SpritePtyProcessEntry, + yield_time_ms: int, + max_output_tokens: int | None, + ) -> tuple[bytes, int | None]: + deadline = time.monotonic() + (yield_time_ms / 1000) + output = bytearray() + + while True: + while entry.output_chunks: + output.extend(entry.output_chunks.popleft()) + + if time.monotonic() >= deadline: + break + + if self._entry_exit_code(entry) is not None: + while entry.output_chunks: + output.extend(entry.output_chunks.popleft()) + break + + remaining_s = deadline - time.monotonic() + if remaining_s <= 0: + break + + entry.output_notify.clear() + try: + await asyncio.wait_for(entry.output_notify.wait(), timeout=remaining_s) + except asyncio.TimeoutError: + break + + text = output.decode("utf-8", errors="replace") + truncated_text, original_token_count = truncate_text_by_tokens(text, max_output_tokens) + return truncated_text.encode("utf-8", errors="replace"), original_token_count + + async def _finalize_pty_update( + self, + *, + process_id: int, + entry: _SpritePtyProcessEntry, + output: bytes, + original_token_count: int | None, + ) -> PtyExecUpdate: + exit_code = self._entry_exit_code(entry) + live_process_id: int | None = process_id + if exit_code is not None: + async with self._pty_lock: + removed = self._pty_processes.pop(process_id, None) + self._reserved_pty_process_ids.discard(process_id) + if removed is not None: + await self._terminate_pty_entry(removed) + live_process_id = None + return PtyExecUpdate( + process_id=live_process_id, + output=output, + exit_code=exit_code, + original_token_count=original_token_count, + ) + + def _prune_pty_processes_if_needed(self) -> _SpritePtyProcessEntry | None: + if len(self._pty_processes) < PTY_PROCESSES_MAX: + return None + meta: list[tuple[int, float, bool]] = [ + (pid, entry.last_used, self._entry_exit_code(entry) is not None) + for pid, entry in self._pty_processes.items() + ] + target = process_id_to_prune_from_meta(meta) + if target is None: + return None + self._reserved_pty_process_ids.discard(target) + return self._pty_processes.pop(target, None) + + def _entry_exit_code(self, entry: _SpritePtyProcessEntry) -> int | None: + op = entry.op_conn + if not op.is_closed(): + return None + code = op.get_exit_code() + # ``OpConn`` initializes ``exit_code`` to -1 and only sets a real value + # on ``op.complete``. Treat -1 as "not yet known" even if closed (e.g. + # transport dropped before exit signal arrived). + if code < 0: + return None + return code + + async def _terminate_pty_entry(self, entry: _SpritePtyProcessEntry) -> None: + op = entry.op_conn + try: + if not op.is_closed(): + try: + await op.signal("TERM") + except Exception: + pass + # Brief grace period before forcing. + for _ in range(5): + if op.is_closed(): + break + await asyncio.sleep(0.05) + if not op.is_closed(): + try: + await op.signal("KILL") + except Exception: + pass + finally: + try: + op.close() + except Exception: + pass + self._release_control(entry.control) + + async def _resolve_exposed_port(self, port: int) -> ExposedPortEndpoint: + await self._ensure_sprite() + # Make sure the sprite is reachable AND that ``Sprite.url`` / + # ``organization_name`` are populated — these come from the post-poll + # ``get_sprite`` refresh. + await self._ensure_warm() + sprite = self._sprite + if sprite is None: + raise ExposedPortUnavailableError( + port=port, + exposed_ports=self.state.exposed_ports, + reason="backend_unavailable", + context={"backend": "sprites", "sprite_name": self.state.sprite_name}, + ) + url = sprite.url + if not url: + raise ExposedPortUnavailableError( + port=port, + exposed_ports=self.state.exposed_ports, + reason="backend_unavailable", + context={"backend": "sprites", "sprite_name": self.state.sprite_name}, + ) + + # Confirm the requested port is exposed by a service on the sprite. Sprites + # exposes only one external HTTP port per sprite, so any extra port is a + # configuration error caught earlier in `_validate_exposed_ports`. + try: + services = await asyncio.to_thread(sprite.list_services) + except SpriteError as exc: + raise ExposedPortUnavailableError( + port=port, + exposed_ports=self.state.exposed_ports, + reason="backend_unavailable", + context={"backend": "sprites", "sprite_name": self.state.sprite_name}, + cause=exc, + ) from exc + + if not any(getattr(service, "http_port", None) == port for service in services): + raise ExposedPortUnavailableError( + port=port, + exposed_ports=self.state.exposed_ports, + reason="not_configured", + context={ + "backend": "sprites", + "sprite_name": self.state.sprite_name, + "hint": ("declare a service with --http-port= in the sprite image"), + }, + ) + + parsed = urlsplit(url) + host = parsed.hostname + if not host: + raise ExposedPortUnavailableError( + port=port, + exposed_ports=self.state.exposed_ports, + reason="backend_unavailable", + context={"backend": "sprites", "sprite_name": self.state.sprite_name, "url": url}, + ) + tls = parsed.scheme == "https" + return ExposedPortEndpoint( + host=host, + port=parsed.port or (443 if tls else 80), + tls=tls, + ) + + async def read(self, path: Path, *, user: str | User | None = None) -> io.IOBase: + if user is not None: + self._reject_user_arg(op="read", user=user) + + normalized_path = await self._validate_path_access(path) + sprite = await self._ensure_sprite() + await self._ensure_warm() + try: + payload = await asyncio.to_thread( + lambda: (sprite.filesystem("/") / sandbox_path_str(normalized_path)).read_bytes() + ) + except FileNotFoundError_ as exc: + raise WorkspaceReadNotFoundError(path=normalized_path, cause=exc) from exc + except Exception as exc: + raise WorkspaceArchiveReadError(path=normalized_path, cause=exc) from exc + return io.BytesIO(payload) + + async def write( + self, + path: Path, + data: io.IOBase, + *, + user: str | User | None = None, + ) -> None: + if user is not None: + self._reject_user_arg(op="write", user=user) + + normalized_path = await self._validate_path_access(path, for_write=True) + payload = data.read() + if isinstance(payload, str): + payload = payload.encode("utf-8") + if not isinstance(payload, bytes | bytearray): + raise WorkspaceWriteTypeError( + path=normalized_path, + actual_type=type(payload).__name__, + ) + + sprite = await self._ensure_sprite() + await self._ensure_warm() + try: + await asyncio.to_thread( + lambda: (sprite.filesystem("/") / sandbox_path_str(normalized_path)).write_bytes( + bytes(payload) + ) + ) + except Exception as exc: + raise WorkspaceArchiveWriteError(path=normalized_path, cause=exc) from exc + + async def persist_workspace(self) -> io.IOBase: + root = self._workspace_root_path() + sprite = await self._ensure_sprite() + archive_path = posix_path_as_path( + coerce_posix_path(f"/tmp/openai-agents-{self.state.session_id.hex}.tar") + ) + excludes = [ + f"--exclude=./{rel_path.as_posix()}" + for rel_path in sorted( + self._persist_workspace_skip_relpaths(), + key=lambda item: item.as_posix(), + ) + ] + tar_command = ("tar", "cf", archive_path.as_posix(), *excludes, ".") + try: + result = await self.exec(*tar_command, shell=False) + if not result.ok(): + raise WorkspaceArchiveReadError( + path=root, + context={ + "backend": "sprites", + "sprite_name": self.state.sprite_name, + "exit_code": result.exit_code, + "stdout": result.stdout.decode("utf-8", errors="replace"), + "stderr": result.stderr.decode("utf-8", errors="replace"), + }, + ) + archive = await asyncio.to_thread( + lambda: (sprite.filesystem("/") / archive_path.as_posix()).read_bytes() + ) + return io.BytesIO(archive) + except WorkspaceArchiveReadError: + raise + except Exception as exc: + raise WorkspaceArchiveReadError(path=root, cause=exc) from exc + finally: + try: + await self.exec("rm", "-f", "--", archive_path.as_posix(), shell=False) + except Exception: + pass + + async def hydrate_workspace(self, data: io.IOBase) -> None: + raw = data.read() + if isinstance(raw, str): + raw = raw.encode("utf-8") + if not isinstance(raw, bytes | bytearray): + raise WorkspaceWriteTypeError( + path=self._workspace_root_path(), + actual_type=type(raw).__name__, + ) + + root = self._workspace_root_path() + sprite = await self._ensure_sprite() + archive_path = posix_path_as_path( + coerce_posix_path(f"/tmp/openai-agents-{self.state.session_id.hex}.tar") + ) + + try: + _validate_tar_bytes(bytes(raw)) + except ValueError as exc: + raise WorkspaceArchiveWriteError(path=root, cause=exc) from exc + + try: + await self.mkdir(root, parents=True) + await asyncio.to_thread( + lambda: (sprite.filesystem("/") / archive_path.as_posix()).write_bytes(bytes(raw)) + ) + extract_cmd = ("tar", "xf", archive_path.as_posix(), "-C", root.as_posix()) + result = await self.exec(*extract_cmd, shell=False) + if not result.ok(): + raise WorkspaceArchiveWriteError( + path=root, + context={ + "backend": "sprites", + "sprite_name": self.state.sprite_name, + "exit_code": result.exit_code, + "stdout": result.stdout.decode("utf-8", errors="replace"), + "stderr": result.stderr.decode("utf-8", errors="replace"), + }, + ) + except WorkspaceArchiveWriteError: + raise + except Exception as exc: + raise WorkspaceArchiveWriteError(path=root, cause=exc) from exc + finally: + try: + await self.exec("rm", "-f", "--", archive_path.as_posix(), shell=False) + except Exception: + pass + + +class SpritesSandboxClient(BaseSandboxClient[SpritesSandboxClientOptions]): + """Sprites-backed sandbox client.""" + + backend_id = "sprites" + _instrumentation: Instrumentation + _token: str | None + _base_url: str + + def __init__( + self, + *, + token: str | None = None, + base_url: str | None = None, + instrumentation: Instrumentation | None = None, + dependencies: Dependencies | None = None, + ) -> None: + super().__init__() + resolved_token = token if token is not None else os.environ.get("SPRITES_API_TOKEN") + if not resolved_token: + raise ConfigurationError( + message=( + "Sprites API token is required. Pass token=... to " + "SpritesSandboxClient or set the SPRITES_API_TOKEN environment " + "variable." + ), + error_code=ErrorCode.SANDBOX_CONFIG_INVALID, + op="start", + context={"backend": "sprites"}, + ) + self._token = resolved_token + self._base_url = ( + base_url + if base_url is not None + else os.environ.get("SPRITES_API_URL", DEFAULT_SPRITES_API_URL) + ) + self._instrumentation = instrumentation or Instrumentation() + self._dependencies = dependencies + + async def create( + self, + *, + snapshot: SnapshotSpec | SnapshotBase | None = None, + manifest: Manifest | None = None, + options: SpritesSandboxClientOptions, + ) -> SandboxSession: + resolved_manifest = _resolve_manifest_root(manifest) + session_id = uuid.uuid4() + snapshot_instance = resolve_snapshot(snapshot, str(session_id)) + + sprite_name = options.sprite_name or f"openai-agents-{session_id.hex[:12]}" + created_by_us = options.sprite_name is None + + state = SpritesSandboxSessionState( + session_id=session_id, + manifest=resolved_manifest, + snapshot=snapshot_instance, + sprite_name=sprite_name, + created_by_us=created_by_us, + url_auth=options.url_auth, + ram_mb=options.ram_mb, + cpus=options.cpus, + region=options.region, + storage_gb=options.storage_gb, + exposed_ports=options.exposed_ports, + env=dict(options.env or {}) or None, + timeout_ms=options.timeout_ms, + workspace_persistence=options.workspace_persistence, + ) + + inner = SpritesSandboxSession.from_state(state, token=self._token, base_url=self._base_url) + inner._validate_exposed_ports() + await inner._ensure_sprite() + return self._wrap_session(inner, instrumentation=self._instrumentation) + + async def delete(self, session: SandboxSession) -> SandboxSession: + inner = session._inner + if not isinstance(inner, SpritesSandboxSession): + raise TypeError("SpritesSandboxClient.delete expects a SpritesSandboxSession") + try: + await inner.shutdown() + except Exception: + pass + return session + + async def resume(self, state: SandboxSessionState) -> SandboxSession: + if not isinstance(state, SpritesSandboxSessionState): + raise TypeError("SpritesSandboxClient.resume expects a SpritesSandboxSessionState") + + inner = SpritesSandboxSession.from_state(state, token=self._token, base_url=self._base_url) + try: + await inner._ensure_sprite() + inner._set_start_state_preserved(True) + except WorkspaceStartError: + if not state.created_by_us: + raise + # Fall through to fresh start; ``_ensure_sprite`` will be retried by the + # session's own ``start()`` lifecycle and will recreate the sprite. + inner._sprite = None + state.workspace_root_ready = False + return self._wrap_session(inner, instrumentation=self._instrumentation) + + def deserialize_session_state(self, payload: dict[str, object]) -> SandboxSessionState: + return SpritesSandboxSessionState.model_validate(payload) + + +__all__ = [ + "DEFAULT_SPRITES_API_URL", + "DEFAULT_SPRITES_WAIT_FOR_RUNNING_TIMEOUT_S", + "DEFAULT_SPRITES_WORKSPACE_ROOT", + "SpritesSandboxClient", + "SpritesSandboxClientOptions", + "SpritesSandboxSession", + "SpritesSandboxSessionState", +] diff --git a/tests/extensions/test_sandbox_sprites.py b/tests/extensions/test_sandbox_sprites.py new file mode 100644 index 0000000000..ca0b735135 --- /dev/null +++ b/tests/extensions/test_sandbox_sprites.py @@ -0,0 +1,1782 @@ +from __future__ import annotations + +import asyncio +import io +import uuid +from pathlib import Path +from typing import Any +from unittest.mock import patch + +import pytest +import pytest_asyncio + +pytest.importorskip("agents.extensions.sandbox.sprites.sandbox") + +from sprites.exceptions import NotFoundError as _SpritesNotFoundError # noqa: E402 + +from agents.extensions.sandbox.sprites import ( # noqa: E402 + SpritesPlatformContext, + SpritesSandboxClient, + SpritesSandboxClientOptions, + SpritesSandboxSession, + SpritesSandboxSessionState, + sandbox as sprites_sandbox, # noqa: E402 +) +from agents.sandbox.errors import ( # noqa: E402 + ConfigurationError, + ExecTimeoutError, + ExecTransportError, + ExposedPortUnavailableError, + WorkspaceArchiveWriteError, + WorkspaceReadNotFoundError, + WorkspaceStartError, + WorkspaceWriteTypeError, +) +from agents.sandbox.manifest import Manifest # noqa: E402 +from agents.sandbox.session.sandbox_client import BaseSandboxClientOptions # noqa: E402 +from agents.sandbox.snapshot import NoopSnapshot # noqa: E402 +from agents.sandbox.types import ExecResult, ExposedPortEndpoint, User # noqa: E402 + +SPRITE_NAME = "sprite-test-1" +SESSION_UUID = uuid.UUID("11111111-1111-1111-1111-111111111111") + + +@pytest.fixture(autouse=True) +def _clear_platform_context_cache() -> Any: + """Make sure cached platform-context text from one test doesn't leak.""" + + from agents.extensions.sandbox.sprites import clear_platform_context_cache + + clear_platform_context_cache() + yield + clear_platform_context_cache() + + +def _attach(inner: SpritesSandboxSession, *, client: Any, sprite: Any = None) -> None: + """Inject fake client/sprite into a SpritesSandboxSession. + + ``setattr`` is used to sidestep mypy's invariant attribute typing — the fakes + duck-type the real ``SpritesClient``/``Sprite`` interface only as far as the + tests exercise. + + Also marks ``_warmth_verified=True`` so I/O paths skip the lazy + wake-up poll — the test has already set up the fake sprite directly, + so we can trust it's "warm enough" for the assertions under test. + Tests that specifically exercise the wait-for-running poll override + this back to False. + """ + + # ``setattr`` (instead of plain assignment) silences mypy's invariant attribute + # check; the fakes only duck-type the parts we exercise. + setattr(inner, "_client", client) # noqa: B010 + if sprite is not None: + setattr(inner, "_sprite", sprite) # noqa: B010 + setattr(inner, "_warmth_verified", True) # noqa: B010 + + +# ---------- Fakes ---------- + + +class _FakeFileNotFound(Exception): + """Stands in for ``sprites.exceptions.FileNotFoundError_`` in fake fs ops.""" + + +class _FakeOpConn: + def __init__( + self, + *, + stdout: bytes = b"", + stderr: bytes = b"", + exit_code: int = 0, + wait_event: asyncio.Event | None = None, + start_failure: BaseException | None = None, + ) -> None: + self._stdout = stdout + self._stderr = stderr + self.exit_code = exit_code + self._wait_event = wait_event + self._start_failure = start_failure + self.signals: list[str] = [] + self.write_calls: list[bytes] = [] + self.closed = False + self.on_stdout: Any = None + self.on_stderr: Any = None + self.on_message: Any = None + + async def wait(self) -> int: + if self._wait_event is not None: + await self._wait_event.wait() + return self.exit_code + + def get_stdout(self) -> bytes: + return self._stdout + + def get_stderr(self) -> bytes: + return self._stderr + + def get_exit_code(self) -> int: + return self.exit_code + + def is_closed(self) -> bool: + return self.closed + + def close(self) -> None: + self.closed = True + + async def signal(self, sig: str) -> None: + self.signals.append(sig) + + async def write(self, data: bytes) -> None: + self.write_calls.append(data) + + +class _FakeControlConnection: + def __init__(self) -> None: + self.start_op_calls: list[dict[str, Any]] = [] + # Each entry is consumed in FIFO order; if empty, a default zero-exit op is returned. + self.next_ops: list[_FakeOpConn] = [] + self.start_op_failures: list[BaseException] = [] + + async def start_op( + self, + op: str, + cmd: list[str] | None = None, + env: dict[str, str] | None = None, + dir: str | None = None, + tty: bool = False, + rows: int = 24, + cols: int = 80, + stdin: bool = True, + ) -> _FakeOpConn: + self.start_op_calls.append( + {"op": op, "cmd": list(cmd or []), "dir": dir, "tty": tty, "stdin": stdin} + ) + if self.start_op_failures: + raise self.start_op_failures.pop(0) + if self.next_ops: + return self.next_ops.pop(0) + return _FakeOpConn() + + +class _FakeSpritePath: + def __init__(self, fs: _FakeSpriteFilesystem, path: str) -> None: + self._fs = fs + self._path = path + + def read_bytes(self) -> bytes: + if self._fs.read_failure is not None: + raise self._fs.read_failure + if self._path not in self._fs.files: + raise _FakeFileNotFound(self._path) + return self._fs.files[self._path] + + def write_bytes(self, data: bytes) -> None: + if self._fs.write_failure is not None: + raise self._fs.write_failure + self._fs.files[self._path] = bytes(data) + + +class _FakeSpriteFilesystem: + def __init__(self, files: dict[str, bytes]) -> None: + self.files = files + self.read_failure: BaseException | None = None + self.write_failure: BaseException | None = None + + def __truediv__(self, path: str) -> _FakeSpritePath: + return _FakeSpritePath(self, path) + + +class _FakeService: + def __init__(self, *, http_port: int | None) -> None: + self.http_port = http_port + + +class _FakeSprite: + def __init__( + self, + *, + name: str, + url: str | None = "https://example-sprite-org.sprites.dev", + status: str = "running", + services: list[_FakeService] | None = None, + files: dict[str, bytes] | None = None, + list_services_failure: BaseException | None = None, + ) -> None: + self.name = name + self.url = url + self.status = status + self.organization_name = "example-org" + self.update_url_settings_calls: list[Any] = [] + self.close_control_connection_calls = 0 + self.list_services_failure = list_services_failure + self._services = services or [] + self._fs = _FakeSpriteFilesystem(files or {}) + + def filesystem(self, working_dir: str = "/") -> _FakeSpriteFilesystem: + return self._fs + + def list_services(self) -> list[_FakeService]: + if self.list_services_failure is not None: + raise self.list_services_failure + return list(self._services) + + def update_url_settings(self, settings: Any) -> None: + self.update_url_settings_calls.append(settings) + + async def close_control_connection(self) -> None: + self.close_control_connection_calls += 1 + + +class _FakeSpritesClient: + def __init__( + self, + *, + token: str = "tok", + base_url: str = "https://api.sprites.dev", + control_mode: bool = False, + sprites_by_name: dict[str, _FakeSprite] | None = None, + ) -> None: + self.token = token + self.base_url = base_url + self.control_mode = control_mode + self.create_sprite_calls: list[tuple[str, Any]] = [] + self.delete_sprite_calls: list[str] = [] + self.get_sprite_calls: list[str] = [] + self.sprite_handle_calls: list[str] = [] + self.closed = False + self._sprites_by_name = sprites_by_name or {} + self.create_failures: list[BaseException] = [] + self.get_failures: list[BaseException] = [] + + def create_sprite(self, name: str, config: Any | None = None) -> _FakeSprite: + self.create_sprite_calls.append((name, config)) + if self.create_failures: + raise self.create_failures.pop(0) + sprite = _FakeSprite(name=name) + self._sprites_by_name[name] = sprite + return sprite + + def sprite(self, name: str) -> _FakeSprite: + self.sprite_handle_calls.append(name) + return self._sprites_by_name.get(name) or _FakeSprite(name=name) + + def get_sprite(self, name: str) -> _FakeSprite: + self.get_sprite_calls.append(name) + if self.get_failures: + raise self.get_failures.pop(0) + sprite = self._sprites_by_name.get(name) + if sprite is None: + raise _SpritesNotFoundError(f"sprite not found: {name}") + return sprite + + def delete_sprite(self, name: str) -> None: + self.delete_sprite_calls.append(name) + + def close(self) -> None: + self.closed = True + + +# ---------- Helpers ---------- + + +def _make_state(**overrides: object) -> SpritesSandboxSessionState: + base = { + "session_id": SESSION_UUID, + "snapshot": NoopSnapshot(id="snapshot-1"), + "manifest": Manifest(root="/workspace"), + "sprite_name": SPRITE_NAME, + "created_by_us": True, + } + base.update(overrides) + return SpritesSandboxSessionState.model_validate(base) + + +@pytest_asyncio.fixture +async def patched_sprites(monkeypatch: pytest.MonkeyPatch) -> dict[str, Any]: + fake_client = _FakeSpritesClient() + fake_control = _FakeControlConnection() + + monkeypatch.setattr(sprites_sandbox, "SpritesClient", lambda **kw: fake_client) + monkeypatch.setattr(sprites_sandbox, "FileNotFoundError_", _FakeFileNotFound) + + async def _get_control(_sprite: Any) -> _FakeControlConnection: + return fake_control + + def _release_control(_sprite: Any, _cc: Any) -> None: + return None + + monkeypatch.setattr(sprites_sandbox, "get_control_connection", _get_control) + monkeypatch.setattr(sprites_sandbox, "release_control_connection", _release_control) + return {"client": fake_client, "control": fake_control} + + +# ---------- 1. Options & state roundtrip ---------- + + +def test_options_roundtrip_through_polymorphic_registry() -> None: + options = SpritesSandboxClientOptions( + sprite_name="my-sprite", + url_auth="public", + ram_mb=512, + cpus=2, + region="iad", + storage_gb=8, + exposed_ports=(8080,), + env={"FOO": "BAR"}, + timeout_ms=120_000, + ) + payload = options.model_dump(mode="json") + assert payload["type"] == "sprites" + restored = BaseSandboxClientOptions.parse(payload) + assert isinstance(restored, SpritesSandboxClientOptions) + assert restored.model_dump(mode="json") == payload + + +def test_state_roundtrip_does_not_leak_token() -> None: + state = _make_state(sprite_name="x", created_by_us=False, url_auth="public") + payload = state.model_dump(mode="json") + assert "token" not in payload and "base_url" not in payload + client = SpritesSandboxClient(token="tok-1", base_url="https://example") + restored = client.deserialize_session_state(payload) + assert isinstance(restored, SpritesSandboxSessionState) + assert restored.model_dump(mode="json") == payload + + +# ---------- 2. Auth resolution ---------- + + +def test_client_resolves_token_from_env(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("SPRITES_API_TOKEN", "from-env") + client = SpritesSandboxClient() + assert client._token == "from-env" + assert client._base_url == "https://api.sprites.dev" + + +def test_client_kwarg_overrides_env(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("SPRITES_API_TOKEN", "from-env") + monkeypatch.setenv("SPRITES_API_URL", "https://env.example") + client = SpritesSandboxClient(token="kwarg", base_url="https://kwarg.example") + assert client._token == "kwarg" + assert client._base_url == "https://kwarg.example" + + +def test_client_missing_token_raises(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("SPRITES_API_TOKEN", raising=False) + with pytest.raises(ConfigurationError): + SpritesSandboxClient() + + +def test_resume_uses_live_client_token_after_env_cleared( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("SPRITES_API_TOKEN", "from-env") + client = SpritesSandboxClient() + monkeypatch.delenv("SPRITES_API_TOKEN", raising=False) + assert client._token == "from-env" + + +# ---------- 3 & 4. Lifecycle ephemeral & named-attach ---------- + + +@pytest.mark.asyncio +async def test_create_ephemeral_sprite(patched_sprites: dict[str, Any]) -> None: + fake_client = patched_sprites["client"] + client = SpritesSandboxClient(token="tok") + options = SpritesSandboxClientOptions() + session = await client.create(options=options) + inner = session._inner + assert isinstance(inner, SpritesSandboxSession) + assert inner.state.created_by_us is True + assert len(fake_client.create_sprite_calls) == 1 + assert fake_client.create_sprite_calls[0][0].startswith("openai-agents-") + # No eager get_sprite poll — ephemeral path is lazy too. The first I/O + # operation drives the wait-for-running via ``_ensure_warm``. + assert fake_client.get_sprite_calls == [] + assert inner._warmth_verified is False + # delete via client.delete deletes the ephemeral sprite + await client.delete(session) + assert fake_client.delete_sprite_calls == [fake_client.create_sprite_calls[0][0]] + + +@pytest.mark.asyncio +async def test_create_attaches_to_named_sprite(patched_sprites: dict[str, Any]) -> None: + fake_client = patched_sprites["client"] + fake_client._sprites_by_name["existing"] = _FakeSprite(name="existing") + client = SpritesSandboxClient(token="tok") + options = SpritesSandboxClientOptions(sprite_name="existing") + session = await client.create(options=options) + inner = session._inner + assert isinstance(inner, SpritesSandboxSession) + assert inner.state.created_by_us is False + assert fake_client.create_sprite_calls == [] + assert fake_client.sprite_handle_calls == ["existing"] + await client.delete(session) + assert fake_client.delete_sprite_calls == [] + + +# ---------- 5. Wait-for-running timeout ---------- + + +@pytest.mark.asyncio +async def test_wait_for_running_raises_workspace_start_error( + patched_sprites: dict[str, Any], monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr(sprites_sandbox, "_SPRITE_READY_POLL_INTERVAL_S", 0.0) + fake_client = patched_sprites["client"] + fake_client._sprites_by_name[SPRITE_NAME] = _FakeSprite(name=SPRITE_NAME, status="starting") + state = _make_state(timeout_ms=1) # 1ms deadline + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=fake_client) + with pytest.raises(WorkspaceStartError) as excinfo: + await inner._wait_for_sprite_running() + assert excinfo.value.context.get("reason") == "wait_for_running_timeout" + + +# ---------- 6. Exec mapping ---------- + + +@pytest.mark.asyncio +async def test_exec_internal_returns_buffered_streams( + patched_sprites: dict[str, Any], +) -> None: + fake_control = patched_sprites["control"] + fake_control.next_ops.append(_FakeOpConn(stdout=b"hi\n", stderr=b"warn\n", exit_code=0)) + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + fake_sprite = _FakeSprite(name=SPRITE_NAME) + _attach(inner, client=patched_sprites["client"], sprite=fake_sprite) + result = await inner._exec_internal("echo", "hi", timeout=5.0) + assert isinstance(result, ExecResult) + assert result.stdout == b"hi\n" + assert result.stderr == b"warn\n" + assert result.exit_code == 0 + assert fake_control.start_op_calls == [ + {"op": "exec", "cmd": ["echo", "hi"], "dir": "/workspace", "tty": False, "stdin": False} + ] + + +@pytest.mark.asyncio +async def test_exec_internal_timeout_raises_and_signals_kill( + patched_sprites: dict[str, Any], +) -> None: + fake_control = patched_sprites["control"] + never = asyncio.Event() + op = _FakeOpConn(wait_event=never) + fake_control.next_ops.append(op) + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=patched_sprites["client"], sprite=_FakeSprite(name=SPRITE_NAME)) + with pytest.raises(ExecTimeoutError): + await inner._exec_internal("sleep", "1000", timeout=0.05) + assert "KILL" in op.signals + + +@pytest.mark.asyncio +async def test_exec_internal_start_op_failure_raises_transport_error( + patched_sprites: dict[str, Any], +) -> None: + fake_control = patched_sprites["control"] + fake_control.start_op_failures.append(RuntimeError("ws closed")) + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=patched_sprites["client"], sprite=_FakeSprite(name=SPRITE_NAME)) + with pytest.raises(ExecTransportError): + await inner._exec_internal("echo", "x", timeout=1.0) + + +# ---------- 7. PTY ---------- + + +def test_supports_pty_is_true() -> None: + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + assert inner.supports_pty() is True + + +@pytest.mark.asyncio +async def test_pty_exec_start_registers_callbacks_and_pre_drains( + patched_sprites: dict[str, Any], +) -> None: + fake_control = patched_sprites["control"] + op = _FakeOpConn(stdout=b"pre-drain\n", exit_code=0) + fake_control.next_ops.append(op) + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + sprite = _FakeSprite(name=SPRITE_NAME) + _attach(inner, client=patched_sprites["client"], sprite=sprite) + + update = await inner.pty_exec_start( + "bash", + shell=False, + tty=True, + yield_time_s=0.001, + ) + # Either pre-drain (via get_stdout) or callback drain ran; the chunk should + # appear in the returned output. + assert b"pre-drain" in update.output + assert fake_control.start_op_calls[0]["tty"] is True + + +@pytest.mark.asyncio +async def test_pty_write_stdin_writes_and_returns_buffered_output( + patched_sprites: dict[str, Any], +) -> None: + fake_control = patched_sprites["control"] + op = _FakeOpConn() + fake_control.next_ops.append(op) + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + sprite = _FakeSprite(name=SPRITE_NAME) + _attach(inner, client=patched_sprites["client"], sprite=sprite) + + started = await inner.pty_exec_start("bash", shell=False, tty=True, yield_time_s=0.001) + session_id = started.process_id + assert session_id is not None + + # Simulate the server pushing output between writes by delivering a chunk + # synchronously through the registered on_stdout callback. + assert op.on_stdout is not None + op.on_stdout(b"hello\n") + update = await inner.pty_write_stdin(session_id=session_id, chars="ls\n", yield_time_s=0.001) + assert op.write_calls == [b"ls\n"] + assert b"hello" in update.output + + +@pytest.mark.asyncio +async def test_pty_terminate_all_signals_term_and_kill( + patched_sprites: dict[str, Any], +) -> None: + fake_control = patched_sprites["control"] + op = _FakeOpConn() + fake_control.next_ops.append(op) + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + sprite = _FakeSprite(name=SPRITE_NAME) + _attach(inner, client=patched_sprites["client"], sprite=sprite) + + await inner.pty_exec_start("bash", shell=False, tty=True, yield_time_s=0.001) + await inner.pty_terminate_all() + # Live op never closed, so terminate sequences TERM then KILL. + assert "TERM" in op.signals + assert "KILL" in op.signals + assert inner._pty_processes == {} + + +@pytest.mark.asyncio +async def test_pty_finalize_drops_session_when_op_closed( + patched_sprites: dict[str, Any], +) -> None: + fake_control = patched_sprites["control"] + op = _FakeOpConn(exit_code=0) + # Mark closed so _entry_exit_code returns 0 immediately. + op.closed = True + op.exit_code = 0 + fake_control.next_ops.append(op) + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + sprite = _FakeSprite(name=SPRITE_NAME) + _attach(inner, client=patched_sprites["client"], sprite=sprite) + + update = await inner.pty_exec_start("bash", shell=False, tty=True, yield_time_s=0.001) + assert update.process_id is None + assert update.exit_code == 0 + assert inner._pty_processes == {} + + +# ---------- 8. Exposed ports ---------- + + +def test_options_exposed_ports_can_be_empty_or_single() -> None: + SpritesSandboxClientOptions(exposed_ports=()) + SpritesSandboxClientOptions(exposed_ports=(8080,)) + + +@pytest.mark.asyncio +async def test_validate_exposed_ports_rejects_more_than_one( + patched_sprites: dict[str, Any], +) -> None: + client = SpritesSandboxClient(token="tok") + options = SpritesSandboxClientOptions(exposed_ports=(8080, 9090)) + with pytest.raises(ConfigurationError): + await client.create(options=options) + + +@pytest.mark.asyncio +async def test_resolve_exposed_port_happy_path( + patched_sprites: dict[str, Any], +) -> None: + fake_client = patched_sprites["client"] + sprite = _FakeSprite( + name=SPRITE_NAME, + url="https://example-sprite-example-org.sprites.dev", + services=[_FakeService(http_port=8080)], + ) + fake_client._sprites_by_name[SPRITE_NAME] = sprite + state = _make_state(exposed_ports=(8080,)) + inner = SpritesSandboxSession.from_state(state, token="tok") + # Inject fakes; bypass _ensure_sprite (which would re-create on the fake). + _attach(inner, client=fake_client, sprite=sprite) + endpoint = await inner._resolve_exposed_port(8080) + assert isinstance(endpoint, ExposedPortEndpoint) + assert endpoint.tls is True + assert endpoint.host == "example-sprite-example-org.sprites.dev" + assert endpoint.port == 443 + + +@pytest.mark.asyncio +async def test_resolve_exposed_port_not_configured_when_no_matching_service( + patched_sprites: dict[str, Any], +) -> None: + fake_client = patched_sprites["client"] + sprite = _FakeSprite( + name=SPRITE_NAME, + url="https://example-sprite-example-org.sprites.dev", + services=[_FakeService(http_port=3000)], + ) + fake_client._sprites_by_name[SPRITE_NAME] = sprite + state = _make_state(exposed_ports=(8080,)) + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=fake_client, sprite=sprite) + with pytest.raises(ExposedPortUnavailableError) as excinfo: + await inner._resolve_exposed_port(8080) + assert excinfo.value.context.get("backend") == "sprites" + + +# ---------- 9. Read / write ---------- + + +@pytest.mark.asyncio +async def test_read_returns_bytesio(patched_sprites: dict[str, Any]) -> None: + fake_client = patched_sprites["client"] + sprite = _FakeSprite(name=SPRITE_NAME, files={"/workspace/hi.txt": b"hello"}) + fake_client._sprites_by_name[SPRITE_NAME] = sprite + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=fake_client, sprite=sprite) + + # Bypass _validate_path_access (uses runtime helper exec) for unit-test isolation. + async def _validate_passthrough(path: Path | str, *, for_write: bool = False) -> Path: + return Path(str(path)) + + with patch.object(inner, "_validate_path_access", _validate_passthrough): + stream = await inner.read(Path("/workspace/hi.txt")) + assert isinstance(stream, io.IOBase) + assert stream.read() == b"hello" + + +@pytest.mark.asyncio +async def test_read_missing_file_raises_workspace_read_not_found( + patched_sprites: dict[str, Any], +) -> None: + fake_client = patched_sprites["client"] + sprite = _FakeSprite(name=SPRITE_NAME, files={}) + fake_client._sprites_by_name[SPRITE_NAME] = sprite + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=fake_client, sprite=sprite) + + async def _validate_passthrough(path: Path | str, *, for_write: bool = False) -> Path: + return Path(str(path)) + + with patch.object(inner, "_validate_path_access", _validate_passthrough): + with pytest.raises(WorkspaceReadNotFoundError): + await inner.read(Path("/workspace/missing.txt")) + + +@pytest.mark.asyncio +async def test_write_rejects_string_payload_with_workspace_write_type_error( + patched_sprites: dict[str, Any], +) -> None: + fake_client = patched_sprites["client"] + sprite = _FakeSprite(name=SPRITE_NAME) + fake_client._sprites_by_name[SPRITE_NAME] = sprite + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=fake_client, sprite=sprite) + + async def _validate_passthrough(path: Path | str, *, for_write: bool = False) -> Path: + return Path(str(path)) + + class _BadStream(io.IOBase): + def read(self, *_args: Any) -> Any: + return 42 # not bytes / str + + with patch.object(inner, "_validate_path_access", _validate_passthrough): + with pytest.raises(WorkspaceWriteTypeError): + await inner.write(Path("/workspace/x"), _BadStream()) + + +@pytest.mark.asyncio +async def test_write_propagates_filesystem_failure_as_archive_write_error( + patched_sprites: dict[str, Any], +) -> None: + fake_client = patched_sprites["client"] + sprite = _FakeSprite(name=SPRITE_NAME) + sprite._fs.write_failure = RuntimeError("disk full") + fake_client._sprites_by_name[SPRITE_NAME] = sprite + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=fake_client, sprite=sprite) + + async def _validate_passthrough(path: Path | str, *, for_write: bool = False) -> Path: + return Path(str(path)) + + with patch.object(inner, "_validate_path_access", _validate_passthrough): + with pytest.raises(WorkspaceArchiveWriteError): + await inner.write(Path("/workspace/x"), io.BytesIO(b"data")) + + +@pytest.mark.asyncio +async def test_read_rejects_user_arg(patched_sprites: dict[str, Any]) -> None: + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=patched_sprites["client"]) + with pytest.raises(ConfigurationError): + await inner.read(Path("/workspace/x"), user="root") + + +@pytest.mark.asyncio +async def test_write_rejects_user_arg(patched_sprites: dict[str, Any]) -> None: + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=patched_sprites["client"]) + with pytest.raises(ConfigurationError): + await inner.write(Path("/workspace/x"), io.BytesIO(b"x"), user=User(name="r")) + + +@pytest.mark.asyncio +async def test_exec_rejects_user_arg(patched_sprites: dict[str, Any]) -> None: + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=patched_sprites["client"]) + with pytest.raises(ConfigurationError): + await inner.exec("echo", "x", user="root") + + +# ---------- 11 (subset). Tar-based persistence sanity ---------- + + +@pytest.mark.asyncio +async def test_persist_workspace_uses_tar_via_exec_and_filesystem_read( + patched_sprites: dict[str, Any], +) -> None: + import tarfile + + fake_control = patched_sprites["control"] + # tar cf, rm cleanup + fake_control.next_ops.append(_FakeOpConn(exit_code=0)) + fake_control.next_ops.append(_FakeOpConn(exit_code=0)) + + fake_client = patched_sprites["client"] + sprite = _FakeSprite(name=SPRITE_NAME) + archive_path = f"/tmp/openai-agents-{SESSION_UUID.hex}.tar" + # Build a minimal valid tar so hydrate could read it back. + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tar: + info = tarfile.TarInfo(name="./hello.txt") + info.size = 5 + tar.addfile(info, io.BytesIO(b"hello")) + sprite._fs.files[archive_path] = buf.getvalue() + fake_client._sprites_by_name[SPRITE_NAME] = sprite + + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=fake_client, sprite=sprite) + stream = await inner.persist_workspace() + assert isinstance(stream, io.IOBase) + archive_bytes = stream.read() + # Round-trip: validate the tar produced is parseable + with tarfile.open(fileobj=io.BytesIO(archive_bytes), mode="r:*") as tar: + names = tar.getnames() + assert "./hello.txt" in names + # First start_op was the tar create (passed shell=False, so verbatim cmd). + first_cmd = fake_control.start_op_calls[0]["cmd"] + assert first_cmd[0:3] == ["tar", "cf", archive_path] + assert "." in first_cmd + + +# ---------- 14. SpritesPlatformContext capability ---------- + + +@pytest.mark.asyncio +async def test_sprites_platform_context_reads_llm_txt( + patched_sprites: dict[str, Any], +) -> None: + fake_control = patched_sprites["control"] + fake_control.next_ops.append( + _FakeOpConn(stdout=b"# Sprite Environment\nbe nice\n", exit_code=0) + ) + fake_client = patched_sprites["client"] + sprite = _FakeSprite(name=SPRITE_NAME) + fake_client._sprites_by_name[SPRITE_NAME] = sprite + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=fake_client, sprite=sprite) + + capability = SpritesPlatformContext() + capability.bind(inner) + text = await capability.instructions(state.manifest) + assert text is not None + assert "" in text + assert "be nice" in text + # First exec call should be the cat — verbatim (shell=False) and absolute path. + first_cmd = fake_control.start_op_calls[0]["cmd"] + assert first_cmd[0:3] == ["cat", "--", "/.sprite/llm.txt"] + + +@pytest.mark.asyncio +async def test_sprites_platform_context_caches_after_first_read( + patched_sprites: dict[str, Any], +) -> None: + fake_control = patched_sprites["control"] + fake_control.next_ops.append(_FakeOpConn(stdout=b"ctx\n", exit_code=0)) + fake_client = patched_sprites["client"] + sprite = _FakeSprite(name=SPRITE_NAME) + fake_client._sprites_by_name[SPRITE_NAME] = sprite + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=fake_client, sprite=sprite) + + capability = SpritesPlatformContext() + capability.bind(inner) + a = await capability.instructions(state.manifest) + b = await capability.instructions(state.manifest) + assert a == b + # Only one start_op call total (cached on the second invocation). + assert len(fake_control.start_op_calls) == 1 + + +@pytest.mark.asyncio +async def test_sprites_platform_context_returns_none_when_file_missing( + patched_sprites: dict[str, Any], +) -> None: + fake_control = patched_sprites["control"] + fake_control.next_ops.append( + _FakeOpConn(stdout=b"", stderr=b"cat: /.sprite/llm.txt: No such file\n", exit_code=1) + ) + fake_client = patched_sprites["client"] + sprite = _FakeSprite(name=SPRITE_NAME) + fake_client._sprites_by_name[SPRITE_NAME] = sprite + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=fake_client, sprite=sprite) + + capability = SpritesPlatformContext() + capability.bind(inner) + assert await capability.instructions(state.manifest) is None + + +# ---------- 15. SpritesUrlAccess capability ---------- + + +@pytest.mark.asyncio +async def test_sprites_url_access_default_blocks_public( + patched_sprites: dict[str, Any], monkeypatch: pytest.MonkeyPatch +) -> None: + from agents.extensions.sandbox.sprites import SpritesUrlAccess + + sprite = _FakeSprite(name=SPRITE_NAME) + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=patched_sprites["client"], sprite=sprite) + + capability = SpritesUrlAccess() + capability.bind(inner) + result = await capability._apply_visibility("public") + assert "disabled by application policy" in result + # URL setting was NOT touched. + assert sprite.update_url_settings_calls == [] + + +@pytest.mark.asyncio +async def test_sprites_url_access_allow_public_calls_update( + patched_sprites: dict[str, Any], +) -> None: + from agents.extensions.sandbox.sprites import SpritesUrlAccess + + sprite = _FakeSprite(name=SPRITE_NAME) + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=patched_sprites["client"], sprite=sprite) + + capability = SpritesUrlAccess(allow_public=True) + capability.bind(inner) + result = await capability._apply_visibility("public") + assert "public" in result + assert len(sprite.update_url_settings_calls) == 1 + settings = sprite.update_url_settings_calls[0] + assert getattr(settings, "auth", None) == "public" + + +@pytest.mark.asyncio +async def test_sprites_url_access_sprite_value_works_without_allow_public( + patched_sprites: dict[str, Any], +) -> None: + from agents.extensions.sandbox.sprites import SpritesUrlAccess + + sprite = _FakeSprite(name=SPRITE_NAME) + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=patched_sprites["client"], sprite=sprite) + + capability = SpritesUrlAccess(allow_public=False) + capability.bind(inner) + result = await capability._apply_visibility("sprite") + assert "sprite" in result + assert len(sprite.update_url_settings_calls) == 1 + + +@pytest.mark.asyncio +async def test_sprites_url_access_invalid_value(patched_sprites: dict[str, Any]) -> None: + from agents.extensions.sandbox.sprites import SpritesUrlAccess + + sprite = _FakeSprite(name=SPRITE_NAME) + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=patched_sprites["client"], sprite=sprite) + capability = SpritesUrlAccess(allow_public=True) + capability.bind(inner) + result = await capability._apply_visibility("nonsense") + assert "must be" in result + assert sprite.update_url_settings_calls == [] + + +def test_sprites_url_access_tools_omit_or_include_public() -> None: + from agents.extensions.sandbox.sprites import SpritesUrlAccess + + cap = SpritesUrlAccess(allow_public=False) + tools = cap.tools() + assert len(tools) == 1 + cap_pub = SpritesUrlAccess(allow_public=True) + tools_pub = cap_pub.tools() + assert len(tools_pub) == 1 + + +# ---------- 16. SpritesCheckpoints capability ---------- + + +class _FakeCheckpoint: + def __init__(self, *, id: str, comment: str = "", create_time: Any = None) -> None: + from datetime import datetime, timezone + + self.id = id + self.comment = comment + self.create_time = create_time or datetime.now(timezone.utc) + + +from dataclasses import dataclass as _dataclass, field as _field # noqa: E402 + + +@_dataclass +class _CheckpointFakeOps: + """Tracks calls + state for the checkpoint stubs attached to a sprite.""" + + checkpoints: list[_FakeCheckpoint] = _field(default_factory=list) + create_calls: list[str] = _field(default_factory=list) + restore_calls: list[str] = _field(default_factory=list) + create_messages: list[Any] = _field(default_factory=list) + restore_messages: list[Any] = _field(default_factory=list) + + +def _attach_checkpoint_methods(sprite: _FakeSprite) -> _CheckpointFakeOps: + """Wire create/list/restore stubs onto ``sprite`` and return the tracking ops.""" + + ops = _CheckpointFakeOps() + + def _create(comment: str = "") -> Any: + ops.create_calls.append(comment) + from datetime import datetime, timedelta, timezone + + latest_time = max( + (c.create_time for c in ops.checkpoints), default=datetime.now(timezone.utc) + ) + timedelta(seconds=1) + ops.checkpoints.append( + _FakeCheckpoint( + id=f"ckpt-{len(ops.checkpoints) + 1}", + comment=comment, + create_time=latest_time, + ) + ) + return iter(ops.create_messages) + + def _list() -> list[_FakeCheckpoint]: + return list(ops.checkpoints) + + def _restore(checkpoint_id: str) -> Any: + ops.restore_calls.append(checkpoint_id) + return iter(ops.restore_messages) + + setattr(sprite, "create_checkpoint", _create) # noqa: B010 + setattr(sprite, "list_checkpoints", _list) # noqa: B010 + setattr(sprite, "restore_checkpoint", _restore) # noqa: B010 + return ops + + +@pytest.mark.asyncio +async def test_sprites_checkpoints_create_returns_id(patched_sprites: dict[str, Any]) -> None: + from agents.extensions.sandbox.sprites import SpritesCheckpoints + + sprite = _FakeSprite(name=SPRITE_NAME) + ops = _attach_checkpoint_methods(sprite) + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=patched_sprites["client"], sprite=sprite) + + cap = SpritesCheckpoints() + cap.bind(inner) + out = await cap._create("before-refactor") + assert "id='ckpt-1'" in out + assert "before-refactor" in out + assert ops.create_calls == ["before-refactor"] + + +@pytest.mark.asyncio +async def test_sprites_checkpoints_list_renders_rows(patched_sprites: dict[str, Any]) -> None: + from agents.extensions.sandbox.sprites import SpritesCheckpoints + + sprite = _FakeSprite(name=SPRITE_NAME) + _attach_checkpoint_methods(sprite) + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=patched_sprites["client"], sprite=sprite) + + cap = SpritesCheckpoints() + cap.bind(inner) + await cap._create("first") + await cap._create("second") + out = await cap._list() + assert "ckpt-1" in out + assert "ckpt-2" in out + + +@pytest.mark.asyncio +async def test_sprites_checkpoints_restore_blocked_by_default( + patched_sprites: dict[str, Any], +) -> None: + from agents.extensions.sandbox.sprites import SpritesCheckpoints + + sprite = _FakeSprite(name=SPRITE_NAME) + ops = _attach_checkpoint_methods(sprite) + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=patched_sprites["client"], sprite=sprite) + + cap = SpritesCheckpoints(allow_restore=False) + cap.bind(inner) + out = await cap._restore("ckpt-1") + assert "disabled" in out + assert ops.restore_calls == [] + + +@pytest.mark.asyncio +async def test_sprites_checkpoints_restore_when_enabled(patched_sprites: dict[str, Any]) -> None: + from agents.extensions.sandbox.sprites import SpritesCheckpoints + + sprite = _FakeSprite(name=SPRITE_NAME) + ops = _attach_checkpoint_methods(sprite) + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=patched_sprites["client"], sprite=sprite) + + cap = SpritesCheckpoints(allow_restore=True) + cap.bind(inner) + out = await cap._restore("ckpt-7") + assert "ckpt-7" in out + assert ops.restore_calls == ["ckpt-7"] + + +def test_sprites_checkpoints_tool_count_depends_on_allow_restore() -> None: + from agents.extensions.sandbox.sprites import SpritesCheckpoints + + cap_no = SpritesCheckpoints(allow_restore=False) + cap_yes = SpritesCheckpoints(allow_restore=True) + assert len(cap_no.tools()) == 2 # create + list + assert len(cap_yes.tools()) == 3 # + restore + + +# ---------- 17. Lazy wake-up ---------- + + +@pytest.mark.asyncio +async def test_named_attach_create_does_not_poll_for_running( + patched_sprites: dict[str, Any], +) -> None: + """create() with sprite_name should NOT call get_sprite during attach. + + The platform auto-wakes the sprite on first traffic; polling here would + pay wake-up latency just to hand back a session handle. The first I/O + operation drives the wake-up via _ensure_warm. + """ + + fake_client = patched_sprites["client"] + fake_client._sprites_by_name["existing"] = _FakeSprite(name="existing") + client = SpritesSandboxClient(token="tok") + options = SpritesSandboxClientOptions(sprite_name="existing") + session = await client.create(options=options) + inner = session._inner + assert isinstance(inner, SpritesSandboxSession) + # No get_sprite calls because we did not poll for warmth. + assert fake_client.get_sprite_calls == [] + # And the warmth flag stays False, so the next I/O will trigger the poll. + assert inner._warmth_verified is False + + +@pytest.mark.asyncio +async def test_lazy_warm_polls_on_first_exec(patched_sprites: dict[str, Any]) -> None: + fake_client = patched_sprites["client"] + fake_client._sprites_by_name["existing"] = _FakeSprite(name="existing") + fake_control = patched_sprites["control"] + fake_control.next_ops.append(_FakeOpConn(stdout=b"", exit_code=0)) + + client = SpritesSandboxClient(token="tok") + session = await client.create(options=SpritesSandboxClientOptions(sprite_name="existing")) + inner = session._inner + assert isinstance(inner, SpritesSandboxSession) + assert inner._warmth_verified is False + assert fake_client.get_sprite_calls == [] + + # First exec drives the wake-up poll. + await inner._exec_internal("echo", "hi") + assert fake_client.get_sprite_calls == ["existing"] + assert inner._warmth_verified is True + + # Subsequent exec does NOT re-poll. + fake_control.next_ops.append(_FakeOpConn(stdout=b"", exit_code=0)) + await inner._exec_internal("echo", "hi2") + # Still just the one poll from the first call. + assert fake_client.get_sprite_calls == ["existing"] + + +@pytest.mark.asyncio +async def test_lazy_warm_invalidate_forces_repoll(patched_sprites: dict[str, Any]) -> None: + fake_client = patched_sprites["client"] + fake_client._sprites_by_name["existing"] = _FakeSprite(name="existing") + fake_control = patched_sprites["control"] + fake_control.next_ops.extend([_FakeOpConn(exit_code=0), _FakeOpConn(exit_code=0)]) + + client = SpritesSandboxClient(token="tok") + session = await client.create(options=SpritesSandboxClientOptions(sprite_name="existing")) + inner = session._inner + assert isinstance(inner, SpritesSandboxSession) + + await inner._exec_internal("echo", "1") + assert len(fake_client.get_sprite_calls) == 1 + + inner._invalidate_warmth() + await inner._exec_internal("echo", "2") + assert len(fake_client.get_sprite_calls) == 2 + + +# ---------- 18. Idle-close watcher ---------- + + +@pytest.mark.asyncio +async def test_idle_watch_closes_control_connections_after_threshold( + patched_sprites: dict[str, Any], +) -> None: + fake_client = patched_sprites["client"] + sprite = _FakeSprite(name=SPRITE_NAME) + fake_client._sprites_by_name[SPRITE_NAME] = sprite + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=fake_client, sprite=sprite) + # Make the idle window vanishingly small so the test runs fast. + inner._idle_close_seconds = 0.01 + + # Touch activity to spawn the watcher, then wait long enough for the + # watcher's idle threshold to elapse and close the control connection. + inner._touch_activity() + assert inner._idle_watch_task is not None + await asyncio.wait_for(inner._idle_watch_task, timeout=1.0) + assert sprite.close_control_connection_calls == 1 + + +@pytest.mark.asyncio +async def test_idle_watch_disabled_when_seconds_is_zero( + patched_sprites: dict[str, Any], +) -> None: + fake_client = patched_sprites["client"] + sprite = _FakeSprite(name=SPRITE_NAME) + fake_client._sprites_by_name[SPRITE_NAME] = sprite + state = _make_state(idle_close_seconds=0) + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=fake_client, sprite=sprite) + inner._idle_close_seconds = 0 # belt-and-braces + + inner._touch_activity() + assert inner._idle_watch_task is None + # Wait briefly to confirm no close ever fires. + await asyncio.sleep(0.05) + assert sprite.close_control_connection_calls == 0 + + +@pytest.mark.asyncio +async def test_idle_watch_skipped_when_pty_active( + patched_sprites: dict[str, Any], +) -> None: + fake_client = patched_sprites["client"] + sprite = _FakeSprite(name=SPRITE_NAME) + fake_client._sprites_by_name[SPRITE_NAME] = sprite + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=fake_client, sprite=sprite) + inner._idle_close_seconds = 0.01 + # Pretend a PTY is active so the watcher should refuse to close. + inner._pty_processes[123] = object() # type: ignore[assignment] + + await inner._close_idle_control_connections() + assert sprite.close_control_connection_calls == 0 + + +@pytest.mark.asyncio +async def test_activity_during_idle_window_keeps_connection_open( + patched_sprites: dict[str, Any], +) -> None: + fake_client = patched_sprites["client"] + sprite = _FakeSprite(name=SPRITE_NAME) + fake_client._sprites_by_name[SPRITE_NAME] = sprite + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=fake_client, sprite=sprite) + inner._idle_close_seconds = 0.05 + + inner._touch_activity() + # Half the window: nudge activity forward so the deadline shifts. + await asyncio.sleep(0.025) + inner._touch_activity() + # Wait long enough for the original deadline to have passed had we not + # touched activity, but short of the new deadline. + await asyncio.sleep(0.04) + # The connection should still be open at this point. + assert sprite.close_control_connection_calls == 0 + # Now actually let it idle out fully. + watcher = inner._idle_watch_task + assert watcher is not None + await asyncio.wait_for(watcher, timeout=0.2) + assert sprite.close_control_connection_calls == 1 + + +# ---------- 19. Platform-context cache survives cloning ---------- + + +@pytest.mark.asyncio +async def test_sprites_platform_context_cache_survives_clone( + patched_sprites: dict[str, Any], +) -> None: + """Each agent turn re-clones capabilities; the cache must survive that. + + Without a module-level cache, a new clone wakes the sprite every turn + just to re-read the (unchanged) platform-context file. With it, only + the first turn for a given sprite-name pays the exec. + """ + + fake_control = patched_sprites["control"] + fake_control.next_ops.append(_FakeOpConn(stdout=b"# Sprite\n", exit_code=0)) + fake_client = patched_sprites["client"] + sprite = _FakeSprite(name=SPRITE_NAME) + fake_client._sprites_by_name[SPRITE_NAME] = sprite + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=fake_client, sprite=sprite) + + # Turn 1: a fresh clone fetches and caches. + cap1 = SpritesPlatformContext() + cap1.bind(inner) + out1 = await cap1.instructions(state.manifest) + assert out1 is not None + assert "" in out1 + assert len(fake_control.start_op_calls) == 1 + + # Turn 2: a NEW clone targeting the same sprite hits the module cache. + cap2 = SpritesPlatformContext() + cap2.bind(inner) + out2 = await cap2.instructions(state.manifest) + assert out2 == out1 + # Still just the one exec — turn 2 didn't touch the sprite. + assert len(fake_control.start_op_calls) == 1 + + +@pytest.mark.asyncio +async def test_sprites_platform_context_cache_clear_forces_refetch( + patched_sprites: dict[str, Any], +) -> None: + from agents.extensions.sandbox.sprites import clear_platform_context_cache + + fake_control = patched_sprites["control"] + fake_control.next_ops.extend( + [_FakeOpConn(stdout=b"v1\n", exit_code=0), _FakeOpConn(stdout=b"v2\n", exit_code=0)] + ) + fake_client = patched_sprites["client"] + sprite = _FakeSprite(name=SPRITE_NAME) + fake_client._sprites_by_name[SPRITE_NAME] = sprite + state = _make_state() + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=fake_client, sprite=sprite) + + cap = SpritesPlatformContext() + cap.bind(inner) + out1 = await cap.instructions(state.manifest) + assert out1 is not None and "v1" in out1 + assert len(fake_control.start_op_calls) == 1 + + # Cache invalidation forces a re-fetch. + clear_platform_context_cache(SPRITE_NAME) + out2 = await cap.instructions(state.manifest) + assert out2 is not None and "v2" in out2 + assert len(fake_control.start_op_calls) == 2 + + +# ---------- 20. Platform context includes service working-directory hint ---------- + + +@pytest.mark.asyncio +async def test_platform_context_warns_about_service_cwd( + patched_sprites: dict[str, Any], +) -> None: + """The framing should warn the model that services run with cwd=$HOME by default. + + Without this warning, agents commonly create `python3 -m http.server` services + and serve from the home directory instead of the workspace. + """ + + fake_control = patched_sprites["control"] + fake_control.next_ops.append(_FakeOpConn(stdout=b"# Sprite\n", exit_code=0)) + fake_client = patched_sprites["client"] + sprite = _FakeSprite(name=SPRITE_NAME) + fake_client._sprites_by_name[SPRITE_NAME] = sprite + state = _make_state(manifest=Manifest(root="/workspace")) + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=fake_client, sprite=sprite) + + cap = SpritesPlatformContext() + cap.bind(inner) + out = await cap.instructions(state.manifest) + assert out is not None + assert "/workspace" in out + assert "--dir /workspace" in out + assert "sprite-env services create" in out + + +@pytest.mark.asyncio +async def test_platform_context_uses_actual_manifest_root( + patched_sprites: dict[str, Any], +) -> None: + """The hint must use the agent's actual manifest.root, not a hardcoded path.""" + + fake_control = patched_sprites["control"] + fake_control.next_ops.append(_FakeOpConn(stdout=b"# Sprite\n", exit_code=0)) + fake_client = patched_sprites["client"] + sprite = _FakeSprite(name=SPRITE_NAME) + fake_client._sprites_by_name[SPRITE_NAME] = sprite + state = _make_state(manifest=Manifest(root="/var/agent-home")) + inner = SpritesSandboxSession.from_state(state, token="tok") + _attach(inner, client=fake_client, sprite=sprite) + + cap = SpritesPlatformContext() + cap.bind(inner) + out = await cap.instructions(state.manifest) + assert out is not None + assert "/var/agent-home" in out + assert "--dir /var/agent-home" in out + + +# ---------- Cloud bucket mount strategy ---------- + + +from agents.extensions.sandbox.sprites.mounts import ( # noqa: E402 + _MISSING, + _MOUNTED, + _NOT_MOUNTED, + _PRESENT, + SpritesCloudBucketMountStrategy, + _assert_sprites_session, + _ensure_fuse_support, + _ensure_rclone, + _rclone_pattern_for_session, + _verify_mount_active, +) +from agents.sandbox.entries import ( # noqa: E402 + RcloneMountPattern, + S3Mount, +) +from agents.sandbox.errors import MountConfigError # noqa: E402 +from agents.sandbox.session.base_sandbox_session import ( # noqa: E402 + BaseSandboxSession, +) + + +class _FakeMountSession(BaseSandboxSession): + """Minimal SpritesSandboxSession-named fake driving canned exec results.""" + + def __init__(self, results: list[ExecResult] | None = None) -> None: + self.state = SpritesSandboxSessionState( + session_id=SESSION_UUID, + manifest=Manifest(root="/workspace"), + snapshot=NoopSnapshot(id="snapshot"), + sprite_name=SPRITE_NAME, + ) + self._results = list(results or []) + self.exec_calls: list[str] = [] + + async def _exec_internal( + self, + *command: str | Path, + timeout: float | None = None, + ) -> ExecResult: + _ = timeout + cmd_str = " ".join(str(c) for c in command) + self.exec_calls.append(cmd_str) + if self._results: + return self._results.pop(0) + return ExecResult(stdout=b"", stderr=b"", exit_code=0) + + async def read(self, path: Path, *, user: str | User | None = None) -> io.IOBase: + _ = (path, user) + return io.BytesIO(b"") + + async def write(self, path: Path, data: io.IOBase, *, user: str | User | None = None) -> None: + _ = (path, data, user) + + async def persist_workspace(self) -> io.IOBase: + raise AssertionError("not expected") + + async def hydrate_workspace(self, data: io.IOBase) -> None: + _ = data + raise AssertionError("not expected") + + async def running(self) -> bool: + return True + + +# The cloud-bucket strategy guards on ``type(session).__name__``; rebrand the fake. +_FakeMountSession.__name__ = "SpritesSandboxSession" + + +def _ok(stdout: bytes = b"") -> ExecResult: + return ExecResult(stdout=stdout, stderr=b"", exit_code=0) + + +def _fail(exit_code: int = 1) -> ExecResult: + return ExecResult(stdout=b"", stderr=b"", exit_code=exit_code) + + +def _present() -> ExecResult: + """ExecResult mimicking the stdout sentinel for a successful detection.""" + + return ExecResult(stdout=f"{_PRESENT}\n".encode("ascii"), stderr=b"", exit_code=0) + + +def _missing() -> ExecResult: + """ExecResult mimicking the stdout sentinel for a failed detection.""" + + return ExecResult(stdout=f"{_MISSING}\n".encode("ascii"), stderr=b"", exit_code=0) + + +def _mounted() -> ExecResult: + return ExecResult(stdout=f"{_MOUNTED}\n".encode("ascii"), stderr=b"", exit_code=0) + + +def _not_mounted() -> ExecResult: + return ExecResult(stdout=f"{_NOT_MOUNTED}\n".encode("ascii"), stderr=b"", exit_code=0) + + +def test_sprites_cloud_bucket_strategy_type_and_default_pattern() -> None: + strategy = SpritesCloudBucketMountStrategy() + + assert strategy.type == "sprites_cloud_bucket" + assert isinstance(strategy.pattern, RcloneMountPattern) + assert strategy.pattern.mode == "fuse" + + +def test_sprites_cloud_bucket_strategy_round_trips_through_manifest() -> None: + manifest = Manifest.model_validate( + { + "root": "/workspace", + "entries": { + "bucket": { + "type": "s3_mount", + "bucket": "my-bucket", + "mount_strategy": {"type": "sprites_cloud_bucket"}, + } + }, + } + ) + + mount = manifest.entries["bucket"] + assert isinstance(mount, S3Mount) + assert isinstance(mount.mount_strategy, SpritesCloudBucketMountStrategy) + + +def test_sprites_cloud_bucket_strategy_round_trips_through_registry() -> None: + payload = SpritesCloudBucketMountStrategy().model_dump() + parsed = SpritesCloudBucketMountStrategy.model_validate(payload) + + assert isinstance(parsed, SpritesCloudBucketMountStrategy) + assert parsed.type == "sprites_cloud_bucket" + assert parsed.pattern.mode == "fuse" + + +def test_sprites_session_guard_rejects_wrong_type() -> None: + class _NotAFlySession: + pass + + with pytest.raises(MountConfigError, match="SpritesSandboxSession"): + _assert_sprites_session(_NotAFlySession()) # type: ignore[arg-type] + + +def test_sprites_session_guard_accepts_correct_type() -> None: + _assert_sprites_session(_FakeMountSession()) + + +def test_sprites_extension_re_exports_cloud_bucket_strategy() -> None: + package_module = __import__( + "agents.extensions.sandbox", + fromlist=["SpritesCloudBucketMountStrategy"], + ) + sprites_module = __import__( + "agents.extensions.sandbox.sprites", + fromlist=["SpritesCloudBucketMountStrategy"], + ) + + assert package_module.SpritesCloudBucketMountStrategy is SpritesCloudBucketMountStrategy + assert sprites_module.SpritesCloudBucketMountStrategy is SpritesCloudBucketMountStrategy + + +@pytest.mark.asyncio +async def test_ensure_rclone_returns_quickly_when_already_installed() -> None: + session = _FakeMountSession([_present()]) + + await _ensure_rclone(session) + + # Single detection call wrapped in an if/then/echo so stdout, not exit code, + # is the source of truth. + assert session.exec_calls == [ + "sh -lc if command -v rclone >/dev/null 2>&1 || test -x /usr/local/bin/rclone; " + f"then echo {_PRESENT}; else echo {_MISSING}; fi" + ] + + +@pytest.mark.asyncio +async def test_ensure_rclone_installs_via_sudo_apt() -> None: + session = _FakeMountSession( + [ + _missing(), # rclone missing + _present(), # apt-get present + _ok(), # apt update + _ok(), # apt install (with fuse package) + _ok(), # rclone install script + _present(), # rclone now present + ] + ) + + await _ensure_rclone(session) + + assert session.exec_calls[0].startswith( + "sh -lc if command -v rclone >/dev/null 2>&1 || test -x /usr/local/bin/rclone" + ) + assert session.exec_calls[1].startswith("sh -lc if command -v apt-get >/dev/null 2>&1") + assert session.exec_calls[2] == ( + "sh -lc sudo -n env DEBIAN_FRONTEND=noninteractive DEBCONF_NOWARNINGS=yes " + "apt-get -o Dpkg::Use-Pty=0 update -qq" + ) + assert session.exec_calls[3] == ( + "sh -lc sudo -n env DEBIAN_FRONTEND=noninteractive DEBCONF_NOWARNINGS=yes " + "apt-get -o Dpkg::Use-Pty=0 install -y -qq curl unzip ca-certificates fuse" + ) + assert session.exec_calls[4] == ( + "sh -lc curl -fsSL https://rclone.org/install.sh | sudo -n bash" + ) + assert session.exec_calls[5].startswith( + "sh -lc if command -v rclone >/dev/null 2>&1 || test -x /usr/local/bin/rclone" + ) + + +@pytest.mark.asyncio +async def test_ensure_rclone_raises_when_apt_unavailable() -> None: + session = _FakeMountSession([_missing(), _missing()]) + + with pytest.raises(MountConfigError, match="apt-get is unavailable"): + await _ensure_rclone(session) + + +@pytest.mark.asyncio +async def test_ensure_rclone_raises_when_post_install_recheck_still_missing() -> None: + """When the install commands silently no-op, the post-install probe catches it. + + With unreliable exec exit codes the install commands themselves can't tell + us whether they actually installed anything; only the stdout sentinel from + the recheck does. This is the only install-failure mode we surface. + """ + + session = _FakeMountSession( + [ + _missing(), # rclone missing + _present(), # apt-get present + _ok(), # apt update (exit code unused) + _ok(), # apt install (exit code unused) + _ok(), # rclone install script (exit code unused) + _missing(), # rclone STILL missing post-install + ] + ) + + with pytest.raises(MountConfigError, match="install attempt completed but rclone is still"): + await _ensure_rclone(session) + + +@pytest.mark.asyncio +async def test_ensure_fuse_support_passes_when_kernel_and_fusermount_present() -> None: + session = _FakeMountSession([_present(), _present(), _ok()]) + + await _ensure_fuse_support(session) + + # kernel probe, fusermount probe, allow_other configuration + assert len(session.exec_calls) == 3 + assert session.exec_calls[0].startswith( + "sh -lc if test -c /dev/fuse && grep -qw fuse /proc/filesystems" + ) + assert session.exec_calls[1].startswith( + "sh -lc if command -v fusermount3 >/dev/null 2>&1 || command -v fusermount >/dev/null 2>&1" + ) + assert "sudo -n chmod a+rw /dev/fuse" in session.exec_calls[2] + assert "user_allow_other" in session.exec_calls[2] + + +@pytest.mark.asyncio +async def test_ensure_fuse_support_lazy_installs_fuse_package() -> None: + session = _FakeMountSession( + [ + _present(), # kernel ok + _missing(), # fusermount missing + _present(), # apt-get present + _ok(), # apt update + _ok(), # apt install fuse + _present(), # fusermount now present + _ok(), # allow_other config + ] + ) + + await _ensure_fuse_support(session) + + assert session.exec_calls[3] == ( + "sh -lc sudo -n env DEBIAN_FRONTEND=noninteractive DEBCONF_NOWARNINGS=yes " + "apt-get -o Dpkg::Use-Pty=0 update -qq" + ) + assert session.exec_calls[4] == ( + "sh -lc sudo -n env DEBIAN_FRONTEND=noninteractive DEBCONF_NOWARNINGS=yes " + "apt-get -o Dpkg::Use-Pty=0 install -y -qq fuse" + ) + + +@pytest.mark.asyncio +async def test_ensure_fuse_support_raises_when_kernel_missing_fuse() -> None: + session = _FakeMountSession([_missing()]) + + with pytest.raises(MountConfigError, match="FUSE support"): + await _ensure_fuse_support(session) + + +@pytest.mark.asyncio +async def test_ensure_fuse_support_raises_when_post_install_fusermount_still_missing() -> None: + session = _FakeMountSession( + [ + _present(), # kernel ok + _missing(), # fusermount missing + _present(), # apt-get present + _ok(), # apt update + _ok(), # apt install + _missing(), # fusermount STILL missing + ] + ) + + with pytest.raises(MountConfigError, match="install attempt completed but fusermount"): + await _ensure_fuse_support(session) + + +@pytest.mark.asyncio +async def test_verify_mount_active_passes_when_mountpoint_reports_mounted() -> None: + session = _FakeMountSession([_mounted(), _ok()]) + + await _verify_mount_active(session, Path("/workspace/tigris")) + + # Two calls: the mountpoint probe followed by a directory-listing warm-up + # that forces rclone to populate its root readdir cache before the caller + # uses the mount. + assert len(session.exec_calls) == 2 + assert "mountpoint -q /workspace/tigris" in session.exec_calls[0] + assert _MOUNTED in session.exec_calls[0] + assert "ls /workspace/tigris >/dev/null 2>&1" in session.exec_calls[1] + + +@pytest.mark.asyncio +async def test_verify_mount_active_raises_when_path_is_not_a_mount() -> None: + session = _FakeMountSession([_not_mounted()]) + + with pytest.raises(MountConfigError, match="not a live mountpoint") as excinfo: + await _verify_mount_active(session, Path("/workspace/tigris")) + assert excinfo.value.context.get("path") == "/workspace/tigris" + + +@pytest.mark.asyncio +async def test_verify_mount_active_warmup_runs_after_mountpoint_check() -> None: + """The post-mountpoint listing warmup runs even if the directory is empty. + + Confirms the warmup exec is fired regardless of what ``ls`` returns — + we only care about the side effect of priming rclone's dir cache. + """ + + session = _FakeMountSession([_mounted(), _ok()]) + + await _verify_mount_active(session, Path("/workspace/some/nested/mount")) + + assert "mountpoint -q /workspace/some/nested/mount" in session.exec_calls[0] + assert "ls /workspace/some/nested/mount" in session.exec_calls[1] + + +@pytest.mark.asyncio +async def test_rclone_pattern_appends_allow_other_and_user_ids() -> None: + session = _FakeMountSession([_ok(stdout=b"1001\n1001\n")]) + + pattern = await _rclone_pattern_for_session(session, RcloneMountPattern(mode="fuse")) + + assert pattern.extra_args == ["--allow-other", "--uid", "1001", "--gid", "1001"] + + +@pytest.mark.asyncio +async def test_rclone_pattern_preserves_explicit_extra_args() -> None: + session = _FakeMountSession([_ok(stdout=b"1001\n1001\n")]) + source = RcloneMountPattern( + mode="fuse", + extra_args=["--allow-other", "--uid", "9999", "--gid", "9999", "--buffer-size", "0"], + ) + + pattern = await _rclone_pattern_for_session(session, source) + + assert pattern.extra_args == [ + "--allow-other", + "--uid", + "9999", + "--gid", + "9999", + "--buffer-size", + "0", + ] + + +@pytest.mark.asyncio +async def test_rclone_pattern_skips_user_ids_when_id_command_fails() -> None: + session = _FakeMountSession([_fail()]) + + pattern = await _rclone_pattern_for_session(session, RcloneMountPattern(mode="fuse")) + + assert pattern.extra_args == ["--allow-other"] + + +@pytest.mark.asyncio +async def test_rclone_pattern_returns_unchanged_for_non_fuse_modes() -> None: + session = _FakeMountSession() + source = RcloneMountPattern(mode="nfs", nfs_addr="127.0.0.1:2049") + + pattern = await _rclone_pattern_for_session(session, source) + + assert pattern is source + assert session.exec_calls == [] diff --git a/tests/sandbox/test_compatibility_guards.py b/tests/sandbox/test_compatibility_guards.py index 7b85757f77..136b04c51e 100644 --- a/tests/sandbox/test_compatibility_guards.py +++ b/tests/sandbox/test_compatibility_guards.py @@ -332,6 +332,25 @@ def test_core_sandbox_public_export_surface_is_stable() -> None: "VercelSandboxSessionState", }, ), + ( + "agents.extensions.sandbox.sprites", + { + "DEFAULT_SPRITES_API_URL", + "DEFAULT_SPRITES_CONTEXT_PATH", + "DEFAULT_SPRITES_WAIT_FOR_RUNNING_TIMEOUT_S", + "DEFAULT_SPRITES_WORKSPACE_ROOT", + "SpritesCheckpoints", + "SpritesCloudBucketMountStrategy", + "SpritesPlatformContext", + "SpritesSandboxClient", + "SpritesSandboxClientOptions", + "SpritesSandboxSession", + "SpritesSandboxSessionState", + "SpritesUrlAccess", + "UrlVisibility", + "clear_platform_context_cache", + }, + ), ], ) def test_extension_sandbox_package_export_surfaces_are_stable( @@ -503,6 +522,24 @@ def test_optional_sandbox_dataclass_constructor_field_order_is_stable( "network_policy", ), ), + ( + "agents.extensions.sandbox.sprites", + "SpritesSandboxClientOptions", + ( + "sprite_name", + "url_auth", + "ram_mb", + "cpus", + "region", + "storage_gb", + "exposed_ports", + "env", + "timeout_ms", + "wait_for_running_timeout_s", + "workspace_persistence", + "idle_close_seconds", + ), + ), ], ) def test_optional_sandbox_client_options_positional_field_order_is_stable( @@ -738,6 +775,31 @@ def test_optional_sandbox_client_options_positional_field_order_is_stable( "network_policy", ), ), + ( + "agents.extensions.sandbox.sprites", + "SpritesSandboxSessionState", + ( + "type", + "session_id", + "snapshot", + "manifest", + "exposed_ports", + "snapshot_fingerprint", + "snapshot_fingerprint_version", + "workspace_root_ready", + "sprite_name", + "created_by_us", + "url_auth", + "ram_mb", + "cpus", + "region", + "storage_gb", + "env", + "timeout_ms", + "workspace_persistence", + "idle_close_seconds", + ), + ), ], ) def test_sandbox_session_state_field_order_is_stable( @@ -778,6 +840,7 @@ def test_sandbox_session_state_field_order_is_stable( ), ("agents.extensions.sandbox.daytona", "DaytonaSandboxClientOptions", (), "daytona"), ("agents.extensions.sandbox.runloop", "RunloopSandboxClientOptions", (), "runloop"), + ("agents.extensions.sandbox.sprites", "SpritesSandboxClientOptions", (), "sprites"), ("agents.extensions.sandbox.vercel", "VercelSandboxClientOptions", (), "vercel"), ], ) @@ -839,6 +902,11 @@ def test_optional_sandbox_client_options_json_round_trip_preserves_type( "RunloopSandboxSessionState", {"devbox_id": "devbox-123"}, ), + ( + "agents.extensions.sandbox.sprites", + "SpritesSandboxSessionState", + {"sprite_name": "sprite-xyz"}, + ), ( "agents.extensions.sandbox.vercel", "VercelSandboxSessionState", @@ -894,6 +962,8 @@ def test_core_discriminator_type_strings_are_stable() -> None: ("agents.sandbox.sandboxes.unix_local", "UnixLocalSandboxSessionState", "unix_local"), ("agents.sandbox.sandboxes.docker", "DockerSandboxClientOptions", "docker"), ("agents.sandbox.sandboxes.docker", "DockerSandboxSessionState", "docker"), + ("agents.extensions.sandbox.sprites", "SpritesSandboxClientOptions", "sprites"), + ("agents.extensions.sandbox.sprites", "SpritesSandboxSessionState", "sprites"), ], ) def test_optional_sandbox_discriminator_type_strings_are_stable( @@ -952,6 +1022,11 @@ def test_mount_strategy_type_strings_round_trip_through_registry( "RunloopCloudBucketMountStrategy", "runloop_cloud_bucket", ), + ( + "agents.extensions.sandbox.sprites", + "SpritesCloudBucketMountStrategy", + "sprites_cloud_bucket", + ), ], ) def test_optional_mount_strategy_type_strings_round_trip_through_registry( diff --git a/uv.lock b/uv.lock index 75f5c3b065..83ee7675d4 100644 --- a/uv.lock +++ b/uv.lock @@ -9,7 +9,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-04-19T02:17:07.061414832Z" +exclude-newer = "2026-04-21T11:44:13.372823398Z" exclude-newer-span = "P7D" [[package]] @@ -2494,6 +2494,9 @@ runloop = [ s3 = [ { name = "boto3" }, ] +sprites = [ + { name = "sprites-py" }, +] sqlalchemy = [ { name = "asyncpg" }, { name = "sqlalchemy" }, @@ -2575,6 +2578,7 @@ requires-dist = [ { name = "redis", marker = "extra == 'redis'", specifier = ">=7" }, { name = "requests", specifier = ">=2.0,<3" }, { name = "runloop-api-client", marker = "extra == 'runloop'", specifier = ">=1.16.0,<2.0.0" }, + { name = "sprites-py", marker = "extra == 'sprites'", specifier = ">=0.0.1rc37,<0.2" }, { name = "sqlalchemy", marker = "extra == 'sqlalchemy'", specifier = ">=2.0" }, { name = "temporalio", marker = "extra == 'temporal'", specifier = "==1.26.0" }, { name = "textual", marker = "extra == 'temporal'", specifier = ">=8.2.3,<8.3" }, @@ -2585,7 +2589,7 @@ requires-dist = [ { name = "websockets", marker = "extra == 'realtime'", specifier = ">=15.0,<17" }, { name = "websockets", marker = "extra == 'voice'", specifier = ">=15.0,<17" }, ] -provides-extras = ["voice", "viz", "litellm", "any-llm", "realtime", "sqlalchemy", "encrypt", "redis", "dapr", "mongodb", "docker", "blaxel", "daytona", "cloudflare", "e2b", "modal", "runloop", "vercel", "s3", "temporal"] +provides-extras = ["voice", "viz", "litellm", "any-llm", "realtime", "sqlalchemy", "encrypt", "redis", "dapr", "mongodb", "docker", "blaxel", "daytona", "cloudflare", "e2b", "modal", "runloop", "sprites", "vercel", "s3", "temporal"] [package.metadata.requires-dev] dev = [ @@ -3865,6 +3869,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e1/3e/61d88e6b0a7383127cdc779195cb9d83ebcf11d39bc961de5777e457075e/sounddevice-0.5.2-py3-none-win_amd64.whl", hash = "sha256:e18944b767d2dac3771a7771bdd7ff7d3acd7d334e72c4bedab17d1aed5dbc22", size = 363808, upload-time = "2025-05-16T18:12:26Z" }, ] +[[package]] +name = "sprites-py" +version = "0.0.1rc37" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/88/cd/24bbca8438fdc8a89e3365b927b20f511e1c4529a8574ea252e967f8071c/sprites_py-0.0.1rc37.tar.gz", hash = "sha256:2fc58aa80a9a99c1a12cb17725d13675d5e25dc2a33f245f21486473aede05b9", size = 35226, upload-time = "2026-02-19T21:38:51.643Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b6/8b/87c87887ae3bde72425134efea163fdbdef610765a8db7244c7b4adeab46/sprites_py-0.0.1rc37-py3-none-any.whl", hash = "sha256:d032b0059e2881b16a35349c0f01adafa99143671103a43a79d7e39a0c1ef6c9", size = 39424, upload-time = "2026-02-19T21:38:49.9Z" }, +] + [[package]] name = "sqlalchemy" version = "2.0.43"