diff --git a/src/agents/sandbox/entries/mounts/patterns.py b/src/agents/sandbox/entries/mounts/patterns.py index 931fa03450..e9f6a3751a 100644 --- a/src/agents/sandbox/entries/mounts/patterns.py +++ b/src/agents/sandbox/entries/mounts/patterns.py @@ -1,6 +1,7 @@ from __future__ import annotations import abc +import hashlib import io import re import shlex @@ -108,6 +109,41 @@ async def _write_sensitive_config_file( ) +def _render_shell_exports(env_vars: list[tuple[str, str]]) -> bytes: + lines = [f"export {name}={shlex.quote(value)}" for name, value in env_vars] + return ("\n".join(lines) + "\n").encode("utf-8") + + +def _redact_sensitive_values(text: str, sensitive_values: list[str]) -> str: + redacted = text + for value in sensitive_values: + if not value: + continue + redacted = redacted.replace(value, "REDACTED") + quoted = shlex.quote(value) + if quoted != value: + redacted = redacted.replace(quoted, "REDACTED") + return redacted + + +async def _read_text_if_present(session: BaseSandboxSession, path: Path) -> str: + try: + handle = await session.read(path) + except Exception: + return "" + + try: + raw = handle.read() + finally: + handle.close() + + if isinstance(raw, bytes): + return raw.decode("utf-8", errors="replace") + if isinstance(raw, str): + return raw + return str(raw) + + class MountPatternBase(BaseModel, abc.ABC): @abc.abstractmethod async def apply( @@ -425,25 +461,57 @@ async def apply( cmd.extend(["--prefix", mountpoint_config.prefix]) cmd.extend([bucket, sandbox_path_str(path)]) - env_parts: list[str] = [] + env_vars: list[tuple[str, str]] = [] access_key_id = mountpoint_config.access_key_id secret_access_key = mountpoint_config.secret_access_key session_token = mountpoint_config.session_token if access_key_id and secret_access_key: - env_parts.append(f"AWS_ACCESS_KEY_ID={shlex.quote(access_key_id)}") - env_parts.append(f"AWS_SECRET_ACCESS_KEY={shlex.quote(secret_access_key)}") + env_vars.append(("AWS_ACCESS_KEY_ID", access_key_id)) + env_vars.append(("AWS_SECRET_ACCESS_KEY", secret_access_key)) if session_token: - env_parts.append(f"AWS_SESSION_TOKEN={shlex.quote(session_token)}") + env_vars.append(("AWS_SESSION_TOKEN", session_token)) joined_cmd = " ".join(shlex.quote(part) for part in cmd) - if env_parts: - joined_cmd = f"{' '.join(env_parts)} {joined_cmd}" + stderr_path: Path | None = None + sensitive_values = [value for _name, value in env_vars] + if env_vars: + session_id = getattr(session.state, "session_id", None) + if session_id is None: + raise MountConfigError( + message="mount session is missing session_id", + context={"type": mountpoint_config.mount_type}, + ) + command_hash = hashlib.sha256( + f"{bucket}\0{sandbox_path_str(path)}".encode() + ).hexdigest()[:16] + config_dir = posix_path_as_path( + coerce_posix_path(f".sandbox-mountpoint-env/{session_id.hex}") + ) + env_path = config_dir / f"{command_hash}.env" + stdout_path = config_dir / f"{command_hash}.stdout" + stderr_path = config_dir / f"{command_hash}.stderr" + + await session.mkdir(config_dir, parents=True) + session.register_persist_workspace_skip_path(config_dir) + await _write_sensitive_config_file(session, env_path, _render_shell_exports(env_vars)) + + command_env_path = sandbox_path_str(session.normalize_path(env_path)) + command_stdout_path = sandbox_path_str(session.normalize_path(stdout_path)) + command_stderr_path = sandbox_path_str(session.normalize_path(stderr_path)) + joined_cmd = ( + f". {shlex.quote(command_env_path)} && exec {joined_cmd} " + f">{shlex.quote(command_stdout_path)} 2>{shlex.quote(command_stderr_path)}" + ) result = await session.exec("sh", "-lc", joined_cmd, shell=False) if not result.ok(): + stderr = result.stderr.decode("utf-8", errors="replace") + if stderr_path is not None: + stderr += await _read_text_if_present(session, stderr_path) + stderr = _redact_sensitive_values(stderr, sensitive_values) raise MountCommandError( command=joined_cmd, - stderr=result.stderr.decode("utf-8", errors="replace"), + stderr=stderr, context={"bucket": bucket}, ) diff --git a/src/agents/sandbox/session/sandbox_session.py b/src/agents/sandbox/session/sandbox_session.py index 22d4212baf..97dccfa07b 100644 --- a/src/agents/sandbox/session/sandbox_session.py +++ b/src/agents/sandbox/session/sandbox_session.py @@ -274,6 +274,9 @@ def _set_archive_limits(self, limits: SandboxArchiveLimits | None) -> None: def normalize_path(self, path: Path | str, *, for_write: bool = False) -> Path: return self._inner.normalize_path(path, for_write=for_write) + def register_persist_workspace_skip_path(self, path: Path | str) -> Path: + return self._inner.register_persist_workspace_skip_path(path) + def supports_pty(self) -> bool: return self._inner.supports_pty() diff --git a/tests/sandbox/test_mounts.py b/tests/sandbox/test_mounts.py index da1ddbe46a..28f597d272 100644 --- a/tests/sandbox/test_mounts.py +++ b/tests/sandbox/test_mounts.py @@ -29,8 +29,12 @@ RcloneMountConfig, S3FilesMountConfig, ) -from agents.sandbox.errors import MountConfigError +from agents.sandbox.errors import MountCommandError, MountConfigError from agents.sandbox.session.base_sandbox_session import BaseSandboxSession +from agents.sandbox.session.events import SandboxSessionEvent +from agents.sandbox.session.manager import Instrumentation +from agents.sandbox.session.sandbox_session import SandboxSession +from agents.sandbox.session.sinks import CallbackSink from agents.sandbox.snapshot import NoopSnapshot from agents.sandbox.types import ExecResult from tests.utils.factories import TestSessionState @@ -76,13 +80,16 @@ async def hydrate_workspace(self, data: io.IOBase) -> None: class _MountpointApplySession(BaseSandboxSession): - def __init__(self) -> None: + def __init__(self, *, mount_exit_code: int = 0, mount_stderr: bytes = b"") -> None: self.state = TestSessionState( session_id=uuid.uuid4(), manifest=Manifest(root="/workspace"), snapshot=NoopSnapshot(id=str(uuid.uuid4())), ) + self._mount_exit_code = mount_exit_code + self._mount_stderr = mount_stderr self.exec_calls: list[list[str]] = [] + self.write_calls: list[tuple[Path, bytes]] = [] async def read(self, path: Path, *, user: object = None) -> io.BytesIO: _ = (path, user) @@ -92,12 +99,15 @@ async def shutdown(self) -> None: return None async def write(self, path: Path, data: io.IOBase, *, user: object = None) -> None: - _ = (path, data, user) - raise AssertionError("write() should not be called in these tests") + _ = user + self.write_calls.append((path, data.read())) async def running(self) -> bool: return True + def persist_workspace_skip_paths(self) -> set[Path]: + return self._persist_workspace_skip_relpaths() + async def _exec_internal( self, *command: str | Path, @@ -106,6 +116,15 @@ async def _exec_internal( _ = timeout command_strs = [str(part) for part in command] self.exec_calls.append(command_strs) + if ( + len(command_strs) >= 3 + and command_strs[:2] == ["sh", "-lc"] + and "mount-s3 " in command_strs[2] + and "command -v " not in command_strs[2] + ): + return ExecResult( + exit_code=self._mount_exit_code, stdout=b"", stderr=self._mount_stderr + ) return ExecResult(exit_code=0, stdout=b"", stderr=b"") async def persist_workspace(self) -> io.IOBase: @@ -380,15 +399,23 @@ async def test_gcs_mount_uses_runtime_endpoint_override_without_mutating_pattern ["sh", "-lc", "command -v mount-s3 >/dev/null 2>&1"], ["mkdir", "-p", "/workspace/remote"], ] - assert len(session.exec_calls) == 3 - - mount_command = session.exec_calls[2] + assert len(session.exec_calls) == 5 + assert len(session.write_calls) == 1 + env_path, env_payload = session.write_calls[0] + assert env_path.as_posix().startswith(".sandbox-mountpoint-env/") + assert env_path.name.endswith(".env") + assert env_payload == b"export AWS_ACCESS_KEY_ID=access\nexport AWS_SECRET_ACCESS_KEY=secret\n" + + mount_command = session.exec_calls[-1] assert mount_command[:2] == ["sh", "-lc"] assert "mount-s3" in mount_command[2] + assert "AWS_ACCESS_KEY_ID=access" not in mount_command[2] + assert "AWS_SECRET_ACCESS_KEY=secret" not in mount_command[2] + assert ".sandbox-mountpoint-env" in mount_command[2] assert "--region us-east1" in mount_command[2] assert "--endpoint-url https://storage.googleapis.com" in mount_command[2] assert "--upload-checksums off" in mount_command[2] - assert mount_command[2].endswith("bucket /workspace/remote") + assert "bucket /workspace/remote" in mount_command[2] @pytest.mark.asyncio @@ -416,19 +443,29 @@ async def test_s3_mountpoint_writable_mode_enables_overwrite_and_delete() -> Non ["sh", "-lc", "command -v mount-s3 >/dev/null 2>&1"], ["mkdir", "-p", "/workspace/remote"], ] - assert len(session.exec_calls) == 3 - - mount_command = session.exec_calls[2] + assert len(session.exec_calls) == 5 + assert len(session.write_calls) == 1 + env_path, env_payload = session.write_calls[0] + assert env_path.as_posix().startswith(".sandbox-mountpoint-env/") + assert env_path.name.endswith(".env") + assert env_payload == ( + b"export AWS_ACCESS_KEY_ID=access\n" + b"export AWS_SECRET_ACCESS_KEY=secret\n" + b"export AWS_SESSION_TOKEN=token\n" + ) + + mount_command = session.exec_calls[-1] assert mount_command[:2] == ["sh", "-lc"] assert "mount-s3" in mount_command[2] assert "--read-only" not in mount_command[2] assert "--allow-overwrite" in mount_command[2] assert "--allow-delete" in mount_command[2] assert "--region us-east-1" in mount_command[2] - assert "AWS_ACCESS_KEY_ID=access" in mount_command[2] - assert "AWS_SECRET_ACCESS_KEY=secret" in mount_command[2] - assert "AWS_SESSION_TOKEN=token" in mount_command[2] - assert mount_command[2].endswith("bucket /workspace/remote") + assert "AWS_ACCESS_KEY_ID=access" not in mount_command[2] + assert "AWS_SECRET_ACCESS_KEY=secret" not in mount_command[2] + assert "AWS_SESSION_TOKEN=token" not in mount_command[2] + assert ".sandbox-mountpoint-env" in mount_command[2] + assert "bucket /workspace/remote" in mount_command[2] @pytest.mark.asyncio @@ -456,9 +493,14 @@ async def test_gcs_mountpoint_writable_mode_enables_overwrite_and_delete() -> No ["sh", "-lc", "command -v mount-s3 >/dev/null 2>&1"], ["mkdir", "-p", "/workspace/remote"], ] - assert len(session.exec_calls) == 3 - - mount_command = session.exec_calls[2] + assert len(session.exec_calls) == 5 + assert len(session.write_calls) == 1 + env_path, env_payload = session.write_calls[0] + assert env_path.as_posix().startswith(".sandbox-mountpoint-env/") + assert env_path.name.endswith(".env") + assert env_payload == b"export AWS_ACCESS_KEY_ID=access\nexport AWS_SECRET_ACCESS_KEY=secret\n" + + mount_command = session.exec_calls[-1] assert mount_command[:2] == ["sh", "-lc"] assert "mount-s3" in mount_command[2] assert "--read-only" not in mount_command[2] @@ -467,9 +509,58 @@ async def test_gcs_mountpoint_writable_mode_enables_overwrite_and_delete() -> No assert "--region us-east1" in mount_command[2] assert "--endpoint-url https://storage.googleapis.com" in mount_command[2] assert "--upload-checksums off" in mount_command[2] - assert "AWS_ACCESS_KEY_ID=access" in mount_command[2] - assert "AWS_SECRET_ACCESS_KEY=secret" in mount_command[2] - assert mount_command[2].endswith("bucket /workspace/remote") + assert "AWS_ACCESS_KEY_ID=access" not in mount_command[2] + assert "AWS_SECRET_ACCESS_KEY=secret" not in mount_command[2] + assert ".sandbox-mountpoint-env" in mount_command[2] + assert "bucket /workspace/remote" in mount_command[2] + + +@pytest.mark.asyncio +async def test_s3_mountpoint_failure_redacts_credentials_from_errors_and_events() -> None: + events: list[SandboxSessionEvent] = [] + inner = _MountpointApplySession( + mount_exit_code=1, + mount_stderr=b"bad credentials: access secret token", + ) + session = SandboxSession( + inner, + instrumentation=Instrumentation( + sinks=[CallbackSink(lambda event, _session: events.append(event))] + ), + ) + pattern = MountpointMountPattern() + + with pytest.raises(MountCommandError) as exc_info: + await pattern.apply( + session, + Path("/workspace/remote"), + MountpointMountConfig( + bucket="bucket", + access_key_id="access", + secret_access_key="secret", + session_token="token", + prefix=None, + region="us-east-1", + endpoint_url=None, + mount_type="s3_mount", + read_only=False, + ), + ) + + context = exc_info.value.context + command = str(context["command"]) + stderr = str(context["stderr"]) + assert "REDACTED" in stderr + assert ".sandbox-mountpoint-env" in command + assert any( + path.as_posix().startswith(".sandbox-mountpoint-env/") + for path in inner.persist_workspace_skip_paths() + ) + serialized_events = "\n".join(event.model_dump_json() for event in events) + for sensitive_value in ("access", "secret", "token"): + assert sensitive_value not in command + assert sensitive_value not in stderr + assert sensitive_value not in serialized_events @pytest.mark.asyncio diff --git a/uv.lock b/uv.lock index e52cdb6636..d9d25ce302 100644 --- a/uv.lock +++ b/uv.lock @@ -9,7 +9,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-05-05T23:48:07Z" +exclude-newer = "2026-05-09T02:05:26Z" [[package]] name = "aiofiles"