Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 75 additions & 7 deletions src/agents/sandbox/entries/mounts/patterns.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import abc
import hashlib
import io
import re
import shlex
Expand Down Expand Up @@ -108,6 +109,41 @@ async def _write_sensitive_config_file(
)


def _render_shell_exports(env_vars: list[tuple[str, str]]) -> bytes:
lines = [f"export {name}={shlex.quote(value)}" for name, value in env_vars]
return ("\n".join(lines) + "\n").encode("utf-8")


def _redact_sensitive_values(text: str, sensitive_values: list[str]) -> str:
redacted = text
for value in sensitive_values:
if not value:
continue
redacted = redacted.replace(value, "REDACTED")
quoted = shlex.quote(value)
if quoted != value:
redacted = redacted.replace(quoted, "REDACTED")
return redacted


async def _read_text_if_present(session: BaseSandboxSession, path: Path) -> str:
try:
handle = await session.read(path)
except Exception:
return ""

try:
raw = handle.read()
finally:
handle.close()

if isinstance(raw, bytes):
return raw.decode("utf-8", errors="replace")
if isinstance(raw, str):
return raw
return str(raw)


class MountPatternBase(BaseModel, abc.ABC):
@abc.abstractmethod
async def apply(
Expand Down Expand Up @@ -425,25 +461,57 @@ async def apply(
cmd.extend(["--prefix", mountpoint_config.prefix])
cmd.extend([bucket, sandbox_path_str(path)])

env_parts: list[str] = []
env_vars: list[tuple[str, str]] = []
access_key_id = mountpoint_config.access_key_id
secret_access_key = mountpoint_config.secret_access_key
session_token = mountpoint_config.session_token
if access_key_id and secret_access_key:
env_parts.append(f"AWS_ACCESS_KEY_ID={shlex.quote(access_key_id)}")
env_parts.append(f"AWS_SECRET_ACCESS_KEY={shlex.quote(secret_access_key)}")
env_vars.append(("AWS_ACCESS_KEY_ID", access_key_id))
env_vars.append(("AWS_SECRET_ACCESS_KEY", secret_access_key))
if session_token:
env_parts.append(f"AWS_SESSION_TOKEN={shlex.quote(session_token)}")
env_vars.append(("AWS_SESSION_TOKEN", session_token))

joined_cmd = " ".join(shlex.quote(part) for part in cmd)
if env_parts:
joined_cmd = f"{' '.join(env_parts)} {joined_cmd}"
stderr_path: Path | None = None
sensitive_values = [value for _name, value in env_vars]
if env_vars:
session_id = getattr(session.state, "session_id", None)
if session_id is None:
raise MountConfigError(
message="mount session is missing session_id",
context={"type": mountpoint_config.mount_type},
)
command_hash = hashlib.sha256(
f"{bucket}\0{sandbox_path_str(path)}".encode()
).hexdigest()[:16]
config_dir = posix_path_as_path(
coerce_posix_path(f".sandbox-mountpoint-env/{session_id.hex}")
)
env_path = config_dir / f"{command_hash}.env"
stdout_path = config_dir / f"{command_hash}.stdout"
stderr_path = config_dir / f"{command_hash}.stderr"

await session.mkdir(config_dir, parents=True)
session.register_persist_workspace_skip_path(config_dir)
Comment thread
seratch marked this conversation as resolved.
await _write_sensitive_config_file(session, env_path, _render_shell_exports(env_vars))

command_env_path = sandbox_path_str(session.normalize_path(env_path))
command_stdout_path = sandbox_path_str(session.normalize_path(stdout_path))
command_stderr_path = sandbox_path_str(session.normalize_path(stderr_path))
joined_cmd = (
f". {shlex.quote(command_env_path)} && exec {joined_cmd} "
f">{shlex.quote(command_stdout_path)} 2>{shlex.quote(command_stderr_path)}"
)

result = await session.exec("sh", "-lc", joined_cmd, shell=False)
if not result.ok():
stderr = result.stderr.decode("utf-8", errors="replace")
if stderr_path is not None:
stderr += await _read_text_if_present(session, stderr_path)
stderr = _redact_sensitive_values(stderr, sensitive_values)
raise MountCommandError(
command=joined_cmd,
stderr=result.stderr.decode("utf-8", errors="replace"),
stderr=stderr,
context={"bucket": bucket},
)

Expand Down
3 changes: 3 additions & 0 deletions src/agents/sandbox/session/sandbox_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ def _set_archive_limits(self, limits: SandboxArchiveLimits | None) -> None:
def normalize_path(self, path: Path | str, *, for_write: bool = False) -> Path:
return self._inner.normalize_path(path, for_write=for_write)

def register_persist_workspace_skip_path(self, path: Path | str) -> Path:
return self._inner.register_persist_workspace_skip_path(path)

def supports_pty(self) -> bool:
return self._inner.supports_pty()

Expand Down
133 changes: 112 additions & 21 deletions tests/sandbox/test_mounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@
RcloneMountConfig,
S3FilesMountConfig,
)
from agents.sandbox.errors import MountConfigError
from agents.sandbox.errors import MountCommandError, MountConfigError
from agents.sandbox.session.base_sandbox_session import BaseSandboxSession
from agents.sandbox.session.events import SandboxSessionEvent
from agents.sandbox.session.manager import Instrumentation
from agents.sandbox.session.sandbox_session import SandboxSession
from agents.sandbox.session.sinks import CallbackSink
from agents.sandbox.snapshot import NoopSnapshot
from agents.sandbox.types import ExecResult
from tests.utils.factories import TestSessionState
Expand Down Expand Up @@ -76,13 +80,16 @@ async def hydrate_workspace(self, data: io.IOBase) -> None:


class _MountpointApplySession(BaseSandboxSession):
def __init__(self) -> None:
def __init__(self, *, mount_exit_code: int = 0, mount_stderr: bytes = b"") -> None:
self.state = TestSessionState(
session_id=uuid.uuid4(),
manifest=Manifest(root="/workspace"),
snapshot=NoopSnapshot(id=str(uuid.uuid4())),
)
self._mount_exit_code = mount_exit_code
self._mount_stderr = mount_stderr
self.exec_calls: list[list[str]] = []
self.write_calls: list[tuple[Path, bytes]] = []

async def read(self, path: Path, *, user: object = None) -> io.BytesIO:
_ = (path, user)
Expand All @@ -92,12 +99,15 @@ async def shutdown(self) -> None:
return None

async def write(self, path: Path, data: io.IOBase, *, user: object = None) -> None:
_ = (path, data, user)
raise AssertionError("write() should not be called in these tests")
_ = user
self.write_calls.append((path, data.read()))

async def running(self) -> bool:
return True

def persist_workspace_skip_paths(self) -> set[Path]:
return self._persist_workspace_skip_relpaths()

async def _exec_internal(
self,
*command: str | Path,
Expand All @@ -106,6 +116,15 @@ async def _exec_internal(
_ = timeout
command_strs = [str(part) for part in command]
self.exec_calls.append(command_strs)
if (
len(command_strs) >= 3
and command_strs[:2] == ["sh", "-lc"]
and "mount-s3 " in command_strs[2]
and "command -v " not in command_strs[2]
):
return ExecResult(
exit_code=self._mount_exit_code, stdout=b"", stderr=self._mount_stderr
)
return ExecResult(exit_code=0, stdout=b"", stderr=b"")

async def persist_workspace(self) -> io.IOBase:
Expand Down Expand Up @@ -380,15 +399,23 @@ async def test_gcs_mount_uses_runtime_endpoint_override_without_mutating_pattern
["sh", "-lc", "command -v mount-s3 >/dev/null 2>&1"],
["mkdir", "-p", "/workspace/remote"],
]
assert len(session.exec_calls) == 3

mount_command = session.exec_calls[2]
assert len(session.exec_calls) == 5
assert len(session.write_calls) == 1
env_path, env_payload = session.write_calls[0]
assert env_path.as_posix().startswith(".sandbox-mountpoint-env/")
assert env_path.name.endswith(".env")
assert env_payload == b"export AWS_ACCESS_KEY_ID=access\nexport AWS_SECRET_ACCESS_KEY=secret\n"

mount_command = session.exec_calls[-1]
assert mount_command[:2] == ["sh", "-lc"]
assert "mount-s3" in mount_command[2]
assert "AWS_ACCESS_KEY_ID=access" not in mount_command[2]
assert "AWS_SECRET_ACCESS_KEY=secret" not in mount_command[2]
assert ".sandbox-mountpoint-env" in mount_command[2]
assert "--region us-east1" in mount_command[2]
assert "--endpoint-url https://storage.googleapis.com" in mount_command[2]
assert "--upload-checksums off" in mount_command[2]
assert mount_command[2].endswith("bucket /workspace/remote")
assert "bucket /workspace/remote" in mount_command[2]


@pytest.mark.asyncio
Expand Down Expand Up @@ -416,19 +443,29 @@ async def test_s3_mountpoint_writable_mode_enables_overwrite_and_delete() -> Non
["sh", "-lc", "command -v mount-s3 >/dev/null 2>&1"],
["mkdir", "-p", "/workspace/remote"],
]
assert len(session.exec_calls) == 3

mount_command = session.exec_calls[2]
assert len(session.exec_calls) == 5
assert len(session.write_calls) == 1
env_path, env_payload = session.write_calls[0]
assert env_path.as_posix().startswith(".sandbox-mountpoint-env/")
assert env_path.name.endswith(".env")
assert env_payload == (
b"export AWS_ACCESS_KEY_ID=access\n"
b"export AWS_SECRET_ACCESS_KEY=secret\n"
b"export AWS_SESSION_TOKEN=token\n"
)

mount_command = session.exec_calls[-1]
assert mount_command[:2] == ["sh", "-lc"]
assert "mount-s3" in mount_command[2]
assert "--read-only" not in mount_command[2]
assert "--allow-overwrite" in mount_command[2]
assert "--allow-delete" in mount_command[2]
assert "--region us-east-1" in mount_command[2]
assert "AWS_ACCESS_KEY_ID=access" in mount_command[2]
assert "AWS_SECRET_ACCESS_KEY=secret" in mount_command[2]
assert "AWS_SESSION_TOKEN=token" in mount_command[2]
assert mount_command[2].endswith("bucket /workspace/remote")
assert "AWS_ACCESS_KEY_ID=access" not in mount_command[2]
assert "AWS_SECRET_ACCESS_KEY=secret" not in mount_command[2]
assert "AWS_SESSION_TOKEN=token" not in mount_command[2]
assert ".sandbox-mountpoint-env" in mount_command[2]
assert "bucket /workspace/remote" in mount_command[2]


@pytest.mark.asyncio
Expand Down Expand Up @@ -456,9 +493,14 @@ async def test_gcs_mountpoint_writable_mode_enables_overwrite_and_delete() -> No
["sh", "-lc", "command -v mount-s3 >/dev/null 2>&1"],
["mkdir", "-p", "/workspace/remote"],
]
assert len(session.exec_calls) == 3

