diff --git a/src/agents/sandbox/entries/artifacts.py b/src/agents/sandbox/entries/artifacts.py index 79c0396da5..83ca441cd4 100644 --- a/src/agents/sandbox/entries/artifacts.py +++ b/src/agents/sandbox/entries/artifacts.py @@ -17,13 +17,13 @@ GitCloneError, GitCopyError, GitMissingInImageError, - LocalChecksumError, + LocalArtifactError, LocalDirReadError, LocalFileReadError, + WorkspaceArchiveWriteError, ) from ..materialization import MaterializedFile, gather_in_order from ..types import ExecResult, User -from ..util.checksums import sha256_file from .base import BaseEntry if TYPE_CHECKING: @@ -34,14 +34,111 @@ _HAS_O_DIRECTORY = hasattr(os, "O_DIRECTORY") -def _sha256_handle(handle: io.BufferedReader) -> str: - digest = hashlib.sha256() - while True: - chunk = handle.read(1024 * 1024) +class _HashingReader(io.IOBase): + def __init__( + self, + stream: io.BufferedReader, + *, + read_error_factory: Callable[[OSError], BaseException] | None = None, + ) -> None: + self._stream = stream + self._digest = hashlib.sha256() + self._started = False + self._finished = False + self._read_error_factory = read_error_factory + + def readable(self) -> bool: + return True + + def read(self, size: int = -1) -> bytes: + try: + chunk = self._stream.read(size) + except OSError as e: + if self._read_error_factory is not None: + raise self._read_error_factory(e) from e + raise + if chunk is None: + self._finished = True + return b"" + if isinstance(chunk, bytearray): + chunk = bytes(chunk) + self._started = True if not chunk: - break - digest.update(chunk) - return digest.hexdigest() + self._finished = True + return b"" + self._digest.update(chunk) + if size < 0 or len(chunk) < size: + self._finished = True + return chunk + + def readinto(self, b: bytearray) -> int: + data = self.read(len(b)) + n = len(data) + b[:n] = data + return n + + def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: + if self._started: + raise io.UnsupportedOperation("cannot seek after reads begin") + try: + return int(self._stream.seek(offset, whence)) + except OSError as e: + if self._read_error_factory is not None: + raise self._read_error_factory(e) from e + raise + + def tell(self) -> int: + try: + return int(self._stream.tell()) + except OSError as e: + if self._read_error_factory is not None: + raise self._read_error_factory(e) from e + raise + + def hexdigest(self) -> str: + if not self._finished: + raise RuntimeError("checksum is not available until the stream is fully consumed") + return self._digest.hexdigest() + + +def _find_nested_local_artifact_error(exc: BaseException) -> LocalArtifactError | None: + seen: set[int] = set() + current: BaseException | None = exc + while current is not None and id(current) not in seen: + if isinstance(current, LocalArtifactError): + return current + seen.add(id(current)) + next_exc = getattr(current, "cause", None) + if not isinstance(next_exc, BaseException): + next_exc = current.__cause__ + current = next_exc if isinstance(next_exc, BaseException) else None + return None + + +def _reraise_nested_local_artifact_error(exc: BaseException) -> None: + nested_local_artifact_error = _find_nested_local_artifact_error(exc) + if nested_local_artifact_error is not None: + raise nested_local_artifact_error + + +async def _write_hashed_local_artifact( + *, + session: BaseSandboxSession, + dest: Path, + src: Path, + src_handle: io.BufferedReader, + user: str | User | None = None, +) -> str: + hashing_reader = _HashingReader( + src_handle, + read_error_factory=lambda e: LocalFileReadError(src=src, cause=e), + ) + try: + await session.write(dest, hashing_reader, user=user) + except WorkspaceArchiveWriteError as e: + _reraise_nested_local_artifact_error(e) + raise + return hashing_reader.hexdigest() class Dir(BaseEntry): @@ -109,17 +206,36 @@ async def apply( dest: Path, base_dir: Path, ) -> list[MaterializedFile]: - src = (base_dir / self.src).resolve() - try: - checksum = sha256_file(src) - except OSError as e: - raise LocalChecksumError(src=src, cause=e) from e - await session.mkdir(Path(dest).parent, parents=True) + src = base_dir / self.src + src = src if src.is_absolute() else src.absolute() + local_dir = LocalDir(src=self.src.parent) + rel_child = Path(self.src.name) + fd: int | None = None try: - with src.open("rb") as f: - await session.write(dest, f) + src_root = local_dir._resolve_local_dir_src_root(base_dir) + fd = local_dir._open_local_dir_file_for_copy( + base_dir=base_dir, + src_root=src_root, + rel_child=rel_child, + ) + with os.fdopen(fd, "rb") as f: + fd = None + await session.mkdir(Path(dest).parent, parents=True) + checksum = await _write_hashed_local_artifact( + session=session, + dest=dest, + src=src, + src_handle=f, + ) + except LocalDirReadError as e: + context = dict(e.context) + context.pop("src", None) + raise LocalFileReadError(src=src, context=context, cause=e.cause) from e except OSError as e: raise LocalFileReadError(src=src, cause=e) from e + finally: + if fd is not None: + os.close(fd) await self._apply_metadata(session, dest) return [MaterializedFile(path=dest, sha256=checksum)] @@ -349,10 +465,14 @@ async def _copy_local_dir_file( ) with os.fdopen(fd, "rb") as f: fd = None - checksum = _sha256_handle(f) - f.seek(0) await session.mkdir(child_dest.parent, parents=True, user=user) - await session.write(child_dest, f, user=user) + checksum = await _write_hashed_local_artifact( + session=session, + dest=child_dest, + src=src, + src_handle=f, + user=user, + ) except OSError as e: raise LocalFileReadError(src=src, cause=e) from e finally: diff --git a/tests/sandbox/test_entries.py b/tests/sandbox/test_entries.py index ecba9d5b2c..a1ce36506d 100644 --- a/tests/sandbox/test_entries.py +++ b/tests/sandbox/test_entries.py @@ -1,5 +1,6 @@ from __future__ import annotations +import hashlib import io import os from collections.abc import Awaitable, Callable, Sequence @@ -10,10 +11,16 @@ import agents.sandbox.entries.artifacts as artifacts_module from agents.sandbox import SandboxConcurrencyLimits from agents.sandbox.entries import Dir, File, GitRepo, LocalDir, LocalFile -from agents.sandbox.errors import ExecNonZeroError, LocalDirReadError +from agents.sandbox.errors import ( + ExecNonZeroError, + LocalDirReadError, + LocalFileReadError, + WorkspaceArchiveWriteError, +) from agents.sandbox.manifest import Manifest from agents.sandbox.materialization import MaterializedFile from agents.sandbox.session.base_sandbox_session import BaseSandboxSession +from agents.sandbox.session.workspace_payloads import coerce_write_payload from agents.sandbox.snapshot import NoopSnapshot from agents.sandbox.types import ExecResult, User from tests.utils.factories import TestSessionState @@ -98,6 +105,120 @@ async def _exec_internal( return ExecResult(stdout=b"", stderr=b"", exit_code=0) +class _MutatingWriteSession(_RecordingSession): + def __init__(self, mutate_before_read: Callable[[], None]) -> None: + super().__init__() + self._mutate_before_read = mutate_before_read + self._mutated = False + + async def write(self, path: Path, data: io.IOBase, *, user: object = None) -> None: + if not self._mutated: + self._mutate_before_read() + self._mutated = True + await super().write(path, data, user=user) + + +class _ChunkedMutatingWriteSession(_RecordingSession): + def __init__(self, mutate_after_first_chunk: Callable[[], None]) -> None: + super().__init__() + self._mutate_after_first_chunk = mutate_after_first_chunk + self._mutated = False + + async def write(self, path: Path, data: io.IOBase, *, user: object = None) -> None: + _ = user + chunks: list[bytes] = [] + first = data.read(4) + if isinstance(first, bytes): + chunks.append(first) + if not self._mutated: + self._mutate_after_first_chunk() + self._mutated = True + rest = data.read() + if isinstance(rest, bytes): + chunks.append(rest) + self.writes[path] = b"".join(chunks) + + +class _PayloadWrappingWriteSession(_RecordingSession): + async def write(self, path: Path, data: io.IOBase, *, user: object = None) -> None: + _ = user + payload = coerce_write_payload(path=path, data=data) + chunks: list[bytes] = [] + try: + while True: + chunk = payload.stream.read(4) + if not chunk: + break + chunks.append(chunk) + except Exception as e: + raise WorkspaceArchiveWriteError(path=path, cause=e) from e + self.writes[path] = b"".join(chunks) + + +class _StagedFailureAfterReadSession(_RecordingSession): + def __init__(self) -> None: + super().__init__() + self.removed: list[Path] = [] + self.staged_writes: dict[Path, bytes] = {} + + async def write(self, path: Path, data: io.IOBase, *, user: object = None) -> None: + _ = user + chunks: list[bytes] = [] + while True: + chunk = data.read(4) + if not chunk: + break + chunks.append(chunk) + staged_path = path.with_name(f".{path.name}.staged") + self.staged_writes[staged_path] = b"".join(chunks) + raise WorkspaceArchiveWriteError( + path=path, + context={"reason": "final_install_failed"}, + ) + + async def rm( + self, + path: Path | str, + *, + recursive: bool = False, + user: object = None, + ) -> None: + _ = recursive, user + normalized = Path(path) + self.removed.append(normalized) + self.writes.pop(normalized, None) + + +class _FailAfterChunkStream(io.BytesIO): + def __init__(self, data: bytes, *, owned_fd: int | None = None) -> None: + super().__init__(data) + self._owned_fd = owned_fd + self._read_count = 0 + + def read(self, size: int | None = -1) -> bytes: + if self._read_count > 0: + raise OSError("source read failed") + self._read_count += 1 + return super().read(-1 if size is None else size) + + def close(self) -> None: + try: + super().close() + finally: + if self._owned_fd is not None: + os.close(self._owned_fd) + self._owned_fd = None + + +def _symlink_or_skip(path: Path, target: Path, *, target_is_directory: bool = False) -> None: + try: + path.symlink_to(target, target_is_directory=target_is_directory) + except OSError as e: + if os.name == "nt" and getattr(e, "winerror", None) == 1314: + pytest.skip("symlink creation requires elevated privileges on Windows") + raise + + @pytest.mark.asyncio async def test_base_sandbox_session_uses_current_working_directory_for_local_file_sources( monkeypatch: pytest.MonkeyPatch, @@ -115,9 +236,110 @@ async def test_base_sandbox_session_uses_current_working_directory_for_local_fil result = await session.apply_manifest() assert result.files[0].path == Path("/workspace/copied.txt") + assert result.files[0].sha256 == hashlib.sha256(b"hello").hexdigest() assert session.writes[Path("/workspace/copied.txt")] == b"hello" +@pytest.mark.asyncio +async def test_local_file_checksum_matches_written_bytes_when_source_changes( + tmp_path: Path, +) -> None: + source = tmp_path / "source.txt" + source.write_bytes(b"original") + + def mutate_source() -> None: + source.write_bytes(b"mutated") + + session = _ChunkedMutatingWriteSession(mutate_source) + + result = await LocalFile(src=Path("source.txt")).apply( + session, + Path("/workspace/copied.txt"), + tmp_path, + ) + + written = session.writes[Path("/workspace/copied.txt")] + assert result[0].sha256 == hashlib.sha256(written).hexdigest() + + +@pytest.mark.asyncio +async def test_local_file_does_not_remove_existing_destination_when_staged_write_fails( + tmp_path: Path, +) -> None: + source = tmp_path / "source.txt" + source.write_bytes(b"new content") + dest = Path("/workspace/copied.txt") + session = _StagedFailureAfterReadSession() + session.writes[dest] = b"old content" + + with pytest.raises(WorkspaceArchiveWriteError): + await LocalFile(src=Path("source.txt")).apply(session, dest, tmp_path) + + assert session.writes[dest] == b"old content" + assert session.removed == [] + assert session.staged_writes[Path("/workspace/.copied.txt.staged")] == b"new content" + + +@pytest.mark.asyncio +async def test_local_file_rejects_symlinked_source_ancestors(tmp_path: Path) -> None: + target_dir = tmp_path / "secret-dir" + target_dir.mkdir() + nested_dir = target_dir / "sub" + nested_dir.mkdir() + (nested_dir / "secret.txt").write_text("secret", encoding="utf-8") + _symlink_or_skip(tmp_path / "link", target_dir, target_is_directory=True) + session = _RecordingSession() + + with pytest.raises(LocalFileReadError) as excinfo: + await LocalFile(src=Path("link/sub/secret.txt")).apply( + session, + Path("/workspace/copied.txt"), + tmp_path, + ) + + assert excinfo.value.context["reason"] == "symlink_not_supported" + assert excinfo.value.context["child"] == "link" + assert session.writes == {} + + +@pytest.mark.asyncio +async def test_local_file_rejects_symlinked_source_leaf(tmp_path: Path) -> None: + secret = tmp_path / "secret.txt" + secret.write_text("secret", encoding="utf-8") + _symlink_or_skip(tmp_path / "link.txt", secret) + session = _RecordingSession() + + with pytest.raises(LocalFileReadError) as excinfo: + await LocalFile(src=Path("link.txt")).apply( + session, + Path("/workspace/copied.txt"), + tmp_path, + ) + + assert excinfo.value.context["reason"] == "symlink_not_supported" + assert excinfo.value.context["child"] == "link.txt" + assert session.writes == {} + + +@pytest.mark.asyncio +async def test_local_file_rejects_symlinked_source_before_checksum(tmp_path: Path) -> None: + target_dir = tmp_path / "secret-dir" + target_dir.mkdir() + _symlink_or_skip(tmp_path / "link.txt", target_dir, target_is_directory=True) + session = _RecordingSession() + + with pytest.raises(LocalFileReadError) as excinfo: + await LocalFile(src=Path("link.txt")).apply( + session, + Path("/workspace/copied.txt"), + tmp_path, + ) + + assert excinfo.value.context["reason"] == "symlink_not_supported" + assert excinfo.value.context["child"] == "link.txt" + assert session.writes == {} + + @pytest.mark.asyncio async def test_local_dir_copy_falls_back_when_safe_dir_fd_open_unavailable( monkeypatch: pytest.MonkeyPatch, @@ -145,11 +367,122 @@ async def test_local_dir_copy_falls_back_when_safe_dir_fd_open_unavailable( assert session.writes[Path("/workspace/copied/safe.txt")] == b"safe" +@pytest.mark.asyncio +async def test_local_dir_checksum_matches_written_bytes_when_source_changes( + tmp_path: Path, +) -> None: + src_root = tmp_path / "src" + src_root.mkdir() + src_file = src_root / "safe.txt" + src_file.write_bytes(b"original") + + def mutate_source() -> None: + src_file.write_bytes(b"mutated") + + session = _ChunkedMutatingWriteSession(mutate_source) + local_dir = LocalDir(src=Path("src")) + + result = await local_dir._copy_local_dir_file( + base_dir=tmp_path, + session=session, + src_root=src_root, + src=src_file, + dest_root=Path("/workspace/copied"), + ) + + written = session.writes[Path("/workspace/copied/safe.txt")] + assert result.sha256 == hashlib.sha256(written).hexdigest() + + +@pytest.mark.asyncio +async def test_local_dir_does_not_remove_existing_destination_when_staged_write_fails( + tmp_path: Path, +) -> None: + src_root = tmp_path / "src" + src_root.mkdir() + src_file = src_root / "safe.txt" + src_file.write_bytes(b"new content") + dest = Path("/workspace/copied") + child_dest = dest / "safe.txt" + session = _StagedFailureAfterReadSession() + session.writes[child_dest] = b"old content" + + with pytest.raises(WorkspaceArchiveWriteError): + await LocalDir(src=Path("src")).apply(session, dest, tmp_path) + + assert session.writes[child_dest] == b"old content" + assert session.removed == [] + assert session.staged_writes[Path("/workspace/copied/.safe.txt.staged")] == b"new content" + + +@pytest.mark.asyncio +async def test_local_file_preserves_local_read_error_when_write_wraps_stream_failures( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + source = (tmp_path / "source.txt").resolve() + source.write_bytes(b"original") + session = _PayloadWrappingWriteSession() + + def failing_fdopen( + fd: int, + *args: object, + **kwargs: object, + ) -> io.IOBase: + _ = args, kwargs + return _FailAfterChunkStream(b"original", owned_fd=fd) + + monkeypatch.setattr("agents.sandbox.entries.artifacts.os.fdopen", failing_fdopen) + + with pytest.raises(LocalFileReadError) as excinfo: + await LocalFile(src=Path("source.txt")).apply( + session, + Path("/workspace/copied.txt"), + tmp_path, + ) + + assert excinfo.value.context["src"] == str(source) + assert isinstance(excinfo.value.cause, OSError) + + +@pytest.mark.asyncio +async def test_local_dir_copy_preserves_local_read_error_when_write_wraps_stream_failures( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + src_root = tmp_path / "src" + src_root.mkdir() + src_file = (src_root / "safe.txt").resolve() + src_file.write_bytes(b"original") + session = _PayloadWrappingWriteSession() + local_dir = LocalDir(src=Path("src")) + + def failing_fdopen(fd: int, *args: object, **kwargs: object) -> io.IOBase: + return _FailAfterChunkStream(b"original", owned_fd=fd) + + monkeypatch.setattr("agents.sandbox.entries.artifacts.os.fdopen", failing_fdopen) + + with pytest.raises(LocalFileReadError) as excinfo: + await local_dir._copy_local_dir_file( + base_dir=tmp_path, + session=session, + src_root=src_root, + src=src_file, + dest_root=Path("/workspace/copied"), + ) + + assert excinfo.value.context["src"] == str(src_file) + assert isinstance(excinfo.value.cause, OSError) + + @pytest.mark.asyncio async def test_local_dir_copy_revalidates_swapped_paths_during_open( monkeypatch: pytest.MonkeyPatch, tmp_path: Path, ) -> None: + if not artifacts_module._OPEN_SUPPORTS_DIR_FD or not artifacts_module._HAS_O_DIRECTORY: + pytest.skip("safe dir_fd open pinning is unavailable on this platform") + src_root = tmp_path / "src" src_root.mkdir() src_file = src_root / "safe.txt" @@ -171,7 +504,7 @@ def swap_then_open( nonlocal swapped if path == "safe.txt" and not swapped: src_file.unlink() - src_file.symlink_to(secret) + _symlink_or_skip(src_file, secret) swapped = True if dir_fd is None: return original_open(path, flags, mode) @@ -201,6 +534,9 @@ async def test_local_dir_copy_pins_parent_directories_during_open( monkeypatch: pytest.MonkeyPatch, tmp_path: Path, ) -> None: + if not artifacts_module._OPEN_SUPPORTS_DIR_FD or not artifacts_module._HAS_O_DIRECTORY: + pytest.skip("safe dir_fd open pinning is unavailable on this platform") + src_root = tmp_path / "src" src_root.mkdir() nested_dir = src_root / "nested" @@ -225,7 +561,7 @@ def swap_parent_then_open( nonlocal swapped if path == "safe.txt" and not swapped: (src_root / "nested").rename(src_root / "nested-original") - (src_root / "nested").symlink_to(secret_dir, target_is_directory=True) + _symlink_or_skip(src_root / "nested", secret_dir, target_is_directory=True) swapped = True if dir_fd is None: return original_open(path, flags, mode) @@ -250,6 +586,9 @@ async def test_local_dir_apply_rejects_source_root_swapped_to_symlink_after_vali monkeypatch: pytest.MonkeyPatch, tmp_path: Path, ) -> None: + if not artifacts_module._OPEN_SUPPORTS_DIR_FD or not artifacts_module._HAS_O_DIRECTORY: + pytest.skip("safe dir_fd open pinning is unavailable on this platform") + src_root = tmp_path / "src" src_root.mkdir() (src_root / "safe.txt").write_text("safe", encoding="utf-8") @@ -271,7 +610,7 @@ def swap_root_then_open( nonlocal swapped if path == "src" and dir_fd is not None and not swapped: src_root.rename(tmp_path / "src-original") - (tmp_path / "src").symlink_to(secret_dir, target_is_directory=True) + _symlink_or_skip(tmp_path / "src", secret_dir, target_is_directory=True) swapped = True if dir_fd is None: return original_open(path, flags, mode) @@ -343,7 +682,7 @@ async def test_local_dir_rejects_symlinked_source_ancestors(tmp_path: Path) -> N nested_dir = target_dir / "sub" nested_dir.mkdir() (nested_dir / "secret.txt").write_text("secret", encoding="utf-8") - (tmp_path / "link").symlink_to(target_dir, target_is_directory=True) + _symlink_or_skip(tmp_path / "link", target_dir, target_is_directory=True) session = _RecordingSession() with pytest.raises(LocalDirReadError) as excinfo: @@ -359,7 +698,7 @@ async def test_local_dir_rejects_symlinked_source_root(tmp_path: Path) -> None: target_dir = tmp_path / "secret-dir" target_dir.mkdir() (target_dir / "secret.txt").write_text("secret", encoding="utf-8") - (tmp_path / "src").symlink_to(target_dir, target_is_directory=True) + _symlink_or_skip(tmp_path / "src", target_dir, target_is_directory=True) session = _RecordingSession() with pytest.raises(LocalDirReadError) as excinfo: @@ -377,7 +716,7 @@ async def test_local_dir_rejects_symlinked_files(tmp_path: Path) -> None: (src_root / "safe.txt").write_text("safe", encoding="utf-8") secret = tmp_path / "secret.txt" secret.write_text("secret", encoding="utf-8") - (src_root / "link.txt").symlink_to(secret) + _symlink_or_skip(src_root / "link.txt", secret) session = _RecordingSession() with pytest.raises(LocalDirReadError) as excinfo: @@ -396,7 +735,7 @@ async def test_local_dir_rejects_symlinked_directories(tmp_path: Path) -> None: target_dir = tmp_path / "secret-dir" target_dir.mkdir() (target_dir / "secret.txt").write_text("secret", encoding="utf-8") - (src_root / "linked-dir").symlink_to(target_dir, target_is_directory=True) + _symlink_or_skip(src_root / "linked-dir", target_dir, target_is_directory=True) session = _RecordingSession() with pytest.raises(LocalDirReadError) as excinfo: @@ -442,10 +781,11 @@ async def test_git_repo_uses_fetch_checkout_path_for_commit_refs() -> None: @pytest.mark.asyncio async def test_dir_metadata_strips_file_type_bits_before_chmod() -> None: session = _RecordingSession() + dest = Path("/workspace/dir") - await Dir()._apply_metadata(session, Path("/workspace/dir")) + await Dir()._apply_metadata(session, dest) - assert ("chmod", "0755", "/workspace/dir") in session.exec_calls + assert ("chmod", "0755", str(dest)) in session.exec_calls @pytest.mark.asyncio @@ -476,5 +816,5 @@ async def test_apply_manifest_raises_on_chgrp_failure() -> None: with pytest.raises(ExecNonZeroError): await session.apply_manifest() - assert ("chgrp", "sandbox-user", "/workspace/copied.txt") in session.exec_calls + assert ("chgrp", "sandbox-user", str(Path("/workspace/copied.txt"))) in session.exec_calls assert not any(call[0] == "chmod" for call in session.exec_calls)