From 2a5b79544473eb79b9484bd9ee4eaa422ebd3daa Mon Sep 17 00:00:00 2001 From: Sean Pryor Date: Wed, 24 Jun 2026 09:23:11 -0400 Subject: [PATCH 1/2] feat(cache): add unified CacheManager subsystem with remote PEP 503 support Introduces a new centralized cache management layer that replaces the ad-hoc per-directory lookups with a structured, hierarchical system. Key components: - `CacheManager` with collection-based lookup across local and remote backends - `RemotePEP503Backend` with lazy per-package fetching and session-scoped index - `StoreRouter` for directing artifacts to the correct collection - Short-circuit optimization in `_phase_prepare_source` that skips source download, build env, and build dep resolution on cache hits - `update_wheel_mirror` called after remote cache downloads so wheels are indexed for subsequent build dependency resolution via the internal server - CLI commands (`fromager cache list/stats/verify/invalidate/gc`) - `--use-cache-manager` and `--cache-wheel-server-url` bootstrap options Co-Authored-By: Claude Signed-off-by: Sean Pryor Co-authored-by: Cursor --- pyproject.toml | 1 + src/fromager/bootstrapper.py | 56 +- src/fromager/cache.py | 752 ++++++++++++++++ src/fromager/commands/bootstrap.py | 17 + src/fromager/commands/cache_cmd.py | 392 +++++++++ src/fromager/context.py | 13 +- src/fromager/requirements_file.py | 1 + tests/test_bootstrapper_iterative.py | 484 ++++++++++ tests/test_cache.py | 1222 ++++++++++++++++++++++++++ 9 files changed, 2932 insertions(+), 6 deletions(-) create mode 100644 src/fromager/cache.py create mode 100644 src/fromager/commands/cache_cmd.py create mode 100644 tests/test_cache.py diff --git a/pyproject.toml b/pyproject.toml index adf93290e..f76f71af1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,6 +126,7 @@ canonicalize = "fromager.commands.canonicalize:canonicalize" download-sequence = "fromager.commands.download_sequence:download_sequence" wheel-server = "fromager.commands.server:wheel_server" lint-requirements = "fromager.commands.lint_requirements:lint_requirements" +cache = "fromager.commands.cache_cmd:cache" [tool.coverage.run] branch = true diff --git a/src/fromager/bootstrapper.py b/src/fromager/bootstrapper.py index 3389a8626..88bfc9494 100644 --- a/src/fromager/bootstrapper.py +++ b/src/fromager/bootstrapper.py @@ -655,15 +655,44 @@ def _find_cached_wheel( ) -> tuple[pathlib.Path | None, pathlib.Path | None]: """Look for cached wheel in 3 locations. - Checks for cached wheels in order: - 1. wheels_build directory (previously built) - 2. wheels_downloads directory (previously downloaded) - 3. Cache server (remote cache) + When a CacheManager is configured on the context, delegates to it + for a unified hierarchical lookup across collections. Otherwise + falls back to the legacy per-directory search. Returns: Tuple of (cached_wheel_filename, unpacked_cached_wheel). Both None if no cache hit. """ + if self.ctx.cache is not None: + return self._find_cached_wheel_via_manager(req, resolved_version) + return self._find_cached_wheel_legacy(req, resolved_version) + + def _find_cached_wheel_via_manager( + self, + req: Requirement, + resolved_version: Version, + ) -> tuple[pathlib.Path | None, pathlib.Path | None]: + """Cache lookup using the CacheManager.""" + assert self.ctx.cache is not None + pbi = self.ctx.package_build_info(req) + build_tag = pbi.build_tag(resolved_version) + + result = self.ctx.cache.lookup_wheel(req, resolved_version, build_tag) + if not result.hit: + return None, None + + assert result.path is not None + metadata_dir = self._unpack_metadata_from_wheel( + req, resolved_version, result.path + ) + return result.path, metadata_dir + + def _find_cached_wheel_legacy( + self, + req: Requirement, + resolved_version: Version, + ) -> tuple[pathlib.Path | None, pathlib.Path | None]: + """Legacy cache lookup: check build dir, downloads dir, remote cache.""" # Check if we have previously built a wheel and still have it on the # local filesystem. cached_wheel, unpacked = self._look_for_existing_wheel( @@ -1493,12 +1522,29 @@ def _phase_prepare_source(self, item: WorkItem) -> list[WorkItem]: item.phase = BootstrapPhase.PROCESS_INSTALL_DEPS return [item] - # Source build path + # Source build path: try cache first cached_wheel, unpacked = self._find_cached_wheel( item.req, item.resolved_version ) item.cached_wheel_filename = cached_wheel + # Short-circuit: when CacheManager provides a hit, skip directly to + # PROCESS_INSTALL_DEPS -- no source download, no build env, no build + # deps resolution needed. Install deps are extracted from the wheel. + if cached_wheel and self.ctx.cache is not None: + server.update_wheel_mirror(self.ctx) + unpack_dir = self._create_unpack_dir(item.req, item.resolved_version) + item.build_result = SourceBuildResult( + wheel_filename=cached_wheel, + sdist_filename=None, + unpack_dir=unpack_dir, + sdist_root_dir=None, + build_env=None, + source_type=SourceType.CACHED, + ) + item.phase = BootstrapPhase.PROCESS_INSTALL_DEPS + return [item] + if not unpacked: logger.debug("no cached wheel, downloading sources") source_filename = self._download_source( diff --git a/src/fromager/cache.py b/src/fromager/cache.py new file mode 100644 index 000000000..dae284007 --- /dev/null +++ b/src/fromager/cache.py @@ -0,0 +1,752 @@ +"""Unified cache subsystem for Fromager artifact management. + +Provides a layered cache with collection-based organization, supporting +hierarchical lookup across local directories and remote PEP 503 repositories. + +Collections represent logically grouped artifacts (e.g., "default", "cuda", +"rocm"). Each collection has one or more backends (local filesystem, remote +index). Lookups traverse collections in priority order, and store routing +determines which collection receives newly built artifacts. +""" + +from __future__ import annotations + +import dataclasses +import logging +import pathlib +import re +import shutil +import time +import typing +from urllib.parse import urlparse + +from packaging.requirements import Requirement +from packaging.utils import ( + BuildTag, + NormalizedName, + canonicalize_name, + parse_wheel_filename, +) +from packaging.version import Version + +from .request_session import session + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Cache Keys +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass(frozen=True) +class WheelCacheKey: + """Identifies a cached wheel artifact. + + The key is intentionally simple -- collection routing is handled + externally by the CacheManager, not embedded in the key. + """ + + package: NormalizedName + version: Version + build_tag: BuildTag # (int, str) from changelog; () if untagged + + def __str__(self) -> str: + tag_str = f"-{self.build_tag[0]}{self.build_tag[1]}" if self.build_tag else "" + return f"{self.package}=={self.version}{tag_str}" + + +@dataclasses.dataclass(frozen=True) +class SdistCacheKey: + """Identifies a cached sdist artifact.""" + + package: NormalizedName + version: Version + + def __str__(self) -> str: + return f"{self.package}=={self.version}" + + +# --------------------------------------------------------------------------- +# Artifact Metadata +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class ArtifactInfo: + """Lightweight metadata for a cached artifact. + + Produced by scanning backends. For local backends, ``url_or_path`` is + an absolute filesystem path. For remote backends, it is a download URL. + """ + + filename: str + url_or_path: str + size_bytes: int | None = None + sha256: str | None = None + + +# --------------------------------------------------------------------------- +# Cache Result +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class CacheResult: + """Result of a cache lookup operation.""" + + hit: bool + path: pathlib.Path | None = None + collection: str = "" + backend_name: str = "" + build_tag: BuildTag = () + was_downloaded: bool = False + + @property + def miss(self) -> bool: + return not self.hit + + +# --------------------------------------------------------------------------- +# Observability +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class CacheEvent: + """A single cache interaction event.""" + + timestamp: float + action: typing.Literal["hit", "miss", "store"] + artifact_type: typing.Literal["wheel", "sdist"] + package: str + version: str + collection: str + backend: str + duration_ms: float | None = None + + +@dataclasses.dataclass +class CacheStats: + """Accumulates cache events for a single run.""" + + events: list[CacheEvent] = dataclasses.field(default_factory=list) + + def record_hit( + self, + req: Requirement, + version: Version, + collection: str, + backend: str, + artifact_type: typing.Literal["wheel", "sdist"] = "wheel", + duration_ms: float | None = None, + ) -> None: + self.events.append( + CacheEvent( + timestamp=time.monotonic(), + action="hit", + artifact_type=artifact_type, + package=str(req.name), + version=str(version), + collection=collection, + backend=backend, + duration_ms=duration_ms, + ) + ) + + def record_miss( + self, + req: Requirement, + version: Version, + reason: str, + artifact_type: typing.Literal["wheel", "sdist"] = "wheel", + ) -> None: + self.events.append( + CacheEvent( + timestamp=time.monotonic(), + action="miss", + artifact_type=artifact_type, + package=str(req.name), + version=str(version), + collection="", + backend=reason, + ) + ) + + def record_store( + self, + req: Requirement, + version: Version, + collection: str, + artifact_type: typing.Literal["wheel", "sdist"] = "wheel", + ) -> None: + self.events.append( + CacheEvent( + timestamp=time.monotonic(), + action="store", + artifact_type=artifact_type, + package=str(req.name), + version=str(version), + collection=collection, + backend="local", + ) + ) + + @property + def hits(self) -> int: + return sum(1 for e in self.events if e.action == "hit") + + @property + def misses(self) -> int: + return sum(1 for e in self.events if e.action == "miss") + + @property + def stores(self) -> int: + return sum(1 for e in self.events if e.action == "store") + + @property + def hit_rate(self) -> float: + total = self.hits + self.misses + if total == 0: + return 0.0 + return self.hits / total + + def summary(self) -> dict[str, typing.Any]: + """Return a structured summary suitable for JSON serialization.""" + hits_by_collection: dict[str, int] = {} + hits_by_backend: dict[str, int] = {} + for e in self.events: + if e.action == "hit": + hits_by_collection[e.collection] = ( + hits_by_collection.get(e.collection, 0) + 1 + ) + hits_by_backend[e.backend] = hits_by_backend.get(e.backend, 0) + 1 + return { + "hits": { + "total": self.hits, + "by_collection": hits_by_collection, + "by_backend": hits_by_backend, + }, + "misses": self.misses, + "stores": self.stores, + "hit_rate": round(self.hit_rate, 4), + } + + +# --------------------------------------------------------------------------- +# Cache Backend Protocol +# --------------------------------------------------------------------------- + + +class CacheBackend(typing.Protocol): + """Protocol for a single storage location that can find and store artifacts.""" + + @property + def name(self) -> str: + """Human-readable identifier (e.g., 'local:default', 'remote:https://...').""" + ... + + @property + def writable(self) -> bool: + """Whether this backend supports store operations.""" + ... + + def scan(self) -> dict[WheelCacheKey, ArtifactInfo]: + """Bulk index at startup. Local backends return full inventory; + remote backends fetch the top-level package list only and return empty. + """ + ... + + def lookup(self, key: WheelCacheKey) -> ArtifactInfo | None: + """Find a specific artifact by key. + + For local backends, checks the in-memory index. + For remote backends, lazily fetches the project page on first access. + """ + ... + + def fetch( + self, key: WheelCacheKey, info: ArtifactInfo, dest: pathlib.Path + ) -> pathlib.Path: + """Retrieve artifact to a local path. + + For local backends, returns the existing path (no-op). + For remote backends, downloads the file to ``dest``. + """ + ... + + def store(self, key: WheelCacheKey, artifact: pathlib.Path) -> ArtifactInfo: + """Store a newly built artifact. Only valid if ``writable`` is True.""" + ... + + +# --------------------------------------------------------------------------- +# Local Directory Backend +# --------------------------------------------------------------------------- + + +class LocalDirectoryBackend: + """Cache backend backed by a local filesystem directory. + + Scans at startup to populate an in-memory index from existing wheel files. + New stores are reflected immediately in the index. + """ + + def __init__( + self, + directory: pathlib.Path, + backend_name: str = "local", + ) -> None: + self._directory = directory + self._backend_name = backend_name + self._index: dict[WheelCacheKey, ArtifactInfo] = {} + + @property + def name(self) -> str: + return self._backend_name + + @property + def writable(self) -> bool: + return True + + @property + def directory(self) -> pathlib.Path: + return self._directory + + def scan(self) -> dict[WheelCacheKey, ArtifactInfo]: + """Scan the directory for wheel files and populate the index.""" + self._index.clear() + if not self._directory.exists(): + return self._index + + for wheel_file in self._directory.glob("*.whl"): + try: + name, version, build_tag, _ = parse_wheel_filename(wheel_file.name) + key = WheelCacheKey( + package=name, + version=version, + build_tag=build_tag, + ) + info = ArtifactInfo( + filename=wheel_file.name, + url_or_path=str(wheel_file.resolve()), + size_bytes=wheel_file.stat().st_size, + ) + self._index[key] = info + except Exception: + logger.debug("skipping unparseable wheel file: %s", wheel_file.name) + logger.debug("scanned %d wheels in %s", len(self._index), self._directory) + return dict(self._index) + + def lookup(self, key: WheelCacheKey) -> ArtifactInfo | None: + """Look up artifact in the in-memory index.""" + info = self._index.get(key) + if info is not None: + file_path = pathlib.Path(info.url_or_path) + if file_path.exists(): + return info + # File was removed since scan -- evict from index + del self._index[key] + return None + + def fetch( + self, key: WheelCacheKey, info: ArtifactInfo, dest: pathlib.Path + ) -> pathlib.Path: + """Return the existing local path (no-op for local backends).""" + return pathlib.Path(info.url_or_path) + + def store(self, key: WheelCacheKey, artifact: pathlib.Path) -> ArtifactInfo: + """Register an artifact in this backend's directory. + + If the artifact is not already in the directory, it is moved there. + Updates the in-memory index. + """ + dest = self._directory / artifact.name + if not dest.exists(): + self._directory.mkdir(parents=True, exist_ok=True) + shutil.move(str(artifact), str(dest)) + + info = ArtifactInfo( + filename=dest.name, + url_or_path=str(dest.resolve()), + size_bytes=dest.stat().st_size, + ) + self._index[key] = info + return info + + +# --------------------------------------------------------------------------- +# Remote PEP 503 Backend +# --------------------------------------------------------------------------- + + +class RemotePEP503Backend: + """Cache backend backed by a remote PEP 503 (Simple Repository API) server. + + At startup, fetches the top-level package list. Individual project pages + are fetched lazily on first lookup per package and memoized for the run. + """ + + def __init__( + self, + server_url: str, + download_dir: pathlib.Path, + backend_name: str | None = None, + ) -> None: + self._server_url = server_url.rstrip("/") + self._download_dir = download_dir + self._backend_name = backend_name or f"remote:{self._server_url}" + self._available_packages: set[NormalizedName] | None = None + self._project_cache: dict[NormalizedName, list[ArtifactInfo]] = {} + + @property + def name(self) -> str: + return self._backend_name + + @property + def writable(self) -> bool: + return False + + def scan(self) -> dict[WheelCacheKey, ArtifactInfo]: + """Fetch top-level index to learn which packages exist.""" + self._available_packages = self._fetch_package_list() + logger.debug( + "remote %s has %d packages available", + self._server_url, + len(self._available_packages) if self._available_packages else 0, + ) + return {} + + def lookup(self, key: WheelCacheKey) -> ArtifactInfo | None: + """Lazy per-package lookup with short-circuit for unknown packages.""" + if ( + self._available_packages is not None + and key.package not in self._available_packages + ): + return None + + if key.package not in self._project_cache: + self._project_cache[key.package] = self._fetch_project_page(key.package) + + for info in self._project_cache[key.package]: + try: + name, version, build_tag, _ = parse_wheel_filename(info.filename) + except Exception: + continue + candidate_key = WheelCacheKey( + package=name, version=version, build_tag=build_tag + ) + if candidate_key == key: + return info + + return None + + def fetch( + self, key: WheelCacheKey, info: ArtifactInfo, dest: pathlib.Path + ) -> pathlib.Path: + """Download the wheel from the remote server.""" + dest.mkdir(parents=True, exist_ok=True) + target = dest / info.filename + if target.exists(): + return target + + url = info.url_or_path + logger.info( + "downloading cached wheel %s from %s", info.filename, self._server_url + ) + resp = session.get(url, stream=True) + resp.raise_for_status() + with open(target, "wb") as f: + for chunk in resp.iter_content(chunk_size=1024 * 1024): + f.write(chunk) + return target + + def store(self, key: WheelCacheKey, artifact: pathlib.Path) -> ArtifactInfo: + """Not supported for remote backends.""" + raise NotImplementedError("Remote backends are read-only") + + def _fetch_package_list(self) -> set[NormalizedName]: + """Fetch the top-level /simple/ index and extract package names.""" + url = f"{self._server_url}/" + try: + resp = session.get(url) + resp.raise_for_status() + except Exception as err: + logger.warning("failed to fetch remote index %s: %s", url, err) + return set() + + return self._parse_index_page(resp.text) + + def _fetch_project_page(self, package: NormalizedName) -> list[ArtifactInfo]: + """Fetch a project's page and extract wheel artifact info.""" + url = f"{self._server_url}/{package}/" + try: + resp = session.get(url) + resp.raise_for_status() + except Exception as err: + logger.debug("failed to fetch project page %s: %s", url, err) + return [] + + return self._parse_project_page(resp.text, url) + + @staticmethod + def _parse_index_page(html: str) -> set[NormalizedName]: + """Extract package names from a PEP 503 index page.""" + names: set[NormalizedName] = set() + for match in re.finditer(r'([^<]+)', html): + name = match.group(1).strip().rstrip("/") + if name: + names.add(canonicalize_name(name)) + return names + + @staticmethod + def _parse_project_page(html: str, base_url: str) -> list[ArtifactInfo]: + """Extract wheel artifact info from a PEP 503 project page.""" + artifacts: list[ArtifactInfo] = [] + for match in re.finditer(r']*>([^<]+)', html): + href = match.group(1) + filename = match.group(2).strip() + if not filename.endswith(".whl"): + continue + + # Resolve relative URLs + if href.startswith("http://") or href.startswith("https://"): + url = href + elif href.startswith("/"): + parsed = urlparse(base_url) + url = f"{parsed.scheme}://{parsed.netloc}{href}" + else: + url = base_url.rstrip("/") + "/" + href + + # Strip hash fragment for the URL but extract sha256 if present + sha256 = None + if "#" in url: + url_part, fragment = url.rsplit("#", 1) + if fragment.startswith("sha256="): + sha256 = fragment[7:] + url = url_part + + artifacts.append( + ArtifactInfo( + filename=filename, + url_or_path=url, + sha256=sha256, + ) + ) + return artifacts + + +# --------------------------------------------------------------------------- +# Collection and Store Router +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class CacheCollection: + """A named group of artifacts with one or more backends.""" + + name: str + backends: list[CacheBackend] + store_backend: LocalDirectoryBackend + + def scan_all(self) -> None: + """Scan all backends in this collection.""" + for backend in self.backends: + backend.scan() + + +class StoreRouter: + """Determines which collection receives a newly built artifact. + + Routing priority: + 1. Explicit per-package override (from overrides.yaml) + 2. Listed in the variant's accelerated requirements file + 3. Default collection + """ + + def __init__( + self, + overrides: dict[NormalizedName, str], + accelerated_packages: set[NormalizedName], + active_variant: str, + default_collection: str = "default", + ) -> None: + self._overrides = overrides + self._accelerated_packages = accelerated_packages + self._active_variant = active_variant + self._default_collection = default_collection + + def route(self, req: Requirement) -> str: + """Return the collection name where this package should be stored.""" + name = canonicalize_name(req.name) + + if name in self._overrides: + return self._overrides[name] + + if name in self._accelerated_packages: + return self._active_variant + + return self._default_collection + + +# --------------------------------------------------------------------------- +# Cache Manager +# --------------------------------------------------------------------------- + + +class CacheManager: + """Unified entry point for all cache operations. + + Owns collections, handles hierarchical lookup, routes stores, + and tracks cache events for observability. + """ + + def __init__( + self, + collections: dict[str, CacheCollection], + search_order: list[str], + store_routing: StoreRouter, + force: bool = False, + ) -> None: + self._collections = collections + self._search_order = search_order + self._store_routing = store_routing + self._force = force + self._stats = CacheStats() + + def initialize(self) -> None: + """Scan all backends at build start. + + Local backends populate their in-memory index from disk. + Remote backends fetch the top-level package list. + """ + for name in self._search_order: + if name not in self._collections: + logger.warning("collection %r in search order but not configured", name) + continue + self._collections[name].scan_all() + + def lookup_wheel( + self, + req: Requirement, + version: Version, + build_tag: BuildTag, + ) -> CacheResult: + """Search collections in priority order for a matching wheel. + + Returns the first hit found. On a remote hit, the wheel is + downloaded to the collection's local store backend. + """ + if self._force: + self._stats.record_miss(req, version, "forced") + return CacheResult(hit=False) + + key = WheelCacheKey( + package=canonicalize_name(req.name), + version=version, + build_tag=build_tag, + ) + + for collection_name in self._search_order: + collection = self._collections.get(collection_name) + if collection is None: + continue + + for backend in collection.backends: + t0 = time.monotonic() + info = backend.lookup(key) + if info is None: + continue + + # Hit -- fetch the artifact to a local path + local_path = backend.fetch( + key, info, collection.store_backend.directory + ) + duration_ms = (time.monotonic() - t0) * 1000 + was_downloaded = not backend.writable + + self._stats.record_hit( + req, + version, + collection_name, + backend.name, + duration_ms=duration_ms, + ) + logger.info( + "cache hit for %s==%s in %s/%s", + req.name, + version, + collection_name, + backend.name, + ) + return CacheResult( + hit=True, + path=local_path, + collection=collection_name, + backend_name=backend.name, + build_tag=build_tag, + was_downloaded=was_downloaded, + ) + + self._stats.record_miss(req, version, "not_found") + logger.debug("cache miss for %s==%s", req.name, version) + return CacheResult(hit=False) + + def lookup_sdist( + self, + req: Requirement, + version: Version, + ) -> CacheResult: + """Search for a cached sdist across collections. + + Uses the same search order as wheel lookups. Sdist keys do not + include build tags. + """ + if self._force: + self._stats.record_miss(req, version, "forced", artifact_type="sdist") + return CacheResult(hit=False) + + # Sdist lookup reuses wheel key matching against .tar.gz/.zip files + # For now, delegate to a simple filename-based check in local backends + # TODO: extend backends with sdist-specific scan/lookup + self._stats.record_miss(req, version, "not_implemented", artifact_type="sdist") + return CacheResult(hit=False) + + def store_wheel( + self, + req: Requirement, + version: Version, + build_tag: BuildTag, + wheel_path: pathlib.Path, + ) -> pathlib.Path: + """Route and store a newly built wheel in the appropriate collection.""" + collection_name = self._store_routing.route(req) + collection = self._collections.get(collection_name) + if collection is None: + raise ValueError( + f"store routing returned unknown collection {collection_name!r} " + f"for {req.name}" + ) + + key = WheelCacheKey( + package=canonicalize_name(req.name), + version=version, + build_tag=build_tag, + ) + + info = collection.store_backend.store(key, wheel_path) + self._stats.record_store(req, version, collection_name) + logger.info("stored %s in collection %r", info.filename, collection_name) + return pathlib.Path(info.url_or_path) + + @property + def stats(self) -> CacheStats: + return self._stats + + @property + def collections(self) -> dict[str, CacheCollection]: + return self._collections + + @property + def search_order(self) -> list[str]: + return list(self._search_order) diff --git a/src/fromager/commands/bootstrap.py b/src/fromager/commands/bootstrap.py index 9baeb36ce..6fc45d86c 100644 --- a/src/fromager/commands/bootstrap.py +++ b/src/fromager/commands/bootstrap.py @@ -22,6 +22,7 @@ ) from ..log import requirement_ctxvar from .build import build_parallel +from .cache_cmd import _build_cache_manager from .graph import find_why, show_explain_duplicates # Map child_name==child_version to list of (parent_name==parent_version, Requirement) @@ -117,6 +118,12 @@ def _get_requirements_from_args( default=None, help="Reject package versions published more than this many days ago.", ) +@click.option( + "--use-cache-manager/--no-cache-manager", + "use_cache_manager", + default=False, + help="Enable the new unified CacheManager for hierarchical cache lookups.", +) @click.argument("toplevel", nargs=-1) @click.pass_obj def bootstrap( @@ -129,6 +136,7 @@ def bootstrap( test_mode: bool, multiple_versions: bool, max_release_age: int | None, + use_cache_manager: bool, toplevel: list[str], ) -> None: """Compute and build the dependencies of a set of requirements recursively @@ -187,6 +195,15 @@ def bootstrap( if pre_built: logger.info("treating %s as pre-built wheels", sorted(pre_built)) + if use_cache_manager: + cache_mgr = _build_cache_manager(wkctx, cache_url=cache_wheel_server_url) + wkctx.cache = cache_mgr + logger.info( + "cache manager enabled with %d collection(s): %s", + len(cache_mgr.collections), + list(cache_mgr.collections.keys()), + ) + server.start_wheel_server(wkctx) with progress.progress_context(total=len(to_build * 2)) as progressbar: diff --git a/src/fromager/commands/cache_cmd.py b/src/fromager/commands/cache_cmd.py new file mode 100644 index 000000000..8c43141ac --- /dev/null +++ b/src/fromager/commands/cache_cmd.py @@ -0,0 +1,392 @@ +"""CLI commands for cache management and observability.""" + +import json +import logging +import pathlib + +import click +import rich +import rich.box +from packaging.utils import canonicalize_name +from rich.table import Table + +from fromager import context +from fromager.cache import ( + CacheBackend, + CacheCollection, + CacheManager, + LocalDirectoryBackend, + RemotePEP503Backend, + StoreRouter, +) + +logger = logging.getLogger(__name__) + + +def _build_cache_manager( + wkctx: context.WorkContext, + cache_url: str | None = None, +) -> CacheManager: + """Construct a CacheManager from the WorkContext configuration. + + If the context already has a cache configured, return it. + Otherwise, build one from the standard filesystem layout. + + Args: + wkctx: The work context providing local paths and variant info. + cache_url: Optional URL to a remote PEP 503 cache server. + """ + if wkctx.cache is not None: + return wkctx.cache + + local_backend = LocalDirectoryBackend( + wkctx.wheels_downloads, backend_name="local:downloads" + ) + prebuilt_backend = LocalDirectoryBackend( + wkctx.wheels_prebuilt, backend_name="local:prebuilt" + ) + + backends: list[CacheBackend] = [local_backend, prebuilt_backend] + + if cache_url: + remote_backend = RemotePEP503Backend( + server_url=cache_url, + download_dir=wkctx.wheels_downloads, + backend_name=f"remote:{cache_url}", + ) + backends.append(remote_backend) + + collection = CacheCollection( + name="default", + backends=backends, + store_backend=local_backend, + ) + + router = StoreRouter( + overrides={}, + accelerated_packages=set(), + active_variant=wkctx.variant, + ) + + manager = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + ) + manager.initialize() + return manager + + +@click.group() +def cache() -> None: + """Manage the fromager wheel cache.""" + pass + + +@cache.command(name="list") +@click.option( + "--collection", + default=None, + help="Only list artifacts from this collection.", +) +@click.option( + "--format", + "output_format", + type=click.Choice(["table", "json"], case_sensitive=False), + default="table", + help="Output format (default: table).", +) +@click.pass_obj +def cache_list( + wkctx: context.WorkContext, + collection: str | None, + output_format: str, +) -> None: + """List all cached wheel artifacts.""" + manager = _build_cache_manager(wkctx) + + entries = [] + for coll_name, coll in manager.collections.items(): + if collection and coll_name != collection: + continue + for backend in coll.backends: + if not hasattr(backend, "_index"): + continue + for key, info in backend._index.items(): + entries.append( + { + "collection": coll_name, + "backend": backend.name, + "package": str(key.package), + "version": str(key.version), + "build_tag": ( + f"{key.build_tag[0]}{key.build_tag[1]}" + if key.build_tag + else "" + ), + "filename": info.filename, + "size_bytes": info.size_bytes or 0, + } + ) + + entries.sort(key=lambda e: (e["collection"], e["package"], e["version"])) + + if output_format == "json": + click.echo(json.dumps(entries, indent=2)) + return + + if not entries: + click.echo("No cached wheels found.") + return + + table = Table(title="Cached Wheels", box=rich.box.SIMPLE) + table.add_column("Collection", no_wrap=True) + table.add_column("Package", no_wrap=True) + table.add_column("Version", no_wrap=True) + table.add_column("Build Tag", no_wrap=True) + table.add_column("Size", justify="right", no_wrap=True) + table.add_column("Backend", no_wrap=True) + + for entry in entries: + size = _format_size(entry["size_bytes"]) + table.add_row( + entry["collection"], + entry["package"], + entry["version"], + entry["build_tag"], + size, + entry["backend"], + ) + + console = rich.get_console() + console.print(table) + console.print(f"\nTotal: {len(entries)} wheel(s)") + + +@cache.command(name="stats") +@click.pass_obj +def cache_stats(wkctx: context.WorkContext) -> None: + """Show cache statistics (hit/miss rates from last run).""" + manager = _build_cache_manager(wkctx) + + table = Table(title="Cache Statistics", box=rich.box.SIMPLE) + table.add_column("Metric", no_wrap=True) + table.add_column("Value", justify="right", no_wrap=True) + + summary = manager.stats.summary() + + table.add_row("Total lookups", str(summary["hits"]["total"] + summary["misses"])) + table.add_row("Hits", str(summary["hits"]["total"])) + table.add_row("Misses", str(summary["misses"])) + table.add_row("Hit rate", f"{summary['hit_rate']:.1%}") + table.add_row("Stores", str(summary["stores"])) + + if summary["hits"]["by_collection"]: + table.add_section() + for coll, count in summary["hits"]["by_collection"].items(): + table.add_row(f" Hits from {coll}", str(count)) + + # Show per-collection inventory counts + table.add_section() + total_wheels = 0 + total_size = 0 + for coll_name, coll in manager.collections.items(): + coll_count = 0 + coll_size = 0 + for backend in coll.backends: + if hasattr(backend, "_index"): + coll_count += len(backend._index) + coll_size += sum( + (info.size_bytes or 0) for info in backend._index.values() + ) + table.add_row(f" {coll_name} wheels", str(coll_count)) + table.add_row(f" {coll_name} size", _format_size(coll_size)) + total_wheels += coll_count + total_size += coll_size + + table.add_section() + table.add_row("Total wheels on disk", str(total_wheels)) + table.add_row("Total size on disk", _format_size(total_size)) + + console = rich.get_console() + console.print(table) + + +@cache.command() +@click.option( + "--remove-missing", + is_flag=True, + default=False, + help="Remove index entries for files that no longer exist on disk.", +) +@click.pass_obj +def verify(wkctx: context.WorkContext, remove_missing: bool) -> None: + """Verify cache integrity: check that indexed files exist on disk.""" + manager = _build_cache_manager(wkctx) + + missing = [] + checked = 0 + + for coll_name, coll in manager.collections.items(): + for backend in coll.backends: + if not isinstance(backend, LocalDirectoryBackend): + continue + for key, info in list(backend._index.items()): + checked += 1 + file_path = pathlib.Path(info.url_or_path) + if not file_path.exists(): + missing.append( + { + "collection": coll_name, + "backend": backend.name, + "key": str(key), + "path": str(file_path), + } + ) + if remove_missing: + del backend._index[key] + + if not missing: + click.echo(f"All {checked} cached artifacts verified OK.") + return + + click.echo(f"Found {len(missing)} missing artifact(s) out of {checked} checked:") + for m in missing: + action = " [removed from index]" if remove_missing else "" + click.echo(f" {m['collection']}/{m['key']}: {m['path']}{action}") + + +@cache.command() +@click.argument("packages", nargs=-1) +@click.option( + "--all", + "invalidate_all", + is_flag=True, + default=False, + help="Invalidate the entire cache.", +) +@click.option( + "--collection", + default=None, + help="Only invalidate within this collection.", +) +@click.pass_obj +def invalidate( + wkctx: context.WorkContext, + packages: tuple[str, ...], + invalidate_all: bool, + collection: str | None, +) -> None: + """Invalidate (remove) cached wheels for specific packages. + + Pass package names as arguments, or use --all to clear everything. + """ + if not packages and not invalidate_all: + raise click.UsageError("Specify package names or use --all.") + + manager = _build_cache_manager(wkctx) + removed = 0 + + target_packages = {canonicalize_name(p) for p in packages} if packages else None + + for coll_name, coll in manager.collections.items(): + if collection and coll_name != collection: + continue + for backend in coll.backends: + if not isinstance(backend, LocalDirectoryBackend): + continue + keys_to_remove = [] + for key, info in backend._index.items(): + if target_packages and key.package not in target_packages: + continue + keys_to_remove.append((key, info)) + + for key, info in keys_to_remove: + file_path = pathlib.Path(info.url_or_path) + if file_path.exists(): + file_path.unlink() + logger.info("removed %s", file_path) + del backend._index[key] + removed += 1 + + click.echo(f"Invalidated {removed} cached artifact(s).") + + +@cache.command() +@click.option( + "--dry-run", + is_flag=True, + default=False, + help="Show what would be removed without actually deleting.", +) +@click.option( + "--keep-latest", + type=int, + default=1, + help="Keep this many build tags per package+version (default: 1).", +) +@click.pass_obj +def gc( + wkctx: context.WorkContext, + dry_run: bool, + keep_latest: int, +) -> None: + """Garbage-collect old builds, keeping only the latest build tags. + + For each package+version, removes all but the --keep-latest most + recent builds (highest build tag number). + """ + manager = _build_cache_manager(wkctx) + removed = 0 + freed_bytes = 0 + + for _coll_name, coll in manager.collections.items(): + for backend in coll.backends: + if not isinstance(backend, LocalDirectoryBackend): + continue + + # Group by (package, version) + groups: dict[tuple, list] = {} + for key, info in backend._index.items(): + group_key = (key.package, key.version) + groups.setdefault(group_key, []).append((key, info)) + + for _group_key, entries in groups.items(): + if len(entries) <= keep_latest: + continue + + # Sort by build tag number descending + entries.sort( + key=lambda e: e[0].build_tag[0] if e[0].build_tag else 0, + reverse=True, + ) + to_remove = entries[keep_latest:] + + for key, info in to_remove: + file_path = pathlib.Path(info.url_or_path) + size = info.size_bytes or 0 + if dry_run: + click.echo( + f" would remove: {info.filename} ({_format_size(size)})" + ) + else: + if file_path.exists(): + file_path.unlink() + del backend._index[key] + logger.info("gc removed %s", file_path) + removed += 1 + freed_bytes += size + + verb = "Would remove" if dry_run else "Removed" + click.echo(f"{verb} {removed} old build(s), freeing {_format_size(freed_bytes)}.") + + +def _format_size(size_bytes: int) -> str: + """Format bytes as human-readable size.""" + if size_bytes == 0: + return "0 B" + for unit in ("B", "KB", "MB", "GB"): + if abs(size_bytes) < 1024: + return f"{size_bytes:.1f} {unit}" + size_bytes /= 1024 # type: ignore[assignment] + return f"{size_bytes:.1f} TB" diff --git a/src/fromager/context.py b/src/fromager/context.py index 58866aedb..45c45a570 100644 --- a/src/fromager/context.py +++ b/src/fromager/context.py @@ -22,7 +22,7 @@ ) if typing.TYPE_CHECKING: - from . import build_environment, candidate + from . import build_environment, cache, candidate logger = logging.getLogger(__name__) @@ -98,6 +98,8 @@ def __init__( self.cooldown: candidate.Cooldown | None = cooldown self._max_release_age: datetime.timedelta | None = max_release_age + self._cache: cache.CacheManager | None = None + @property def max_release_age(self) -> datetime.timedelta | None: return self._max_release_age @@ -108,6 +110,15 @@ def set_max_release_age(self, days: int) -> None: raise ValueError(f"max_release_age must be positive, got {days}") self._max_release_age = datetime.timedelta(days=days) + @property + def cache(self) -> cache.CacheManager | None: + """The cache manager, if configured.""" + return self._cache + + @cache.setter + def cache(self, value: cache.CacheManager) -> None: + self._cache = value + def enable_parallel_builds(self) -> None: self._parallel_builds = True diff --git a/src/fromager/requirements_file.py b/src/fromager/requirements_file.py index a588d20a3..02f9f4f0a 100644 --- a/src/fromager/requirements_file.py +++ b/src/fromager/requirements_file.py @@ -34,6 +34,7 @@ class SourceType(StrEnum): SDIST = "sdist" OVERRIDE = "override" GIT = "git" + CACHED = "cached" def parse_requirements_file( diff --git a/tests/test_bootstrapper_iterative.py b/tests/test_bootstrapper_iterative.py index 27692dc95..933055e17 100644 --- a/tests/test_bootstrapper_iterative.py +++ b/tests/test_bootstrapper_iterative.py @@ -1126,6 +1126,338 @@ def test_constraint_logged_when_present( assert "matches constraint" in caplog.text +class TestPhasePrepareSourceCacheManager: + """Tests for _phase_prepare_source using the CacheManager short-circuit path. + + When a CacheManager is configured on WorkContext and provides a cache hit, + PREPARE_SOURCE should skip directly to PROCESS_INSTALL_DEPS without + downloading source, creating a build environment, or resolving build deps. + """ + + def test_cache_hit_short_circuits_to_process_install_deps( + self, tmp_context: WorkContext + ) -> None: + """Cache hit via CacheManager skips all build phases.""" + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + # Set up a CacheManager with a wheel in the default collection + wheels_dir = tmp_context.wheels_downloads + wheel_file = wheels_dir / "testpkg-1.0-1-py3-none-any.whl" + wheel_file.write_bytes(b"fake wheel") + + backend = LocalDirectoryBackend(wheels_dir, backend_name="local:default") + collection = CacheCollection( + name="default", backends=[backend], store_backend=backend + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + cache_mgr = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + item = _make_build_item(phase=BootstrapPhase.PREPARE_SOURCE) + + with ( + patch.object(tmp_context.constraints, "get_constraint", return_value=None), + patch.object( + bt, + "_find_cached_wheel_via_manager", + return_value=(wheel_file, tmp_context.work_dir / "testpkg-1.0"), + ), + patch.object( + bt, + "_create_unpack_dir", + return_value=tmp_context.work_dir / "testpkg-1.0", + ), + patch.object(bt, "_download_source") as mock_dl_src, + patch.object(bt, "_prepare_source") as mock_prep, + patch.object(bt, "_create_build_env") as mock_create_env, + patch("fromager.bootstrapper.server.update_wheel_mirror") as mock_mirror, + ): + result = bt._phase_prepare_source(item) + + # Short-circuit: jumped to PROCESS_INSTALL_DEPS + assert item.phase == BootstrapPhase.PROCESS_INSTALL_DEPS + assert item.build_result is not None + assert item.build_result.wheel_filename == wheel_file + assert item.build_result.source_type == SourceType.CACHED + assert item.build_result.build_env is None + assert item.build_result.sdist_filename is None + + # Nothing from the build path was called + mock_dl_src.assert_not_called() + mock_prep.assert_not_called() + mock_create_env.assert_not_called() + + # Wheel mirror was updated so cached wheel is indexed for build deps + mock_mirror.assert_called_once_with(tmp_context) + + # Only the continuation item, no build dep items + assert len(result) == 1 + assert result[0] is item + + def test_cache_hit_updates_wheel_mirror_for_build_dep_resolution( + self, tmp_context: WorkContext + ) -> None: + """Cache hit calls update_wheel_mirror so downloaded wheels are available + to subsequent build environments via the internal wheel server.""" + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + wheels_dir = tmp_context.wheels_downloads + wheel_file = wheels_dir / "testpkg-1.0-1-py3-none-any.whl" + wheel_file.write_bytes(b"fake wheel") + + backend = LocalDirectoryBackend(wheels_dir, backend_name="local:default") + collection = CacheCollection( + name="default", backends=[backend], store_backend=backend + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + cache_mgr = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + item = _make_build_item(phase=BootstrapPhase.PREPARE_SOURCE) + + with ( + patch.object(tmp_context.constraints, "get_constraint", return_value=None), + patch.object( + bt, + "_find_cached_wheel_via_manager", + return_value=(wheel_file, tmp_context.work_dir / "testpkg-1.0"), + ), + patch.object( + bt, + "_create_unpack_dir", + return_value=tmp_context.work_dir / "testpkg-1.0", + ), + patch("fromager.bootstrapper.server.update_wheel_mirror") as mock_mirror, + ): + bt._phase_prepare_source(item) + + # update_wheel_mirror called before returning, ensuring the wheel is + # symlinked into the simple/ index for uv pip install to find it + mock_mirror.assert_called_once_with(tmp_context) + + def test_cache_miss_with_manager_falls_through_to_build( + self, tmp_context: WorkContext + ) -> None: + """Cache miss via CacheManager proceeds to normal build path.""" + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + # Empty cache — no wheels + wheels_dir = tmp_context.wheels_downloads + backend = LocalDirectoryBackend(wheels_dir, backend_name="local:default") + collection = CacheCollection( + name="default", backends=[backend], store_backend=backend + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + cache_mgr = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + item = _make_build_item(phase=BootstrapPhase.PREPARE_SOURCE) + + sdist_root = tmp_context.work_dir / "testpkg-1.0" / "testpkg-1.0" + mock_env = Mock() + + with ( + patch.object(tmp_context.constraints, "get_constraint", return_value=None), + patch.object( + bt, "_download_source", return_value=tmp_context.work_dir / "src.tar.gz" + ) as mock_dl_src, + patch.object(bt, "_prepare_source", return_value=sdist_root) as mock_prep, + patch.object(bt, "_create_build_env", return_value=mock_env), + patch( + "fromager.dependencies.get_build_system_dependencies", + return_value=set(), + ), + patch.object(bt, "_create_unresolved_work_items", return_value=[]), + ): + bt._phase_prepare_source(item) + + # Normal build path — advances to PREPARE_BUILD + assert item.phase == BootstrapPhase.PREPARE_BUILD + assert item.build_env is mock_env + mock_dl_src.assert_called_once() + mock_prep.assert_called_once() + + def test_cache_hit_records_stats(self, tmp_context: WorkContext) -> None: + """Cache hit via CacheManager records a hit event in stats.""" + + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + wheels_dir = tmp_context.wheels_downloads + wheel_file = wheels_dir / "testpkg-1.0-1-py3-none-any.whl" + wheel_file.write_bytes(b"fake wheel") + + backend = LocalDirectoryBackend(wheels_dir, backend_name="local:default") + collection = CacheCollection( + name="default", backends=[backend], store_backend=backend + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + cache_mgr = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + item = _make_build_item(phase=BootstrapPhase.PREPARE_SOURCE) + + with ( + patch.object(tmp_context.constraints, "get_constraint", return_value=None), + patch.object( + bt, + "_find_cached_wheel_via_manager", + return_value=(wheel_file, tmp_context.work_dir / "testpkg-1.0"), + ), + patch.object( + bt, + "_create_unpack_dir", + return_value=tmp_context.work_dir / "testpkg-1.0", + ), + patch("fromager.bootstrapper.server.update_wheel_mirror"), + ): + bt._phase_prepare_source(item) + + # Stats should show the cache was consulted + assert item.phase == BootstrapPhase.PROCESS_INSTALL_DEPS + + def test_no_cache_manager_uses_legacy_path(self, tmp_context: WorkContext) -> None: + """Without CacheManager, legacy path is used (no short-circuit).""" + # Explicitly no cache manager + assert tmp_context.cache is None + + bt = bootstrapper.Bootstrapper(tmp_context) + item = _make_build_item(phase=BootstrapPhase.PREPARE_SOURCE) + + unpacked = tmp_context.work_dir / "testpkg-1.0" + unpacked.mkdir(parents=True) + cached_wheel = tmp_context.work_dir / "testpkg-1.0-py3-none-any.whl" + mock_env = Mock() + + with ( + patch.object(tmp_context.constraints, "get_constraint", return_value=None), + patch.object( + bt, "_find_cached_wheel", return_value=(cached_wheel, unpacked) + ), + patch.object(bt, "_download_source") as mock_dl_src, + patch.object(bt, "_prepare_source") as mock_prep, + patch.object(bt, "_create_build_env", return_value=mock_env), + patch( + "fromager.dependencies.get_build_system_dependencies", + return_value=set(), + ), + patch.object(bt, "_create_unresolved_work_items", return_value=[]), + ): + bt._phase_prepare_source(item) + + # Legacy path: cached wheel found but no short-circuit — goes to PREPARE_BUILD + assert item.phase == BootstrapPhase.PREPARE_BUILD + assert item.cached_wheel_filename == cached_wheel + mock_dl_src.assert_not_called() + mock_prep.assert_not_called() + + def test_force_mode_skips_cache(self, tmp_context: WorkContext) -> None: + """--force flag causes CacheManager to return miss, triggering full build.""" + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + wheels_dir = tmp_context.wheels_downloads + # Put a wheel in the cache + wheel_file = wheels_dir / "testpkg-1.0-1-py3-none-any.whl" + wheel_file.write_bytes(b"fake wheel") + + backend = LocalDirectoryBackend(wheels_dir, backend_name="local:default") + collection = CacheCollection( + name="default", backends=[backend], store_backend=backend + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + cache_mgr = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + force=True, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + item = _make_build_item(phase=BootstrapPhase.PREPARE_SOURCE) + + sdist_root = tmp_context.work_dir / "testpkg-1.0" / "testpkg-1.0" + mock_env = Mock() + + with ( + patch.object(tmp_context.constraints, "get_constraint", return_value=None), + patch.object( + bt, "_download_source", return_value=tmp_context.work_dir / "src.tar.gz" + ) as mock_dl_src, + patch.object(bt, "_prepare_source", return_value=sdist_root), + patch.object(bt, "_create_build_env", return_value=mock_env), + patch( + "fromager.dependencies.get_build_system_dependencies", + return_value=set(), + ), + patch.object(bt, "_create_unresolved_work_items", return_value=[]), + ): + bt._phase_prepare_source(item) + + # Force mode: goes through full build path despite wheel existing + assert item.phase == BootstrapPhase.PREPARE_BUILD + mock_dl_src.assert_called_once() + + class TestPhasePrepareBuild: """Tests for _phase_prepare_build: dep installation and extraction.""" @@ -1672,3 +2004,155 @@ def test_build_order_called_with_correct_args( prebuilt=True, constraint=constraint, ) + + +class TestFindCachedWheelDispatch: + """Tests for _find_cached_wheel dispatch to manager vs legacy.""" + + def test_dispatches_to_manager_when_configured( + self, tmp_context: WorkContext + ) -> None: + """With CacheManager set, _find_cached_wheel calls via_manager.""" + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + wheels_dir = tmp_context.wheels_downloads + backend = LocalDirectoryBackend(wheels_dir, backend_name="local:default") + collection = CacheCollection( + name="default", backends=[backend], store_backend=backend + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + cache_mgr = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + req = Requirement("testpkg") + version = Version("1.0") + + with ( + patch.object( + bt, "_find_cached_wheel_via_manager", return_value=(None, None) + ) as mock_mgr, + patch.object(bt, "_find_cached_wheel_legacy") as mock_legacy, + ): + bt._find_cached_wheel(req, version) + + mock_mgr.assert_called_once_with(req, version) + mock_legacy.assert_not_called() + + def test_dispatches_to_legacy_when_no_manager( + self, tmp_context: WorkContext + ) -> None: + """Without CacheManager, _find_cached_wheel calls legacy.""" + assert tmp_context.cache is None + + bt = bootstrapper.Bootstrapper(tmp_context) + req = Requirement("testpkg") + version = Version("1.0") + + with ( + patch.object(bt, "_find_cached_wheel_via_manager") as mock_mgr, + patch.object( + bt, "_find_cached_wheel_legacy", return_value=(None, None) + ) as mock_legacy, + ): + bt._find_cached_wheel(req, version) + + mock_legacy.assert_called_once_with(req, version) + mock_mgr.assert_not_called() + + def test_via_manager_returns_hit_path(self, tmp_context: WorkContext) -> None: + """_find_cached_wheel_via_manager returns path on cache hit.""" + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + wheels_dir = tmp_context.wheels_downloads + wheel_file = wheels_dir / "testpkg-1.0-1-py3-none-any.whl" + wheel_file.write_bytes(b"fake wheel content") + + backend = LocalDirectoryBackend(wheels_dir, backend_name="local:default") + collection = CacheCollection( + name="default", backends=[backend], store_backend=backend + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + cache_mgr = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + req = Requirement("testpkg") + version = Version("1.0") + + # build_tag must be a BuildTag tuple matching what parse_wheel_filename produces + mock_pbi = Mock() + mock_pbi.build_tag.return_value = (1, "") + metadata_dir = tmp_context.work_dir / "testpkg-1.0.dist-info" + + with ( + patch.object(tmp_context, "package_build_info", return_value=mock_pbi), + patch.object(bt, "_unpack_metadata_from_wheel", return_value=metadata_dir), + ): + wheel_path, meta_dir = bt._find_cached_wheel_via_manager(req, version) + + assert wheel_path == wheel_file + assert meta_dir == metadata_dir + + def test_via_manager_returns_none_on_miss(self, tmp_context: WorkContext) -> None: + """_find_cached_wheel_via_manager returns (None, None) on cache miss.""" + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + # Empty cache + wheels_dir = tmp_context.wheels_downloads + backend = LocalDirectoryBackend(wheels_dir, backend_name="local:default") + collection = CacheCollection( + name="default", backends=[backend], store_backend=backend + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + cache_mgr = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + req = Requirement("testpkg") + version = Version("1.0") + + mock_pbi = Mock() + mock_pbi.build_tag.return_value = (1, "") + + with patch.object(tmp_context, "package_build_info", return_value=mock_pbi): + wheel_path, meta_dir = bt._find_cached_wheel_via_manager(req, version) + + assert wheel_path is None + assert meta_dir is None diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 000000000..7831ef97c --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,1222 @@ +"""Unit tests for the cache subsystem.""" + +import pathlib + +import pytest +import requests_mock +from packaging.requirements import Requirement +from packaging.utils import canonicalize_name +from packaging.version import Version + +from fromager.cache import ( + CacheCollection, + CacheManager, + CacheResult, + CacheStats, + LocalDirectoryBackend, + RemotePEP503Backend, + SdistCacheKey, + StoreRouter, + WheelCacheKey, +) + +# --------------------------------------------------------------------------- +# WheelCacheKey tests +# --------------------------------------------------------------------------- + + +class TestWheelCacheKey: + def test_creation(self) -> None: + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + assert key.package == "numpy" + assert key.version == Version("1.26.4") + assert key.build_tag == (2, "") + + def test_equality(self) -> None: + key1 = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + key2 = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + assert key1 == key2 + + def test_inequality_version(self) -> None: + key1 = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + key2 = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.5"), + build_tag=(2, ""), + ) + assert key1 != key2 + + def test_inequality_build_tag(self) -> None: + key1 = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + key2 = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(3, ""), + ) + assert key1 != key2 + + def test_hashable(self) -> None: + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + d: dict[WheelCacheKey, str] = {key: "found"} + assert d[key] == "found" + + def test_str_with_build_tag(self) -> None: + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + assert str(key) == "numpy==1.26.4-2" + + def test_str_without_build_tag(self) -> None: + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(), + ) + assert str(key) == "numpy==1.26.4" + + def test_name_normalization(self) -> None: + key1 = WheelCacheKey( + package=canonicalize_name("Flask-RESTful"), + version=Version("0.3.10"), + build_tag=(), + ) + key2 = WheelCacheKey( + package=canonicalize_name("flask_restful"), + version=Version("0.3.10"), + build_tag=(), + ) + assert key1 == key2 + + +class TestSdistCacheKey: + def test_creation(self) -> None: + key = SdistCacheKey( + package=canonicalize_name("requests"), + version=Version("2.31.0"), + ) + assert key.package == "requests" + assert str(key) == "requests==2.31.0" + + +# --------------------------------------------------------------------------- +# LocalDirectoryBackend tests +# --------------------------------------------------------------------------- + + +def _create_wheel_file(directory: pathlib.Path, filename: str) -> pathlib.Path: + """Create a fake wheel file for testing.""" + directory.mkdir(parents=True, exist_ok=True) + wheel_path = directory / filename + wheel_path.write_bytes(b"fake wheel content") + return wheel_path + + +class TestLocalDirectoryBackend: + def test_name(self, tmp_path: pathlib.Path) -> None: + backend = LocalDirectoryBackend( + tmp_path / "wheels", backend_name="local:default" + ) + assert backend.name == "local:default" + + def test_writable(self, tmp_path: pathlib.Path) -> None: + backend = LocalDirectoryBackend(tmp_path / "wheels") + assert backend.writable is True + + def test_scan_empty_directory(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + wheels_dir.mkdir() + backend = LocalDirectoryBackend(wheels_dir) + result = backend.scan() + assert result == {} + + def test_scan_nonexistent_directory(self, tmp_path: pathlib.Path) -> None: + backend = LocalDirectoryBackend(tmp_path / "nonexistent") + result = backend.scan() + assert result == {} + + def test_scan_finds_wheels(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + _create_wheel_file(wheels_dir, "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl") + backend = LocalDirectoryBackend(wheels_dir) + result = backend.scan() + + expected_key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + assert expected_key in result + assert ( + result[expected_key].filename + == "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl" + ) + + def test_scan_skips_non_wheel_files(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + wheels_dir.mkdir() + (wheels_dir / "readme.txt").write_text("not a wheel") + (wheels_dir / "numpy-1.26.4.tar.gz").write_bytes(b"sdist") + backend = LocalDirectoryBackend(wheels_dir) + result = backend.scan() + assert result == {} + + def test_scan_skips_unparseable_wheels(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + wheels_dir.mkdir() + (wheels_dir / "totally-invalid-name.whl").write_bytes(b"bad") + backend = LocalDirectoryBackend(wheels_dir) + result = backend.scan() + assert result == {} + + def test_lookup_hit(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + _create_wheel_file(wheels_dir, "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl") + backend = LocalDirectoryBackend(wheels_dir) + backend.scan() + + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + result = backend.lookup(key) + assert result is not None + assert result.filename == "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl" + + def test_lookup_miss(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + wheels_dir.mkdir() + backend = LocalDirectoryBackend(wheels_dir) + backend.scan() + + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + assert backend.lookup(key) is None + + def test_lookup_evicts_deleted_file(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + whl = _create_wheel_file( + wheels_dir, "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl" + ) + backend = LocalDirectoryBackend(wheels_dir) + backend.scan() + + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + # File exists initially + assert backend.lookup(key) is not None + # Delete it + whl.unlink() + # Now lookup should return None and evict from index + assert backend.lookup(key) is None + + def test_fetch_returns_local_path(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + whl = _create_wheel_file( + wheels_dir, "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl" + ) + backend = LocalDirectoryBackend(wheels_dir) + backend.scan() + + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + info = backend.lookup(key) + assert info is not None + result = backend.fetch(key, info, tmp_path / "dest") + assert result == whl.resolve() + + def test_store_moves_file(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + wheels_dir.mkdir() + backend = LocalDirectoryBackend(wheels_dir) + backend.scan() + + # Create a wheel in a "build" directory + build_dir = tmp_path / "build" + whl = _create_wheel_file(build_dir, "requests-2.31.0-1-py3-none-any.whl") + + key = WheelCacheKey( + package=canonicalize_name("requests"), + version=Version("2.31.0"), + build_tag=(1, ""), + ) + info = backend.store(key, whl) + + assert info.filename == "requests-2.31.0-1-py3-none-any.whl" + assert (wheels_dir / "requests-2.31.0-1-py3-none-any.whl").exists() + assert not whl.exists() # Moved, not copied + + def test_store_updates_index(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + wheels_dir.mkdir() + backend = LocalDirectoryBackend(wheels_dir) + backend.scan() + + build_dir = tmp_path / "build" + whl = _create_wheel_file(build_dir, "requests-2.31.0-1-py3-none-any.whl") + + key = WheelCacheKey( + package=canonicalize_name("requests"), + version=Version("2.31.0"), + build_tag=(1, ""), + ) + backend.store(key, whl) + + # Should be findable via lookup now + result = backend.lookup(key) + assert result is not None + assert result.filename == "requests-2.31.0-1-py3-none-any.whl" + + def test_store_no_move_if_already_exists(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + existing = _create_wheel_file(wheels_dir, "requests-2.31.0-1-py3-none-any.whl") + backend = LocalDirectoryBackend(wheels_dir) + backend.scan() + + key = WheelCacheKey( + package=canonicalize_name("requests"), + version=Version("2.31.0"), + build_tag=(1, ""), + ) + # Store with same filename that already exists + info = backend.store(key, existing) + assert info.filename == "requests-2.31.0-1-py3-none-any.whl" + assert existing.exists() + + +# --------------------------------------------------------------------------- +# RemotePEP503Backend tests +# --------------------------------------------------------------------------- + + +class TestRemotePEP503Backend: + def test_name_default(self) -> None: + backend = RemotePEP503Backend( + server_url="https://cache.example.com/simple", + download_dir=pathlib.Path("/tmp/downloads"), + ) + assert backend.name == "remote:https://cache.example.com/simple" + + def test_not_writable(self) -> None: + backend = RemotePEP503Backend( + server_url="https://cache.example.com/simple", + download_dir=pathlib.Path("/tmp/downloads"), + ) + assert backend.writable is False + + def test_store_raises(self) -> None: + backend = RemotePEP503Backend( + server_url="https://cache.example.com/simple", + download_dir=pathlib.Path("/tmp/downloads"), + ) + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + with pytest.raises(NotImplementedError): + backend.store(key, pathlib.Path("/fake.whl")) + + def test_parse_index_page(self) -> None: + html = """ + + + numpy + requests + Flask + + """ + result = RemotePEP503Backend._parse_index_page(html) + assert canonicalize_name("numpy") in result + assert canonicalize_name("requests") in result + assert canonicalize_name("flask") in result + + def test_parse_project_page(self) -> None: + html = """ + + + numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl + numpy-1.26.4.tar.gz + numpy-1.25.0-cp312-cp312-linux_x86_64.whl + + """ + result = RemotePEP503Backend._parse_project_page( + html, "https://cache.test/simple/numpy/" + ) + assert len(result) == 2 # Only .whl files + assert result[0].filename == "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl" + assert result[0].sha256 == "abc123" + assert result[1].filename == "numpy-1.25.0-cp312-cp312-linux_x86_64.whl" + assert "cache.test" in result[1].url_or_path + + def test_parse_project_page_absolute_url(self) -> None: + html = 'numpy-1.26.4-cp312-cp312-linux_x86_64.whl' + result = RemotePEP503Backend._parse_project_page( + html, "https://cache.test/simple/numpy/" + ) + assert len(result) == 1 + assert ( + result[0].url_or_path + == "https://other.test/numpy-1.26.4-cp312-cp312-linux_x86_64.whl" + ) + + def test_lookup_short_circuits_unknown_package(self) -> None: + backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=pathlib.Path("/tmp"), + ) + # Simulate scan() having populated the available packages set + backend._available_packages = {canonicalize_name("requests")} + + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + assert backend.lookup(key) is None + + def test_scan_populates_available_packages( + self, requests_mock: requests_mock.Mocker + ) -> None: + """scan() fetches the index page and populates available_packages.""" + index_html = """ + + numpy + torch + + """ + requests_mock.get("https://cache.test/simple/", text=index_html) + + backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=pathlib.Path("/tmp"), + ) + result = backend.scan() + + assert result == {} + assert backend._available_packages is not None + assert canonicalize_name("numpy") in backend._available_packages + assert canonicalize_name("torch") in backend._available_packages + assert len(backend._available_packages) == 2 + + def test_scan_handles_network_error( + self, requests_mock: requests_mock.Mocker + ) -> None: + """scan() gracefully handles a network error.""" + import requests + + requests_mock.get( + "https://cache.test/simple/", exc=requests.ConnectionError("timeout") + ) + + backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=pathlib.Path("/tmp"), + ) + result = backend.scan() + + assert result == {} + assert backend._available_packages == set() + + def test_lookup_fetches_project_page_lazily( + self, requests_mock: requests_mock.Mocker + ) -> None: + """lookup() fetches the project page on first access and caches it.""" + project_html = """ + numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl + """ + requests_mock.get("https://cache.test/simple/numpy/", text=project_html) + + backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=pathlib.Path("/tmp"), + ) + backend._available_packages = {canonicalize_name("numpy")} + + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + info = backend.lookup(key) + + assert info is not None + assert info.filename == "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl" + assert info.sha256 == "abc" + # Project page is now cached + assert canonicalize_name("numpy") in backend._project_cache + + def test_lookup_returns_none_for_unmatched_version( + self, requests_mock: requests_mock.Mocker + ) -> None: + """lookup() returns None when version doesn't match any wheel.""" + project_html = """ + numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl + """ + requests_mock.get("https://cache.test/simple/numpy/", text=project_html) + + backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=pathlib.Path("/tmp"), + ) + backend._available_packages = {canonicalize_name("numpy")} + + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("2.0.0"), + build_tag=(1, ""), + ) + assert backend.lookup(key) is None + + def test_lookup_caches_project_page( + self, requests_mock: requests_mock.Mocker + ) -> None: + """Second lookup() for same package does not fetch again.""" + project_html = """ + numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl + """ + requests_mock.get("https://cache.test/simple/numpy/", text=project_html) + + backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=pathlib.Path("/tmp"), + ) + backend._available_packages = {canonicalize_name("numpy")} + + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + backend.lookup(key) + backend.lookup(key) + + # Only one request made to the project page + assert requests_mock.call_count == 1 + + def test_fetch_downloads_wheel( + self, tmp_path: pathlib.Path, requests_mock: requests_mock.Mocker + ) -> None: + """fetch() downloads wheel content to the destination directory.""" + wheel_content = b"PK\x03\x04fake wheel archive content" + requests_mock.get( + "https://cache.test/files/numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + content=wheel_content, + ) + + from fromager.cache import ArtifactInfo + + backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=tmp_path / "downloads", + ) + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + info = ArtifactInfo( + filename="numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + url_or_path="https://cache.test/files/numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + ) + + dest = tmp_path / "dest" + result_path = backend.fetch(key, info, dest) + + assert result_path.exists() + assert result_path.name == "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl" + assert result_path.read_bytes() == wheel_content + + def test_fetch_skips_existing_file(self, tmp_path: pathlib.Path) -> None: + """fetch() returns existing path without downloading if file exists.""" + from fromager.cache import ArtifactInfo + + dest = tmp_path / "dest" + dest.mkdir() + existing = dest / "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl" + existing.write_bytes(b"existing content") + + backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=tmp_path / "downloads", + ) + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + info = ArtifactInfo( + filename="numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + url_or_path="https://cache.test/files/numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + ) + + result_path = backend.fetch(key, info, dest) + + assert result_path == existing + assert result_path.read_bytes() == b"existing content" + + def test_full_scan_lookup_fetch_flow( + self, + tmp_path: pathlib.Path, + requests_mock: requests_mock.Mocker, + ) -> None: + """End-to-end: scan -> lookup -> fetch for a remote backend.""" + index_html = 'numpy' + project_html = """ + numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl + """ + wheel_content = b"wheel bytes" + + requests_mock.get("https://cache.test/simple/", text=index_html) + requests_mock.get("https://cache.test/simple/numpy/", text=project_html) + requests_mock.get( + "https://cache.test/files/numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + content=wheel_content, + ) + + backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=tmp_path / "downloads", + ) + + # Step 1: scan + backend.scan() + assert backend._available_packages is not None + assert canonicalize_name("numpy") in backend._available_packages + + # Step 2: lookup + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + info = backend.lookup(key) + assert info is not None + assert info.sha256 == "deadbeef" + + # Step 3: fetch + dest = tmp_path / "local-cache" + result_path = backend.fetch(key, info, dest) + assert result_path.exists() + assert result_path.read_bytes() == wheel_content + + +# --------------------------------------------------------------------------- +# StoreRouter tests +# --------------------------------------------------------------------------- + + +class TestStoreRouter: + def test_override_wins(self) -> None: + router = StoreRouter( + overrides={canonicalize_name("torch"): "cuda"}, + accelerated_packages=set(), + active_variant="cuda", + ) + req = Requirement("torch>=2.0") + assert router.route(req) == "cuda" + + def test_accelerated_package(self) -> None: + router = StoreRouter( + overrides={}, + accelerated_packages={canonicalize_name("flash-attn")}, + active_variant="cuda", + ) + req = Requirement("flash-attn>=2.0") + assert router.route(req) == "cuda" + + def test_default_fallback(self) -> None: + router = StoreRouter( + overrides={}, + accelerated_packages={canonicalize_name("torch")}, + active_variant="cuda", + ) + req = Requirement("requests>=2.0") + assert router.route(req) == "default" + + def test_override_takes_priority_over_accelerated(self) -> None: + router = StoreRouter( + overrides={canonicalize_name("numpy"): "default"}, + accelerated_packages={canonicalize_name("numpy")}, + active_variant="cuda", + ) + req = Requirement("numpy>=1.0") + assert router.route(req) == "default" + + def test_custom_default_collection(self) -> None: + router = StoreRouter( + overrides={}, + accelerated_packages=set(), + active_variant="rocm", + default_collection="base", + ) + req = Requirement("six") + assert router.route(req) == "base" + + +# --------------------------------------------------------------------------- +# CacheManager tests +# --------------------------------------------------------------------------- + + +def _make_collection(tmp_path: pathlib.Path, name: str) -> CacheCollection: + """Create a CacheCollection with a single local backend.""" + wheels_dir = tmp_path / f"wheels-{name}" + wheels_dir.mkdir(parents=True, exist_ok=True) + backend = LocalDirectoryBackend(wheels_dir, backend_name=f"local:{name}") + return CacheCollection( + name=name, + backends=[backend], + store_backend=backend, + ) + + +class TestCacheManager: + def test_lookup_miss_empty_cache(self, tmp_path: pathlib.Path) -> None: + default = _make_collection(tmp_path, "default") + manager = CacheManager( + collections={"default": default}, + search_order=["default"], + store_routing=StoreRouter( + overrides={}, + accelerated_packages=set(), + active_variant="cpu", + ), + ) + manager.initialize() + + result = manager.lookup_wheel( + Requirement("numpy"), + Version("1.26.4"), + build_tag=(2, ""), + ) + assert result.hit is False + assert result.miss is True + + def test_lookup_hit_local(self, tmp_path: pathlib.Path) -> None: + default = _make_collection(tmp_path, "default") + _create_wheel_file( + tmp_path / "wheels-default", + "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + ) + manager = CacheManager( + collections={"default": default}, + search_order=["default"], + store_routing=StoreRouter( + overrides={}, + accelerated_packages=set(), + active_variant="cpu", + ), + ) + manager.initialize() + + result = manager.lookup_wheel( + Requirement("numpy"), + Version("1.26.4"), + build_tag=(2, ""), + ) + assert result.hit is True + assert result.collection == "default" + assert result.backend_name == "local:default" + assert result.path is not None + + def test_lookup_respects_search_order(self, tmp_path: pathlib.Path) -> None: + """CUDA collection is searched first, default second.""" + cuda = _make_collection(tmp_path, "cuda") + default = _make_collection(tmp_path, "default") + + # Put the wheel only in default + _create_wheel_file( + tmp_path / "wheels-default", + "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + ) + + manager = CacheManager( + collections={"cuda": cuda, "default": default}, + search_order=["cuda", "default"], + store_routing=StoreRouter( + overrides={}, + accelerated_packages=set(), + active_variant="cuda", + ), + ) + manager.initialize() + + result = manager.lookup_wheel( + Requirement("numpy"), + Version("1.26.4"), + build_tag=(2, ""), + ) + assert result.hit is True + assert result.collection == "default" + + def test_lookup_variant_collection_takes_priority( + self, tmp_path: pathlib.Path + ) -> None: + """When wheel exists in both, variant collection wins.""" + cuda = _make_collection(tmp_path, "cuda") + default = _make_collection(tmp_path, "default") + + _create_wheel_file( + tmp_path / "wheels-cuda", + "torch-2.10.0-7-cp312-cp312-linux_x86_64.whl", + ) + _create_wheel_file( + tmp_path / "wheels-default", + "torch-2.10.0-7-cp312-cp312-linux_x86_64.whl", + ) + + manager = CacheManager( + collections={"cuda": cuda, "default": default}, + search_order=["cuda", "default"], + store_routing=StoreRouter( + overrides={}, + accelerated_packages=set(), + active_variant="cuda", + ), + ) + manager.initialize() + + result = manager.lookup_wheel( + Requirement("torch"), + Version("2.10.0"), + build_tag=(7, ""), + ) + assert result.hit is True + assert result.collection == "cuda" + + def test_force_skips_lookup(self, tmp_path: pathlib.Path) -> None: + default = _make_collection(tmp_path, "default") + _create_wheel_file( + tmp_path / "wheels-default", + "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + ) + manager = CacheManager( + collections={"default": default}, + search_order=["default"], + store_routing=StoreRouter( + overrides={}, + accelerated_packages=set(), + active_variant="cpu", + ), + force=True, + ) + manager.initialize() + + result = manager.lookup_wheel( + Requirement("numpy"), + Version("1.26.4"), + build_tag=(2, ""), + ) + assert result.hit is False + + def test_store_routes_to_correct_collection(self, tmp_path: pathlib.Path) -> None: + cuda = _make_collection(tmp_path, "cuda") + default = _make_collection(tmp_path, "default") + + build_dir = tmp_path / "build" + whl = _create_wheel_file( + build_dir, "torch-2.10.0-7-cp312-cp312-linux_x86_64.whl" + ) + + manager = CacheManager( + collections={"cuda": cuda, "default": default}, + search_order=["cuda", "default"], + store_routing=StoreRouter( + overrides={}, + accelerated_packages={canonicalize_name("torch")}, + active_variant="cuda", + ), + ) + manager.initialize() + + result_path = manager.store_wheel( + Requirement("torch"), + Version("2.10.0"), + build_tag=(7, ""), + wheel_path=whl, + ) + + # Should be stored in cuda collection + assert "wheels-cuda" in str(result_path) + assert ( + tmp_path / "wheels-cuda" / "torch-2.10.0-7-cp312-cp312-linux_x86_64.whl" + ).exists() + + def test_store_default_collection(self, tmp_path: pathlib.Path) -> None: + cuda = _make_collection(tmp_path, "cuda") + default = _make_collection(tmp_path, "default") + + build_dir = tmp_path / "build" + whl = _create_wheel_file(build_dir, "requests-2.31.0-1-py3-none-any.whl") + + manager = CacheManager( + collections={"cuda": cuda, "default": default}, + search_order=["cuda", "default"], + store_routing=StoreRouter( + overrides={}, + accelerated_packages={canonicalize_name("torch")}, + active_variant="cuda", + ), + ) + manager.initialize() + + result_path = manager.store_wheel( + Requirement("requests"), + Version("2.31.0"), + build_tag=(1, ""), + wheel_path=whl, + ) + + assert "wheels-default" in str(result_path) + + def test_store_then_lookup(self, tmp_path: pathlib.Path) -> None: + """A stored wheel is immediately findable via lookup.""" + default = _make_collection(tmp_path, "default") + + build_dir = tmp_path / "build" + whl = _create_wheel_file(build_dir, "requests-2.31.0-1-py3-none-any.whl") + + manager = CacheManager( + collections={"default": default}, + search_order=["default"], + store_routing=StoreRouter( + overrides={}, + accelerated_packages=set(), + active_variant="cpu", + ), + ) + manager.initialize() + + manager.store_wheel( + Requirement("requests"), + Version("2.31.0"), + build_tag=(1, ""), + wheel_path=whl, + ) + + result = manager.lookup_wheel( + Requirement("requests"), + Version("2.31.0"), + build_tag=(1, ""), + ) + assert result.hit is True + + def test_store_unknown_collection_raises(self, tmp_path: pathlib.Path) -> None: + default = _make_collection(tmp_path, "default") + build_dir = tmp_path / "build" + whl = _create_wheel_file( + build_dir, "torch-2.10.0-7-cp312-cp312-linux_x86_64.whl" + ) + + manager = CacheManager( + collections={"default": default}, + search_order=["default"], + store_routing=StoreRouter( + overrides={}, + accelerated_packages={canonicalize_name("torch")}, + active_variant="cuda", # No "cuda" collection configured + ), + ) + manager.initialize() + + with pytest.raises(ValueError, match="unknown collection"): + manager.store_wheel( + Requirement("torch"), + Version("2.10.0"), + build_tag=(7, ""), + wheel_path=whl, + ) + + +# --------------------------------------------------------------------------- +# CacheStats tests +# --------------------------------------------------------------------------- + + +class TestCacheStats: + def test_empty_stats(self) -> None: + stats = CacheStats() + assert stats.hits == 0 + assert stats.misses == 0 + assert stats.stores == 0 + assert stats.hit_rate == 0.0 + + def test_record_hit(self) -> None: + stats = CacheStats() + stats.record_hit(Requirement("numpy"), Version("1.26.4"), "default", "local") + assert stats.hits == 1 + assert stats.misses == 0 + + def test_record_miss(self) -> None: + stats = CacheStats() + stats.record_miss(Requirement("numpy"), Version("1.26.4"), "not_found") + assert stats.misses == 1 + assert stats.hits == 0 + + def test_hit_rate(self) -> None: + stats = CacheStats() + stats.record_hit(Requirement("numpy"), Version("1.26.4"), "default", "local") + stats.record_hit(Requirement("requests"), Version("2.31.0"), "default", "local") + stats.record_miss(Requirement("torch"), Version("2.10.0"), "not_found") + assert stats.hit_rate == pytest.approx(2 / 3) + + def test_summary(self) -> None: + stats = CacheStats() + stats.record_hit( + Requirement("numpy"), Version("1.26.4"), "default", "local:default" + ) + stats.record_miss(Requirement("torch"), Version("2.10.0"), "not_found") + stats.record_store(Requirement("torch"), Version("2.10.0"), "cuda") + + summary = stats.summary() + assert summary["hits"]["total"] == 1 + assert summary["hits"]["by_collection"]["default"] == 1 + assert summary["misses"] == 1 + assert summary["stores"] == 1 + + +# --------------------------------------------------------------------------- +# CacheResult tests +# --------------------------------------------------------------------------- + + +class TestCacheResult: + def test_miss_property(self) -> None: + result = CacheResult(hit=False) + assert result.miss is True + assert result.hit is False + + def test_hit_result(self) -> None: + result = CacheResult( + hit=True, + path=pathlib.Path("/some/wheel.whl"), + collection="default", + backend_name="local:default", + ) + assert result.miss is False + assert result.path == pathlib.Path("/some/wheel.whl") + + +# --------------------------------------------------------------------------- +# CacheManager + RemotePEP503Backend integration tests +# --------------------------------------------------------------------------- + + +class TestCacheManagerRemoteIntegration: + """Integration tests verifying CacheManager with remote backends.""" + + def test_remote_hit_downloads_to_local_store( + self, + tmp_path: pathlib.Path, + requests_mock: requests_mock.Mocker, + ) -> None: + """CacheManager lookup via remote backend downloads wheel locally.""" + + # Set up remote backend with a wheel available + project_html = """ + numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl + """ + wheel_content = b"remote wheel content" + requests_mock.get( + "https://cache.test/simple/", text='numpy' + ) + requests_mock.get("https://cache.test/simple/numpy/", text=project_html) + requests_mock.get( + "https://cache.test/files/numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + content=wheel_content, + ) + + local_dir = tmp_path / "local-wheels" + local_dir.mkdir() + local_backend = LocalDirectoryBackend(local_dir, backend_name="local:default") + + remote_backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=tmp_path / "downloads", + ) + + collection = CacheCollection( + name="default", + backends=[local_backend, remote_backend], + store_backend=local_backend, + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + manager = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + ) + manager.initialize() + + # Lookup should miss local, hit remote, and download + result = manager.lookup_wheel( + Requirement("numpy"), + Version("1.26.4"), + build_tag=(2, ""), + ) + + assert result.hit is True + assert result.path is not None + assert result.path.exists() + assert result.path.read_bytes() == wheel_content + assert result.backend_name == "remote:https://cache.test/simple" + assert result.was_downloaded is True + + def test_local_hit_takes_priority_over_remote( + self, + tmp_path: pathlib.Path, + requests_mock: requests_mock.Mocker, + ) -> None: + """Local backend hit means remote is never consulted.""" + local_dir = tmp_path / "local-wheels" + local_dir.mkdir() + # Put a wheel in the local cache + wheel_file = local_dir / "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl" + wheel_file.write_bytes(b"local wheel") + + local_backend = LocalDirectoryBackend(local_dir, backend_name="local:default") + remote_backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=tmp_path / "downloads", + ) + + # Don't register any requests_mock responses — remote should not be hit + requests_mock.get( + "https://cache.test/simple/", text='numpy' + ) + + collection = CacheCollection( + name="default", + backends=[local_backend, remote_backend], + store_backend=local_backend, + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + manager = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + ) + manager.initialize() + + result = manager.lookup_wheel( + Requirement("numpy"), + Version("1.26.4"), + build_tag=(2, ""), + ) + + assert result.hit is True + assert result.path == wheel_file + assert result.backend_name == "local:default" + assert result.was_downloaded is False + # Remote project page never fetched + assert not any("simple/numpy/" in h.url for h in requests_mock.request_history) + + def test_hierarchical_search_across_collections_with_remote( + self, + tmp_path: pathlib.Path, + requests_mock: requests_mock.Mocker, + ) -> None: + """CacheManager searches variant collection first, then falls through to default.""" + # CUDA collection: empty (no local, remote has nothing matching) + cuda_local = tmp_path / "cuda-wheels" + cuda_local.mkdir() + cuda_backend = LocalDirectoryBackend(cuda_local, backend_name="local:cuda") + + # Default collection: has numpy via remote + default_local = tmp_path / "default-wheels" + default_local.mkdir() + default_backend = LocalDirectoryBackend( + default_local, backend_name="local:default" + ) + + project_html = """ + numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl + """ + requests_mock.get( + "https://cache.test/simple/", text='numpy' + ) + requests_mock.get("https://cache.test/simple/numpy/", text=project_html) + requests_mock.get( + "https://cache.test/files/numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + content=b"remote numpy", + ) + + remote_default = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=tmp_path / "downloads", + ) + + cuda_collection = CacheCollection( + name="cuda", + backends=[cuda_backend], + store_backend=cuda_backend, + ) + default_collection = CacheCollection( + name="default", + backends=[default_backend, remote_default], + store_backend=default_backend, + ) + + router = StoreRouter( + overrides={}, + accelerated_packages={canonicalize_name("torch")}, + active_variant="cuda", + ) + manager = CacheManager( + collections={"cuda": cuda_collection, "default": default_collection}, + search_order=["cuda", "default"], + store_routing=router, + ) + manager.initialize() + + # numpy is not in cuda, should be found in default via remote + result = manager.lookup_wheel( + Requirement("numpy"), + Version("1.26.4"), + build_tag=(2, ""), + ) + + assert result.hit is True + assert result.collection == "default" + assert result.path is not None + assert result.path.read_bytes() == b"remote numpy" From 28b92ada0fe6779768d622416cbc692c861fe26e Mon Sep 17 00:00:00 2001 From: Sean Pryor Date: Wed, 24 Jun 2026 09:32:16 -0400 Subject: [PATCH 2/2] feat(cache): route built wheels to variant or shared collection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When the CacheManager is active and a non-cpu variant has top-level requirements, newly built wheels are routed to the appropriate collection directory: - Packages listed in the variant's requirements file → wheels-repo/variants// - Unlisted transitive dependencies → wheels-repo/downloads/ (shared default) This ensures the variant collection contains exactly the packages explicitly requested in the requirements file, while common deps discovered during dependency resolution stay in the shared collection for reuse across variants. Changes: - `_build_cache_manager` accepts `toplevel_reqs` and uses them as the variant package set for store routing - `StoreRouter` routes based on variant_packages (from requirements file), not just pre-built packages - `_phase_build` calls `store_wheel` after building (not for cached hits) - `LocalDirectoryBackend.store()` copies (not moves) to preserve the original in downloads/ for the internal wheel server Co-Authored-By: Claude Signed-off-by: Sean Pryor Co-authored-by: Cursor Signed-off-by: Sean Pryor Co-authored-by: Cursor --- src/fromager/bootstrapper.py | 20 +++ src/fromager/cache.py | 19 +-- src/fromager/commands/bootstrap.py | 4 +- src/fromager/commands/cache_cmd.py | 54 ++++++-- tests/test_bootstrapper_iterative.py | 181 +++++++++++++++++++++++++++ tests/test_cache.py | 5 +- 6 files changed, 263 insertions(+), 20 deletions(-) diff --git a/src/fromager/bootstrapper.py b/src/fromager/bootstrapper.py index 88bfc9494..cb6b2732d 100644 --- a/src/fromager/bootstrapper.py +++ b/src/fromager/bootstrapper.py @@ -1533,6 +1533,12 @@ def _phase_prepare_source(self, item: WorkItem) -> list[WorkItem]: # deps resolution needed. Install deps are extracted from the wheel. if cached_wheel and self.ctx.cache is not None: server.update_wheel_mirror(self.ctx) + # Route to the correct collection (e.g., variant dir for listed packages) + pbi = self.ctx.package_build_info(item.req) + build_tag = pbi.build_tag(item.resolved_version) + self.ctx.cache.store_wheel( + item.req, item.resolved_version, build_tag, cached_wheel + ) unpack_dir = self._create_unpack_dir(item.req, item.resolved_version) item.build_result = SourceBuildResult( wheel_filename=cached_wheel, @@ -1664,6 +1670,20 @@ def _phase_build(self, item: WorkItem) -> list[WorkItem]: cached_wheel_filename=item.cached_wheel_filename, ) + # Route newly built wheels to the appropriate collection directory. + # This copies the wheel into the routed collection's storage while + # keeping the original in downloads/ for the internal wheel server. + if ( + wheel_filename is not None + and self.ctx.cache is not None + and not item.cached_wheel_filename + ): + pbi = self.ctx.package_build_info(item.req) + build_tag = pbi.build_tag(item.resolved_version) + self.ctx.cache.store_wheel( + item.req, item.resolved_version, build_tag, wheel_filename + ) + source_type = sources.get_source_type(self.ctx, item.req) item.build_result = SourceBuildResult( diff --git a/src/fromager/cache.py b/src/fromager/cache.py index dae284007..f9122eb37 100644 --- a/src/fromager/cache.py +++ b/src/fromager/cache.py @@ -358,13 +358,14 @@ def fetch( def store(self, key: WheelCacheKey, artifact: pathlib.Path) -> ArtifactInfo: """Register an artifact in this backend's directory. - If the artifact is not already in the directory, it is moved there. + If the artifact is not already in the directory, it is copied there + (preserving the original for the internal wheel server index). Updates the in-memory index. """ dest = self._directory / artifact.name if not dest.exists(): self._directory.mkdir(parents=True, exist_ok=True) - shutil.move(str(artifact), str(dest)) + shutil.copy2(str(artifact), str(dest)) info = ArtifactInfo( filename=dest.name, @@ -560,19 +561,21 @@ class StoreRouter: Routing priority: 1. Explicit per-package override (from overrides.yaml) - 2. Listed in the variant's accelerated requirements file - 3. Default collection + 2. Listed in the variant's requirements file (variant_packages) + 3. Default collection (shared/common dependencies) """ def __init__( self, overrides: dict[NormalizedName, str], - accelerated_packages: set[NormalizedName], - active_variant: str, + variant_packages: set[NormalizedName] | None = None, + active_variant: str = "cpu", default_collection: str = "default", + # Keep old kwarg name for backward compatibility + accelerated_packages: set[NormalizedName] | None = None, ) -> None: self._overrides = overrides - self._accelerated_packages = accelerated_packages + self._variant_packages = variant_packages or accelerated_packages or set() self._active_variant = active_variant self._default_collection = default_collection @@ -583,7 +586,7 @@ def route(self, req: Requirement) -> str: if name in self._overrides: return self._overrides[name] - if name in self._accelerated_packages: + if name in self._variant_packages: return self._active_variant return self._default_collection diff --git a/src/fromager/commands/bootstrap.py b/src/fromager/commands/bootstrap.py index 6fc45d86c..e63795ddf 100644 --- a/src/fromager/commands/bootstrap.py +++ b/src/fromager/commands/bootstrap.py @@ -196,7 +196,9 @@ def bootstrap( logger.info("treating %s as pre-built wheels", sorted(pre_built)) if use_cache_manager: - cache_mgr = _build_cache_manager(wkctx, cache_url=cache_wheel_server_url) + cache_mgr = _build_cache_manager( + wkctx, cache_url=cache_wheel_server_url, toplevel_reqs=to_build + ) wkctx.cache = cache_mgr logger.info( "cache manager enabled with %d collection(s): %s", diff --git a/src/fromager/commands/cache_cmd.py b/src/fromager/commands/cache_cmd.py index 8c43141ac..3f7b03218 100644 --- a/src/fromager/commands/cache_cmd.py +++ b/src/fromager/commands/cache_cmd.py @@ -7,6 +7,7 @@ import click import rich import rich.box +from packaging.requirements import Requirement from packaging.utils import canonicalize_name from rich.table import Table @@ -26,27 +27,40 @@ def _build_cache_manager( wkctx: context.WorkContext, cache_url: str | None = None, + toplevel_reqs: list[Requirement] | None = None, ) -> CacheManager: """Construct a CacheManager from the WorkContext configuration. If the context already has a cache configured, return it. Otherwise, build one from the standard filesystem layout. + When a non-default variant is active and top-level requirements are provided, + two collections are created: + - A variant-specific collection for packages listed in the variant's + requirements file (the "main" packages for this build) + - A shared "default" collection for unlisted transitive dependencies + + Newly built wheels are routed by the StoreRouter: packages listed in the + variant requirements go to the variant collection, unlisted deps to default. + Args: wkctx: The work context providing local paths and variant info. cache_url: Optional URL to a remote PEP 503 cache server. + toplevel_reqs: Top-level requirements from the variant's requirements + file. These define which packages belong to the variant collection. """ if wkctx.cache is not None: return wkctx.cache - local_backend = LocalDirectoryBackend( + # Shared (default) collection: downloads + prebuilt + optional remote + shared_backend = LocalDirectoryBackend( wkctx.wheels_downloads, backend_name="local:downloads" ) prebuilt_backend = LocalDirectoryBackend( wkctx.wheels_prebuilt, backend_name="local:prebuilt" ) - backends: list[CacheBackend] = [local_backend, prebuilt_backend] + shared_backends: list[CacheBackend] = [shared_backend, prebuilt_backend] if cache_url: remote_backend = RemotePEP503Backend( @@ -54,23 +68,45 @@ def _build_cache_manager( download_dir=wkctx.wheels_downloads, backend_name=f"remote:{cache_url}", ) - backends.append(remote_backend) + shared_backends.append(remote_backend) - collection = CacheCollection( + default_collection = CacheCollection( name="default", - backends=backends, - store_backend=local_backend, + backends=shared_backends, + store_backend=shared_backend, ) + collections: dict[str, CacheCollection] = {"default": default_collection} + search_order: list[str] = ["default"] + + # Variant-specific collection for packages listed in the requirements file + variant_packages = {canonicalize_name(r.name) for r in (toplevel_reqs or [])} + + if variant_packages and wkctx.variant != "cpu": + variant_dir = wkctx.wheels_repo / "variants" / wkctx.variant + variant_dir.mkdir(parents=True, exist_ok=True) + variant_backend = LocalDirectoryBackend( + variant_dir, backend_name=f"local:{wkctx.variant}" + ) + + variant_backends: list[CacheBackend] = [variant_backend, prebuilt_backend] + variant_collection = CacheCollection( + name=wkctx.variant, + backends=variant_backends, + store_backend=variant_backend, + ) + collections[wkctx.variant] = variant_collection + search_order = [wkctx.variant, "default"] + router = StoreRouter( overrides={}, - accelerated_packages=set(), + accelerated_packages=variant_packages, active_variant=wkctx.variant, ) manager = CacheManager( - collections={"default": collection}, - search_order=["default"], + collections=collections, + search_order=search_order, store_routing=router, ) manager.initialize() diff --git a/tests/test_bootstrapper_iterative.py b/tests/test_bootstrapper_iterative.py index 933055e17..43889d281 100644 --- a/tests/test_bootstrapper_iterative.py +++ b/tests/test_bootstrapper_iterative.py @@ -1787,6 +1787,187 @@ def test_returns_single_item_at_process_install_deps( assert result[0] is item assert item.phase == BootstrapPhase.PROCESS_INSTALL_DEPS + def test_store_wheel_called_for_freshly_built_wheel( + self, tmp_context: WorkContext + ) -> None: + """When CacheManager is active and a wheel is built (not cached), + store_wheel routes it to the appropriate collection.""" + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + wheels_dir = tmp_context.wheels_downloads + variant_dir = tmp_context.wheels_repo / "variants" / "gaudi-ubi9" + variant_dir.mkdir(parents=True) + + shared_backend = LocalDirectoryBackend(wheels_dir, backend_name="local:shared") + variant_backend = LocalDirectoryBackend( + variant_dir, backend_name="local:gaudi-ubi9" + ) + + default_coll = CacheCollection( + name="default", backends=[shared_backend], store_backend=shared_backend + ) + variant_coll = CacheCollection( + name="gaudi-ubi9", + backends=[variant_backend], + store_backend=variant_backend, + ) + + router = StoreRouter( + overrides={}, + accelerated_packages={canonicalize_name("testpkg")}, + active_variant="gaudi-ubi9", + ) + cache_mgr = CacheManager( + collections={"default": default_coll, "gaudi-ubi9": variant_coll}, + search_order=["gaudi-ubi9", "default"], + store_routing=router, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + mock_env = Mock() + sdist_root = tmp_context.work_dir / "testpkg-1.0" / "testpkg-1.0" + sdist_root.parent.mkdir(parents=True, exist_ok=True) + item = _make_build_item( + phase=BootstrapPhase.BUILD, + build_env=mock_env, + sdist_root_dir=sdist_root, + ) + + built_wheel = wheels_dir / "testpkg-1.0-1-py3-none-any.whl" + built_wheel.write_bytes(b"fake built wheel") + + with ( + patch.object(bt, "_do_build", return_value=(built_wheel, None)), + patch("fromager.sources.get_source_type", return_value=SourceType.SDIST), + ): + bt._phase_build(item) + + # Wheel was copied to the variant collection directory + assert (variant_dir / "testpkg-1.0-1-py3-none-any.whl").exists() + # Original in downloads preserved for internal wheel server + assert built_wheel.exists() + + def test_store_wheel_routes_unlisted_to_default( + self, tmp_context: WorkContext + ) -> None: + """Unlisted packages (not accelerated) are stored in the default collection.""" + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + wheels_dir = tmp_context.wheels_downloads + variant_dir = tmp_context.wheels_repo / "variants" / "gaudi-ubi9" + variant_dir.mkdir(parents=True) + + shared_backend = LocalDirectoryBackend(wheels_dir, backend_name="local:shared") + variant_backend = LocalDirectoryBackend( + variant_dir, backend_name="local:gaudi-ubi9" + ) + + default_coll = CacheCollection( + name="default", backends=[shared_backend], store_backend=shared_backend + ) + variant_coll = CacheCollection( + name="gaudi-ubi9", + backends=[variant_backend], + store_backend=variant_backend, + ) + + # testpkg is NOT in accelerated_packages + router = StoreRouter( + overrides={}, + accelerated_packages={canonicalize_name("torch")}, + active_variant="gaudi-ubi9", + ) + cache_mgr = CacheManager( + collections={"default": default_coll, "gaudi-ubi9": variant_coll}, + search_order=["gaudi-ubi9", "default"], + store_routing=router, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + mock_env = Mock() + sdist_root = tmp_context.work_dir / "testpkg-1.0" / "testpkg-1.0" + sdist_root.parent.mkdir(parents=True, exist_ok=True) + item = _make_build_item( + phase=BootstrapPhase.BUILD, + build_env=mock_env, + sdist_root_dir=sdist_root, + ) + + built_wheel = wheels_dir / "testpkg-1.0-1-py3-none-any.whl" + built_wheel.write_bytes(b"fake built wheel") + + with ( + patch.object(bt, "_do_build", return_value=(built_wheel, None)), + patch("fromager.sources.get_source_type", return_value=SourceType.SDIST), + ): + bt._phase_build(item) + + # Wheel stays in downloads (default collection dir == downloads) + assert built_wheel.exists() + # NOT copied to variant dir + assert not (variant_dir / "testpkg-1.0-1-py3-none-any.whl").exists() + + def test_store_wheel_skipped_when_cached(self, tmp_context: WorkContext) -> None: + """When a cached wheel is used (not freshly built), store_wheel is not called.""" + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + wheels_dir = tmp_context.wheels_downloads + shared_backend = LocalDirectoryBackend(wheels_dir, backend_name="local:shared") + default_coll = CacheCollection( + name="default", backends=[shared_backend], store_backend=shared_backend + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + cache_mgr = CacheManager( + collections={"default": default_coll}, + search_order=["default"], + store_routing=router, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + mock_env = Mock() + sdist_root = tmp_context.work_dir / "testpkg-1.0" / "testpkg-1.0" + sdist_root.parent.mkdir(parents=True, exist_ok=True) + cached_wheel = wheels_dir / "testpkg-1.0-1-py3-none-any.whl" + cached_wheel.write_bytes(b"cached wheel") + item = _make_build_item( + phase=BootstrapPhase.BUILD, + build_env=mock_env, + sdist_root_dir=sdist_root, + cached_wheel_filename=cached_wheel, + ) + + with ( + patch.object(bt, "_do_build", return_value=(cached_wheel, None)), + patch("fromager.sources.get_source_type", return_value=SourceType.SDIST), + patch.object(cache_mgr, "store_wheel") as mock_store, + ): + bt._phase_build(item) + + mock_store.assert_not_called() + class TestPhaseProcessInstallDeps: """Tests for _phase_process_install_deps: hooks, dep extraction, error modes.""" diff --git a/tests/test_cache.py b/tests/test_cache.py index 7831ef97c..d54010c08 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -260,7 +260,8 @@ def test_fetch_returns_local_path(self, tmp_path: pathlib.Path) -> None: result = backend.fetch(key, info, tmp_path / "dest") assert result == whl.resolve() - def test_store_moves_file(self, tmp_path: pathlib.Path) -> None: + def test_store_copies_file(self, tmp_path: pathlib.Path) -> None: + """Store copies the wheel to the collection directory, preserving the original.""" wheels_dir = tmp_path / "wheels" wheels_dir.mkdir() backend = LocalDirectoryBackend(wheels_dir) @@ -279,7 +280,7 @@ def test_store_moves_file(self, tmp_path: pathlib.Path) -> None: assert info.filename == "requests-2.31.0-1-py3-none-any.whl" assert (wheels_dir / "requests-2.31.0-1-py3-none-any.whl").exists() - assert not whl.exists() # Moved, not copied + assert whl.exists() # Original preserved for internal wheel server def test_store_updates_index(self, tmp_path: pathlib.Path) -> None: wheels_dir = tmp_path / "wheels"