mount_command = session.exec_calls[2]
assert len(session.exec_calls) == 5
assert len(session.write_calls) == 1
env_path, env_payload = session.write_calls[0]
assert env_path.as_posix().startswith(".sandbox-mountpoint-env/")
assert env_path.name.endswith(".env")
assert env_payload == b"export AWS_ACCESS_KEY_ID=access\nexport AWS_SECRET_ACCESS_KEY=secret\n"

mount_command = session.exec_calls[-1]
assert mount_command[:2] == ["sh", "-lc"]
assert "mount-s3" in mount_command[2]
assert "--read-only" not in mount_command[2]
Expand All @@ -467,9 +509,58 @@ async def test_gcs_mountpoint_writable_mode_enables_overwrite_and_delete() -> No
assert "--region us-east1" in mount_command[2]
assert "--endpoint-url https://storage.googleapis.com" in mount_command[2]
assert "--upload-checksums off" in mount_command[2]
assert "AWS_ACCESS_KEY_ID=access" in mount_command[2]
assert "AWS_SECRET_ACCESS_KEY=secret" in mount_command[2]
assert mount_command[2].endswith("bucket /workspace/remote")
assert "AWS_ACCESS_KEY_ID=access" not in mount_command[2]
assert "AWS_SECRET_ACCESS_KEY=secret" not in mount_command[2]
assert ".sandbox-mountpoint-env" in mount_command[2]
assert "bucket /workspace/remote" in mount_command[2]


