Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 140 additions & 20 deletions src/agents/sandbox/entries/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
GitCloneError,
GitCopyError,
GitMissingInImageError,
LocalChecksumError,
LocalArtifactError,
LocalDirReadError,
LocalFileReadError,
WorkspaceArchiveWriteError,
)
from ..materialization import MaterializedFile, gather_in_order
from ..types import ExecResult, User
from ..util.checksums import sha256_file
from .base import BaseEntry

if TYPE_CHECKING:
Expand All @@ -34,14 +34,111 @@
_HAS_O_DIRECTORY = hasattr(os, "O_DIRECTORY")


def _sha256_handle(handle: io.BufferedReader) -> str:
digest = hashlib.sha256()
while True:
chunk = handle.read(1024 * 1024)
class _HashingReader(io.IOBase):
def __init__(
self,
stream: io.BufferedReader,
*,
read_error_factory: Callable[[OSError], BaseException] | None = None,
) -> None:
self._stream = stream
self._digest = hashlib.sha256()
self._started = False
self._finished = False
self._read_error_factory = read_error_factory

def readable(self) -> bool:
return True

def read(self, size: int = -1) -> bytes:
try:
chunk = self._stream.read(size)
except OSError as e:
if self._read_error_factory is not None:
raise self._read_error_factory(e) from e
raise
if chunk is None:
self._finished = True
return b""
if isinstance(chunk, bytearray):
chunk = bytes(chunk)
self._started = True
if not chunk:
break
digest.update(chunk)
return digest.hexdigest()
self._finished = True
return b""
self._digest.update(chunk)
if size < 0 or len(chunk) < size:
self._finished = True
return chunk

def readinto(self, b: bytearray) -> int:
data = self.read(len(b))
n = len(data)
b[:n] = data
return n

def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
if self._started:
raise io.UnsupportedOperation("cannot seek after reads begin")
try:
return int(self._stream.seek(offset, whence))
except OSError as e:
if self._read_error_factory is not None:
raise self._read_error_factory(e) from e
raise

def tell(self) -> int:
try:
return int(self._stream.tell())
except OSError as e:
if self._read_error_factory is not None:
raise self._read_error_factory(e) from e
raise

def hexdigest(self) -> str:
if not self._finished:
raise RuntimeError("checksum is not available until the stream is fully consumed")
return self._digest.hexdigest()


def _find_nested_local_artifact_error(exc: BaseException) -> LocalArtifactError | None:
seen: set[int] = set()
current: BaseException | None = exc
while current is not None and id(current) not in seen:
if isinstance(current, LocalArtifactError):
return current
seen.add(id(current))
next_exc = getattr(current, "cause", None)
if not isinstance(next_exc, BaseException):
next_exc = current.__cause__
current = next_exc if isinstance(next_exc, BaseException) else None
return None


def _reraise_nested_local_artifact_error(exc: BaseException) -> None:
nested_local_artifact_error = _find_nested_local_artifact_error(exc)
if nested_local_artifact_error is not None:
raise nested_local_artifact_error


async def _write_hashed_local_artifact(
*,
session: BaseSandboxSession,
dest: Path,
src: Path,
src_handle: io.BufferedReader,
user: str | User | None = None,
) -> str:
hashing_reader = _HashingReader(
src_handle,
read_error_factory=lambda e: LocalFileReadError(src=src, cause=e),
)
try:
await session.write(dest, hashing_reader, user=user)
except WorkspaceArchiveWriteError as e:
_reraise_nested_local_artifact_error(e)
raise
return hashing_reader.hexdigest()


class Dir(BaseEntry):
Expand Down Expand Up @@ -109,17 +206,36 @@ async def apply(
dest: Path,
base_dir: Path,
) -> list[MaterializedFile]:
src = (base_dir / self.src).resolve()
try:
checksum = sha256_file(src)
except OSError as e:
raise LocalChecksumError(src=src, cause=e) from e
await session.mkdir(Path(dest).parent, parents=True)
src = base_dir / self.src
src = src if src.is_absolute() else src.absolute()
local_dir = LocalDir(src=self.src.parent)
rel_child = Path(self.src.name)
fd: int | None = None
try:
with src.open("rb") as f:
await session.write(dest, f)
src_root = local_dir._resolve_local_dir_src_root(base_dir)
fd = local_dir._open_local_dir_file_for_copy(
base_dir=base_dir,
src_root=src_root,
rel_child=rel_child,
)
with os.fdopen(fd, "rb") as f:
fd = None
await session.mkdir(Path(dest).parent, parents=True)
checksum = await _write_hashed_local_artifact(
session=session,
dest=dest,
src=src,
src_handle=f,
)
except LocalDirReadError as e:
context = dict(e.context)
context.pop("src", None)
raise LocalFileReadError(src=src, context=context, cause=e.cause) from e
except OSError as e:
raise LocalFileReadError(src=src, cause=e) from e
finally:
if fd is not None:
os.close(fd)
await self._apply_metadata(session, dest)
return [MaterializedFile(path=dest, sha256=checksum)]

Expand Down Expand Up @@ -349,10 +465,14 @@ async def _copy_local_dir_file(
)
with os.fdopen(fd, "rb") as f:
fd = None
checksum = _sha256_handle(f)
f.seek(0)
await session.mkdir(child_dest.parent, parents=True, user=user)
await session.write(child_dest, f, user=user)
checksum = await _write_hashed_local_artifact(
session=session,
dest=child_dest,
src=src,
src_handle=f,
user=user,
)
except OSError as e:
raise LocalFileReadError(src=src, cause=e) from e
finally:
Expand Down
Loading