@pytest.mark.asyncio
async def test_s3_mountpoint_failure_redacts_credentials_from_errors_and_events() -> None:
events: list[SandboxSessionEvent] = []
inner = _MountpointApplySession(
mount_exit_code=1,
mount_stderr=b"bad credentials: access secret token",
)
session = SandboxSession(
inner,
instrumentation=Instrumentation(
sinks=[CallbackSink(lambda event, _session: events.append(event))]
),
)
pattern = MountpointMountPattern()

with pytest.raises(MountCommandError) as exc_info:
await pattern.apply(
session,
Path("/workspace/remote"),
MountpointMountConfig(
bucket="bucket",
access_key_id="access",
secret_access_key="secret",
session_token="token",
prefix=None,
region="us-east-1",
endpoint_url=None,
mount_type="s3_mount",
read_only=False,
),
)

context = exc_info.value.context
command = str(context["command"])
stderr = str(context["stderr"])
assert "REDACTED" in stderr
assert ".sandbox-mountpoint-env" in command
assert any(
path.as_posix().startswith(".sandbox-mountpoint-env/")
for path in inner.persist_workspace_skip_paths()
)
serialized_events = "\n".join(event.model_dump_json() for event in events)
for sensitive_value in ("access", "secret", "token"):
assert sensitive_value not in command
assert sensitive_value not in stderr
assert sensitive_value not in serialized_events


@pytest.mark.asyncio
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading