diff --git a/pyproject.toml b/pyproject.toml index adf93290..f76f71af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,6 +126,7 @@ canonicalize = "fromager.commands.canonicalize:canonicalize" download-sequence = "fromager.commands.download_sequence:download_sequence" wheel-server = "fromager.commands.server:wheel_server" lint-requirements = "fromager.commands.lint_requirements:lint_requirements" +cache = "fromager.commands.cache_cmd:cache" [tool.coverage.run] branch = true diff --git a/src/fromager/bootstrapper.py b/src/fromager/bootstrapper.py index 3389a862..c6163a9b 100644 --- a/src/fromager/bootstrapper.py +++ b/src/fromager/bootstrapper.py @@ -655,15 +655,44 @@ def _find_cached_wheel( ) -> tuple[pathlib.Path | None, pathlib.Path | None]: """Look for cached wheel in 3 locations. - Checks for cached wheels in order: - 1. wheels_build directory (previously built) - 2. wheels_downloads directory (previously downloaded) - 3. Cache server (remote cache) + When a CacheManager is configured on the context, delegates to it + for a unified hierarchical lookup across collections. Otherwise + falls back to the legacy per-directory search. Returns: Tuple of (cached_wheel_filename, unpacked_cached_wheel). Both None if no cache hit. """ + if self.ctx.cache is not None: + return self._find_cached_wheel_via_manager(req, resolved_version) + return self._find_cached_wheel_legacy(req, resolved_version) + + def _find_cached_wheel_via_manager( + self, + req: Requirement, + resolved_version: Version, + ) -> tuple[pathlib.Path | None, pathlib.Path | None]: + """Cache lookup using the CacheManager.""" + assert self.ctx.cache is not None + pbi = self.ctx.package_build_info(req) + build_tag = pbi.build_tag(resolved_version) + + result = self.ctx.cache.lookup_wheel(req, resolved_version, build_tag) + if not result.hit: + return None, None + + assert result.path is not None + metadata_dir = self._unpack_metadata_from_wheel( + req, resolved_version, result.path + ) + return result.path, metadata_dir + + def _find_cached_wheel_legacy( + self, + req: Requirement, + resolved_version: Version, + ) -> tuple[pathlib.Path | None, pathlib.Path | None]: + """Legacy cache lookup: check build dir, downloads dir, remote cache.""" # Check if we have previously built a wheel and still have it on the # local filesystem. cached_wheel, unpacked = self._look_for_existing_wheel( @@ -1493,12 +1522,35 @@ def _phase_prepare_source(self, item: WorkItem) -> list[WorkItem]: item.phase = BootstrapPhase.PROCESS_INSTALL_DEPS return [item] - # Source build path + # Source build path: try cache first cached_wheel, unpacked = self._find_cached_wheel( item.req, item.resolved_version ) item.cached_wheel_filename = cached_wheel + # Short-circuit: when CacheManager provides a hit, skip directly to + # PROCESS_INSTALL_DEPS -- no source download, no build env, no build + # deps resolution needed. Install deps are extracted from the wheel. + if cached_wheel and self.ctx.cache is not None: + # Route to the correct collection (e.g., variant dir for listed packages) + pbi = self.ctx.package_build_info(item.req) + build_tag = pbi.build_tag(item.resolved_version) + self.ctx.cache.store_wheel( + item.req, item.resolved_version, build_tag, cached_wheel + ) + server.update_wheel_mirror(self.ctx) + unpack_dir = self._create_unpack_dir(item.req, item.resolved_version) + item.build_result = SourceBuildResult( + wheel_filename=cached_wheel, + sdist_filename=None, + unpack_dir=unpack_dir, + sdist_root_dir=None, + build_env=None, + source_type=SourceType.CACHED, + ) + item.phase = BootstrapPhase.PROCESS_INSTALL_DEPS + return [item] + if not unpacked: logger.debug("no cached wheel, downloading sources") source_filename = self._download_source( @@ -1618,6 +1670,20 @@ def _phase_build(self, item: WorkItem) -> list[WorkItem]: cached_wheel_filename=item.cached_wheel_filename, ) + # Route newly built wheels to the appropriate collection directory. + # This copies the wheel into the routed collection's storage while + # keeping the original in downloads/ for the internal wheel server. + if ( + wheel_filename is not None + and self.ctx.cache is not None + and not item.cached_wheel_filename + ): + pbi = self.ctx.package_build_info(item.req) + build_tag = pbi.build_tag(item.resolved_version) + self.ctx.cache.store_wheel( + item.req, item.resolved_version, build_tag, wheel_filename + ) + source_type = sources.get_source_type(self.ctx, item.req) item.build_result = SourceBuildResult( diff --git a/src/fromager/cache.py b/src/fromager/cache.py new file mode 100644 index 00000000..13122450 --- /dev/null +++ b/src/fromager/cache.py @@ -0,0 +1,814 @@ +"""Unified cache subsystem for Fromager artifact management. + +Provides a layered cache with collection-based organization, supporting +hierarchical lookup across local directories and remote PEP 503 repositories. + +Collections represent logically grouped artifacts (e.g., "default", "cuda", +"rocm"). Each collection has one or more backends (local filesystem, remote +index). Lookups traverse collections in priority order, and store routing +determines which collection receives newly built artifacts. +""" + +from __future__ import annotations + +import dataclasses +import hashlib +import logging +import pathlib +import re +import shutil +import time +import typing +from urllib.parse import urlparse + +from packaging.requirements import Requirement +from packaging.utils import ( + BuildTag, + NormalizedName, + canonicalize_name, + parse_wheel_filename, +) +from packaging.version import Version + +from .request_session import session + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Cache Keys +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass(frozen=True) +class WheelCacheKey: + """Identifies a cached wheel artifact. + + The key is intentionally simple -- collection routing is handled + externally by the CacheManager, not embedded in the key. + """ + + package: NormalizedName + version: Version + build_tag: BuildTag # (int, str) from changelog; () if untagged + + def __str__(self) -> str: + tag_str = f"-{self.build_tag[0]}{self.build_tag[1]}" if self.build_tag else "" + return f"{self.package}=={self.version}{tag_str}" + + +@dataclasses.dataclass(frozen=True) +class SdistCacheKey: + """Identifies a cached sdist artifact.""" + + package: NormalizedName + version: Version + + def __str__(self) -> str: + return f"{self.package}=={self.version}" + + +# --------------------------------------------------------------------------- +# Artifact Metadata +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class ArtifactInfo: + """Lightweight metadata for a cached artifact. + + Produced by scanning backends. For local backends, ``url_or_path`` is + an absolute filesystem path. For remote backends, it is a download URL. + """ + + filename: str + url_or_path: str + size_bytes: int | None = None + sha256: str | None = None + + +# --------------------------------------------------------------------------- +# Cache Result +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class CacheResult: + """Result of a cache lookup operation.""" + + hit: bool + path: pathlib.Path | None = None + collection: str = "" + backend_name: str = "" + build_tag: BuildTag = () + was_downloaded: bool = False + + @property + def miss(self) -> bool: + return not self.hit + + +# --------------------------------------------------------------------------- +# Observability +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class CacheEvent: + """A single cache interaction event.""" + + timestamp: float + action: typing.Literal["hit", "miss", "store"] + artifact_type: typing.Literal["wheel", "sdist"] + package: str + version: str + collection: str + backend: str + duration_ms: float | None = None + + +@dataclasses.dataclass +class CacheStats: + """Accumulates cache events for a single run.""" + + events: list[CacheEvent] = dataclasses.field(default_factory=list) + + def record_hit( + self, + req: Requirement, + version: Version, + collection: str, + backend: str, + artifact_type: typing.Literal["wheel", "sdist"] = "wheel", + duration_ms: float | None = None, + ) -> None: + self.events.append( + CacheEvent( + timestamp=time.monotonic(), + action="hit", + artifact_type=artifact_type, + package=str(req.name), + version=str(version), + collection=collection, + backend=backend, + duration_ms=duration_ms, + ) + ) + + def record_miss( + self, + req: Requirement, + version: Version, + reason: str, + artifact_type: typing.Literal["wheel", "sdist"] = "wheel", + ) -> None: + self.events.append( + CacheEvent( + timestamp=time.monotonic(), + action="miss", + artifact_type=artifact_type, + package=str(req.name), + version=str(version), + collection="", + backend=reason, + ) + ) + + def record_store( + self, + req: Requirement, + version: Version, + collection: str, + artifact_type: typing.Literal["wheel", "sdist"] = "wheel", + ) -> None: + self.events.append( + CacheEvent( + timestamp=time.monotonic(), + action="store", + artifact_type=artifact_type, + package=str(req.name), + version=str(version), + collection=collection, + backend="local", + ) + ) + + @property + def hits(self) -> int: + return sum(1 for e in self.events if e.action == "hit") + + @property + def misses(self) -> int: + return sum(1 for e in self.events if e.action == "miss") + + @property + def stores(self) -> int: + return sum(1 for e in self.events if e.action == "store") + + @property + def hit_rate(self) -> float: + total = self.hits + self.misses + if total == 0: + return 0.0 + return self.hits / total + + def summary(self) -> dict[str, typing.Any]: + """Return a structured summary suitable for JSON serialization.""" + hits_by_collection: dict[str, int] = {} + hits_by_backend: dict[str, int] = {} + for e in self.events: + if e.action == "hit": + hits_by_collection[e.collection] = ( + hits_by_collection.get(e.collection, 0) + 1 + ) + hits_by_backend[e.backend] = hits_by_backend.get(e.backend, 0) + 1 + return { + "hits": { + "total": self.hits, + "by_collection": hits_by_collection, + "by_backend": hits_by_backend, + }, + "misses": self.misses, + "stores": self.stores, + "hit_rate": round(self.hit_rate, 4), + } + + +# --------------------------------------------------------------------------- +# Cache Backend Protocol +# --------------------------------------------------------------------------- + + +class CacheBackend(typing.Protocol): + """Protocol for a single storage location that can find and store artifacts.""" + + @property + def name(self) -> str: + """Human-readable identifier (e.g., 'local:default', 'remote:https://...').""" + ... + + @property + def writable(self) -> bool: + """Whether this backend supports store operations.""" + ... + + def scan(self) -> dict[WheelCacheKey, ArtifactInfo]: + """Bulk index at startup. Local backends return full inventory; + remote backends fetch the top-level package list only and return empty. + """ + ... + + def lookup(self, key: WheelCacheKey) -> ArtifactInfo | None: + """Find a specific artifact by key. + + For local backends, checks the in-memory index. + For remote backends, lazily fetches the project page on first access. + """ + ... + + def fetch( + self, key: WheelCacheKey, info: ArtifactInfo, dest: pathlib.Path + ) -> pathlib.Path: + """Retrieve artifact to a local path. + + For local backends, returns the existing path (no-op). + For remote backends, downloads the file to ``dest``. + """ + ... + + def store(self, key: WheelCacheKey, artifact: pathlib.Path) -> ArtifactInfo: + """Store a newly built artifact. Only valid if ``writable`` is True.""" + ... + + +# --------------------------------------------------------------------------- +# Local Directory Backend +# --------------------------------------------------------------------------- + + +class LocalDirectoryBackend: + """Cache backend backed by a local filesystem directory. + + Scans at startup to populate an in-memory index from existing wheel files. + New stores are reflected immediately in the index. + """ + + def __init__( + self, + directory: pathlib.Path, + backend_name: str = "local", + ) -> None: + self._directory = directory + self._backend_name = backend_name + self._index: dict[WheelCacheKey, ArtifactInfo] = {} + + @property + def name(self) -> str: + return self._backend_name + + @property + def writable(self) -> bool: + return True + + @property + def directory(self) -> pathlib.Path: + return self._directory + + def scan(self) -> dict[WheelCacheKey, ArtifactInfo]: + """Scan the directory for wheel files and populate the index.""" + self._index.clear() + if not self._directory.exists(): + return self._index + + for wheel_file in self._directory.glob("*.whl"): + try: + name, version, build_tag, _ = parse_wheel_filename(wheel_file.name) + key = WheelCacheKey( + package=name, + version=version, + build_tag=build_tag, + ) + info = ArtifactInfo( + filename=wheel_file.name, + url_or_path=str(wheel_file.resolve()), + size_bytes=wheel_file.stat().st_size, + ) + self._index[key] = info + except Exception: + logger.debug("skipping unparseable wheel file: %s", wheel_file.name) + logger.debug("scanned %d wheels in %s", len(self._index), self._directory) + return dict(self._index) + + def lookup(self, key: WheelCacheKey) -> ArtifactInfo | None: + """Look up artifact in the in-memory index.""" + info = self._index.get(key) + if info is not None: + file_path = pathlib.Path(info.url_or_path) + if file_path.exists(): + return info + # File was removed since scan -- evict from index + del self._index[key] + return None + + def fetch( + self, key: WheelCacheKey, info: ArtifactInfo, dest: pathlib.Path + ) -> pathlib.Path: + """Return the existing local path (no-op for local backends).""" + return pathlib.Path(info.url_or_path) + + def store(self, key: WheelCacheKey, artifact: pathlib.Path) -> ArtifactInfo: + """Register an artifact in this backend's directory. + + If the artifact is not already in the directory, it is copied there + (preserving the original for the internal wheel server index). + Updates the in-memory index. + """ + dest = self._directory / artifact.name + self._directory.mkdir(parents=True, exist_ok=True) + if not dest.exists() or not artifact.samefile(dest): + shutil.copy2(str(artifact), str(dest)) + + info = ArtifactInfo( + filename=dest.name, + url_or_path=str(dest.resolve()), + size_bytes=dest.stat().st_size, + ) + self._index[key] = info + return info + + +# --------------------------------------------------------------------------- +# Remote PEP 503 Backend +# --------------------------------------------------------------------------- + + +class RemotePEP503Backend: + """Cache backend backed by a remote PEP 503 (Simple Repository API) server. + + At startup, fetches the top-level package list. Individual project pages + are fetched lazily on first lookup per package and memoized for the run. + """ + + def __init__( + self, + server_url: str, + download_dir: pathlib.Path, + backend_name: str | None = None, + ) -> None: + self._server_url = server_url.rstrip("/") + self._download_dir = download_dir + self._backend_name = backend_name or f"remote:{self._server_url}" + self._available_packages: set[NormalizedName] | None = None + self._project_cache: dict[NormalizedName, list[ArtifactInfo]] = {} + + @property + def name(self) -> str: + return self._backend_name + + @property + def writable(self) -> bool: + return False + + def scan(self) -> dict[WheelCacheKey, ArtifactInfo]: + """Fetch top-level index to learn which packages exist.""" + self._available_packages = self._fetch_package_list() + logger.debug( + "remote %s has %d packages available", + self._server_url, + len(self._available_packages) if self._available_packages else 0, + ) + return {} + + def lookup(self, key: WheelCacheKey) -> ArtifactInfo | None: + """Lazy per-package lookup with short-circuit for unknown packages.""" + if ( + self._available_packages is not None + and key.package not in self._available_packages + ): + return None + + if key.package not in self._project_cache: + self._project_cache[key.package] = self._fetch_project_page(key.package) + + for info in self._project_cache[key.package]: + try: + name, version, build_tag, _ = parse_wheel_filename(info.filename) + except Exception: + continue + candidate_key = WheelCacheKey( + package=name, version=version, build_tag=build_tag + ) + if candidate_key == key: + return info + + return None + + def fetch( + self, key: WheelCacheKey, info: ArtifactInfo, dest: pathlib.Path + ) -> pathlib.Path: + """Download the wheel from the remote server.""" + dest.mkdir(parents=True, exist_ok=True) + target = dest / info.filename + + if target.exists(): + if info.sha256: + verify_hash = hashlib.sha256() + with open(target, "rb") as f: + for chunk in iter(lambda: f.read(1024 * 1024), b""): + verify_hash.update(chunk) + if verify_hash.hexdigest() == info.sha256: + return target + logger.warning( + "existing %s has wrong sha256, re-downloading", info.filename + ) + target.unlink() + else: + return target + + url = info.url_or_path + logger.info( + "downloading cached wheel %s from %s", info.filename, self._server_url + ) + resp = session.get(url, stream=True) + resp.raise_for_status() + hasher = hashlib.sha256() if info.sha256 else None + tmp_target = target.with_suffix(target.suffix + ".tmp") + try: + with open(tmp_target, "wb") as f: + for chunk in resp.iter_content(chunk_size=1024 * 1024): + if not chunk: + continue + if hasher is not None: + hasher.update(chunk) + f.write(chunk) + if hasher is not None and hasher.hexdigest() != info.sha256: + raise ValueError( + f"sha256 mismatch for {info.filename}: " + f"expected {info.sha256}, got {hasher.hexdigest()}" + ) + tmp_target.replace(target) + except BaseException: + tmp_target.unlink(missing_ok=True) + raise + return target + + def store(self, key: WheelCacheKey, artifact: pathlib.Path) -> ArtifactInfo: + """Not supported for remote backends.""" + raise NotImplementedError("Remote backends are read-only") + + def _fetch_package_list(self) -> set[NormalizedName]: + """Fetch the top-level /simple/ index and extract package names.""" + url = f"{self._server_url}/" + try: + resp = session.get(url) + resp.raise_for_status() + except Exception as err: + logger.warning("failed to fetch remote index %s: %s", url, err) + return set() + + return self._parse_index_page(resp.text) + + def _fetch_project_page(self, package: NormalizedName) -> list[ArtifactInfo]: + """Fetch a project's page and extract wheel artifact info.""" + url = f"{self._server_url}/{package}/" + try: + resp = session.get(url) + resp.raise_for_status() + except Exception as err: + logger.debug("failed to fetch project page %s: %s", url, err) + return [] + + return self._parse_project_page(resp.text, url) + + @staticmethod + def _parse_index_page(html: str) -> set[NormalizedName]: + """Extract package names from a PEP 503 index page.""" + names: set[NormalizedName] = set() + for match in re.finditer(r'([^<]+)', html): + name = match.group(1).strip().rstrip("/") + if name: + names.add(canonicalize_name(name)) + return names + + @staticmethod + def _parse_project_page(html: str, base_url: str) -> list[ArtifactInfo]: + """Extract wheel artifact info from a PEP 503 project page.""" + artifacts: list[ArtifactInfo] = [] + for match in re.finditer(r']*>([^<]+)', html): + href = match.group(1) + filename = match.group(2).strip() + if not filename.endswith(".whl"): + continue + + # Reject filenames with path components to prevent directory traversal + safe_filename = pathlib.PurePosixPath(filename).name + if safe_filename != filename: + logger.warning( + "skipping remote artifact with unsafe filename %r", filename + ) + continue + filename = safe_filename + + # Resolve relative URLs + if href.startswith("http://") or href.startswith("https://"): + url = href + elif href.startswith("/"): + parsed = urlparse(base_url) + url = f"{parsed.scheme}://{parsed.netloc}{href}" + else: + url = base_url.rstrip("/") + "/" + href + + # Strip hash fragment for the URL but extract sha256 if present + sha256 = None + if "#" in url: + url_part, fragment = url.rsplit("#", 1) + if fragment.startswith("sha256="): + sha256 = fragment[7:] + url = url_part + + artifacts.append( + ArtifactInfo( + filename=filename, + url_or_path=url, + sha256=sha256, + ) + ) + return artifacts + + +# --------------------------------------------------------------------------- +# Collection and Store Router +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class CacheCollection: + """A named group of artifacts with one or more backends.""" + + name: str + backends: list[CacheBackend] + store_backend: LocalDirectoryBackend + + def scan_all(self) -> None: + """Scan all backends in this collection.""" + for backend in self.backends: + backend.scan() + + +class StoreRouter: + """Determines which collection receives a newly built artifact. + + Routing priority: + 1. Explicit per-package override (from overrides.yaml) + 2. Listed in the variant's requirements file (variant_packages) + 3. Default collection (shared/common dependencies) + """ + + def __init__( + self, + overrides: dict[NormalizedName, str], + variant_packages: set[NormalizedName] | None = None, + active_variant: str = "cpu", + default_collection: str = "default", + # Keep old kwarg name for backward compatibility + accelerated_packages: set[NormalizedName] | None = None, + ) -> None: + self._overrides = overrides + if variant_packages is not None: + self._variant_packages = variant_packages + elif accelerated_packages is not None: + self._variant_packages = accelerated_packages + else: + self._variant_packages = set() + self._active_variant = active_variant + self._default_collection = default_collection + + def route(self, req: Requirement) -> str: + """Return the collection name where this package should be stored.""" + name = canonicalize_name(req.name) + + if name in self._overrides: + return self._overrides[name] + + if name in self._variant_packages: + return self._active_variant + + return self._default_collection + + +# --------------------------------------------------------------------------- +# Cache Manager +# --------------------------------------------------------------------------- + + +class CacheManager: + """Unified entry point for all cache operations. + + Owns collections, handles hierarchical lookup, routes stores, + and tracks cache events for observability. + """ + + def __init__( + self, + collections: dict[str, CacheCollection], + search_order: list[str], + store_routing: StoreRouter, + force: bool = False, + ) -> None: + self._collections = collections + self._search_order = search_order + self._store_routing = store_routing + self._force = force + self._stats = CacheStats() + + def initialize(self) -> None: + """Scan all backends at build start. + + Local backends populate their in-memory index from disk. + Remote backends fetch the top-level package list. + """ + for name in self._search_order: + if name not in self._collections: + logger.warning("collection %r in search order but not configured", name) + continue + self._collections[name].scan_all() + + def lookup_wheel( + self, + req: Requirement, + version: Version, + build_tag: BuildTag, + ) -> CacheResult: + """Search collections in priority order for a matching wheel. + + Returns the first hit found. On a remote hit, the wheel is + downloaded to the collection's local store backend. + """ + if self._force: + self._stats.record_miss(req, version, "forced") + return CacheResult(hit=False) + + key = WheelCacheKey( + package=canonicalize_name(req.name), + version=version, + build_tag=build_tag, + ) + + for collection_name in self._search_order: + collection = self._collections.get(collection_name) + if collection is None: + continue + + for backend in collection.backends: + t0 = time.monotonic() + info = backend.lookup(key) + if info is None: + continue + + # Hit -- fetch the artifact to a local path + try: + local_path = backend.fetch( + key, info, collection.store_backend.directory + ) + # Register in the local store index so subsequent lookups + # find it locally without hitting the remote again + if backend is not collection.store_backend: + collection.store_backend.store(key, local_path) + except Exception as err: + logger.warning( + "cache hit for %s==%s in %s/%s could not be fetched: %s", + req.name, + version, + collection_name, + backend.name, + err, + ) + continue + duration_ms = (time.monotonic() - t0) * 1000 + was_downloaded = not backend.writable + + self._stats.record_hit( + req, + version, + collection_name, + backend.name, + duration_ms=duration_ms, + ) + logger.info( + "cache hit for %s==%s in %s/%s", + req.name, + version, + collection_name, + backend.name, + ) + return CacheResult( + hit=True, + path=local_path, + collection=collection_name, + backend_name=backend.name, + build_tag=build_tag, + was_downloaded=was_downloaded, + ) + + self._stats.record_miss(req, version, "not_found") + logger.debug("cache miss for %s==%s", req.name, version) + return CacheResult(hit=False) + + def lookup_sdist( + self, + req: Requirement, + version: Version, + ) -> CacheResult: + """Search for a cached sdist across collections. + + Uses the same search order as wheel lookups. Sdist keys do not + include build tags. + """ + if self._force: + self._stats.record_miss(req, version, "forced", artifact_type="sdist") + return CacheResult(hit=False) + + # Sdist lookup reuses wheel key matching against .tar.gz/.zip files + # For now, delegate to a simple filename-based check in local backends + # TODO: extend backends with sdist-specific scan/lookup + self._stats.record_miss(req, version, "not_implemented", artifact_type="sdist") + return CacheResult(hit=False) + + def store_wheel( + self, + req: Requirement, + version: Version, + build_tag: BuildTag, + wheel_path: pathlib.Path, + ) -> pathlib.Path: + """Route and store a newly built wheel in the appropriate collection.""" + collection_name = self._store_routing.route(req) + collection = self._collections.get(collection_name) + if collection is None: + raise ValueError( + f"store routing returned unknown collection {collection_name!r} " + f"for {req.name}" + ) + + key = WheelCacheKey( + package=canonicalize_name(req.name), + version=version, + build_tag=build_tag, + ) + + info = collection.store_backend.store(key, wheel_path) + self._stats.record_store(req, version, collection_name) + logger.info("stored %s in collection %r", info.filename, collection_name) + return pathlib.Path(info.url_or_path) + + @property + def stats(self) -> CacheStats: + return self._stats + + @property + def collections(self) -> dict[str, CacheCollection]: + return self._collections + + @property + def search_order(self) -> list[str]: + return list(self._search_order) diff --git a/src/fromager/commands/bootstrap.py b/src/fromager/commands/bootstrap.py index 9baeb36c..e63795dd 100644 --- a/src/fromager/commands/bootstrap.py +++ b/src/fromager/commands/bootstrap.py @@ -22,6 +22,7 @@ ) from ..log import requirement_ctxvar from .build import build_parallel +from .cache_cmd import _build_cache_manager from .graph import find_why, show_explain_duplicates # Map child_name==child_version to list of (parent_name==parent_version, Requirement) @@ -117,6 +118,12 @@ def _get_requirements_from_args( default=None, help="Reject package versions published more than this many days ago.", ) +@click.option( + "--use-cache-manager/--no-cache-manager", + "use_cache_manager", + default=False, + help="Enable the new unified CacheManager for hierarchical cache lookups.", +) @click.argument("toplevel", nargs=-1) @click.pass_obj def bootstrap( @@ -129,6 +136,7 @@ def bootstrap( test_mode: bool, multiple_versions: bool, max_release_age: int | None, + use_cache_manager: bool, toplevel: list[str], ) -> None: """Compute and build the dependencies of a set of requirements recursively @@ -187,6 +195,17 @@ def bootstrap( if pre_built: logger.info("treating %s as pre-built wheels", sorted(pre_built)) + if use_cache_manager: + cache_mgr = _build_cache_manager( + wkctx, cache_url=cache_wheel_server_url, toplevel_reqs=to_build + ) + wkctx.cache = cache_mgr + logger.info( + "cache manager enabled with %d collection(s): %s", + len(cache_mgr.collections), + list(cache_mgr.collections.keys()), + ) + server.start_wheel_server(wkctx) with progress.progress_context(total=len(to_build * 2)) as progressbar: diff --git a/src/fromager/commands/cache_cmd.py b/src/fromager/commands/cache_cmd.py new file mode 100644 index 00000000..72c55c23 --- /dev/null +++ b/src/fromager/commands/cache_cmd.py @@ -0,0 +1,427 @@ +"""CLI commands for cache management and observability.""" + +import json +import logging +import pathlib + +import click +import rich +import rich.box +from packaging.requirements import Requirement +from packaging.utils import canonicalize_name +from rich.table import Table + +from fromager import context +from fromager.cache import ( + CacheBackend, + CacheCollection, + CacheManager, + LocalDirectoryBackend, + RemotePEP503Backend, + StoreRouter, +) + +logger = logging.getLogger(__name__) + + +def _build_cache_manager( + wkctx: context.WorkContext, + cache_url: str | None = None, + toplevel_reqs: list[Requirement] | None = None, +) -> CacheManager: + """Construct a CacheManager from the WorkContext configuration. + + If the context already has a cache configured, return it. + Otherwise, build one from the standard filesystem layout. + + When a non-default variant is active and top-level requirements are provided, + two collections are created: + - A variant-specific collection for packages listed in the variant's + requirements file (the "main" packages for this build) + - A shared "default" collection for unlisted transitive dependencies + + Newly built wheels are routed by the StoreRouter: packages listed in the + variant requirements go to the variant collection, unlisted deps to default. + + Args: + wkctx: The work context providing local paths and variant info. + cache_url: Optional URL to a remote PEP 503 cache server. + toplevel_reqs: Top-level requirements from the variant's requirements + file. These define which packages belong to the variant collection. + """ + if wkctx.cache is not None: + return wkctx.cache + + # Shared (default) collection: downloads + prebuilt + optional remote + shared_backend = LocalDirectoryBackend( + wkctx.wheels_downloads, backend_name="local:downloads" + ) + prebuilt_backend = LocalDirectoryBackend( + wkctx.wheels_prebuilt, backend_name="local:prebuilt" + ) + + shared_backends: list[CacheBackend] = [shared_backend, prebuilt_backend] + + if cache_url: + remote_backend = RemotePEP503Backend( + server_url=cache_url, + download_dir=wkctx.wheels_downloads, + backend_name=f"remote:{cache_url}", + ) + shared_backends.append(remote_backend) + + default_collection = CacheCollection( + name="default", + backends=shared_backends, + store_backend=shared_backend, + ) + + collections: dict[str, CacheCollection] = {"default": default_collection} + search_order: list[str] = ["default"] + + # Variant-specific collection for packages listed in the requirements file + variant_packages = {canonicalize_name(r.name) for r in (toplevel_reqs or [])} + + if variant_packages and wkctx.variant != "cpu": + variant_dir = wkctx.wheels_repo / "variants" / wkctx.variant + variant_dir.mkdir(parents=True, exist_ok=True) + variant_backend = LocalDirectoryBackend( + variant_dir, backend_name=f"local:{wkctx.variant}" + ) + + variant_collection = CacheCollection( + name=wkctx.variant, + backends=[variant_backend], + store_backend=variant_backend, + ) + collections[wkctx.variant] = variant_collection + search_order = [wkctx.variant, "default"] + + router = StoreRouter( + overrides={}, + accelerated_packages=variant_packages, + active_variant=wkctx.variant, + ) + + manager = CacheManager( + collections=collections, + search_order=search_order, + store_routing=router, + ) + manager.initialize() + return manager + + +@click.group() +def cache() -> None: + """Manage the fromager wheel cache.""" + pass + + +@cache.command(name="list") +@click.option( + "--collection", + default=None, + help="Only list artifacts from this collection.", +) +@click.option( + "--format", + "output_format", + type=click.Choice(["table", "json"], case_sensitive=False), + default="table", + help="Output format (default: table).", +) +@click.pass_obj +def cache_list( + wkctx: context.WorkContext, + collection: str | None, + output_format: str, +) -> None: + """List all cached wheel artifacts.""" + manager = _build_cache_manager(wkctx) + + entries = [] + for coll_name, coll in manager.collections.items(): + if collection and coll_name != collection: + continue + for backend in coll.backends: + if not hasattr(backend, "_index"): + continue + for key, info in backend._index.items(): + entries.append( + { + "collection": coll_name, + "backend": backend.name, + "package": str(key.package), + "version": str(key.version), + "build_tag": ( + f"{key.build_tag[0]}{key.build_tag[1]}" + if key.build_tag + else "" + ), + "filename": info.filename, + "size_bytes": info.size_bytes or 0, + } + ) + + entries.sort(key=lambda e: (e["collection"], e["package"], e["version"])) + + if output_format == "json": + click.echo(json.dumps(entries, indent=2)) + return + + if not entries: + click.echo("No cached wheels found.") + return + + table = Table(title="Cached Wheels", box=rich.box.SIMPLE) + table.add_column("Collection", no_wrap=True) + table.add_column("Package", no_wrap=True) + table.add_column("Version", no_wrap=True) + table.add_column("Build Tag", no_wrap=True) + table.add_column("Size", justify="right", no_wrap=True) + table.add_column("Backend", no_wrap=True) + + for entry in entries: + size = _format_size(entry["size_bytes"]) + table.add_row( + entry["collection"], + entry["package"], + entry["version"], + entry["build_tag"], + size, + entry["backend"], + ) + + console = rich.get_console() + console.print(table) + console.print(f"\nTotal: {len(entries)} wheel(s)") + + +@cache.command(name="stats") +@click.pass_obj +def cache_stats(wkctx: context.WorkContext) -> None: + """Show cache statistics (hit/miss rates from last run).""" + manager = _build_cache_manager(wkctx) + + table = Table(title="Cache Statistics", box=rich.box.SIMPLE) + table.add_column("Metric", no_wrap=True) + table.add_column("Value", justify="right", no_wrap=True) + + summary = manager.stats.summary() + + table.add_row("Total lookups", str(summary["hits"]["total"] + summary["misses"])) + table.add_row("Hits", str(summary["hits"]["total"])) + table.add_row("Misses", str(summary["misses"])) + table.add_row("Hit rate", f"{summary['hit_rate']:.1%}") + table.add_row("Stores", str(summary["stores"])) + + if summary["hits"]["by_collection"]: + table.add_section() + for coll, count in summary["hits"]["by_collection"].items(): + table.add_row(f" Hits from {coll}", str(count)) + + # Show per-collection inventory counts + table.add_section() + total_wheels = 0 + total_size = 0 + for coll_name, coll in manager.collections.items(): + coll_count = 0 + coll_size = 0 + for backend in coll.backends: + if hasattr(backend, "_index"): + coll_count += len(backend._index) + coll_size += sum( + (info.size_bytes or 0) for info in backend._index.values() + ) + table.add_row(f" {coll_name} wheels", str(coll_count)) + table.add_row(f" {coll_name} size", _format_size(coll_size)) + total_wheels += coll_count + total_size += coll_size + + table.add_section() + table.add_row("Total wheels on disk", str(total_wheels)) + table.add_row("Total size on disk", _format_size(total_size)) + + console = rich.get_console() + console.print(table) + + +@cache.command() +@click.option( + "--remove-missing", + is_flag=True, + default=False, + help="Remove index entries for files that no longer exist on disk.", +) +@click.pass_obj +def verify(wkctx: context.WorkContext, remove_missing: bool) -> None: + """Verify cache integrity: check that indexed files exist on disk.""" + manager = _build_cache_manager(wkctx) + + missing = [] + checked = 0 + + for coll_name, coll in manager.collections.items(): + for backend in coll.backends: + if not isinstance(backend, LocalDirectoryBackend): + continue + for key, info in list(backend._index.items()): + checked += 1 + file_path = pathlib.Path(info.url_or_path) + if not file_path.exists(): + missing.append( + { + "collection": coll_name, + "backend": backend.name, + "key": str(key), + "path": str(file_path), + } + ) + if remove_missing: + del backend._index[key] + + if not missing: + click.echo(f"All {checked} cached artifacts verified OK.") + return + + click.echo(f"Found {len(missing)} missing artifact(s) out of {checked} checked:") + for m in missing: + action = " [removed from index]" if remove_missing else "" + click.echo(f" {m['collection']}/{m['key']}: {m['path']}{action}") + + +@cache.command() +@click.argument("packages", nargs=-1) +@click.option( + "--all", + "invalidate_all", + is_flag=True, + default=False, + help="Invalidate the entire cache.", +) +@click.option( + "--collection", + default=None, + help="Only invalidate within this collection.", +) +@click.pass_obj +def invalidate( + wkctx: context.WorkContext, + packages: tuple[str, ...], + invalidate_all: bool, + collection: str | None, +) -> None: + """Invalidate (remove) cached wheels for specific packages. + + Pass package names as arguments, or use --all to clear everything. + """ + if not packages and not invalidate_all: + raise click.UsageError("Specify package names or use --all.") + + manager = _build_cache_manager(wkctx) + removed = 0 + + target_packages = {canonicalize_name(p) for p in packages} if packages else None + + for coll_name, coll in manager.collections.items(): + if collection and coll_name != collection: + continue + for backend in coll.backends: + if not isinstance(backend, LocalDirectoryBackend): + continue + keys_to_remove = [] + for key, info in backend._index.items(): + if target_packages and key.package not in target_packages: + continue + keys_to_remove.append((key, info)) + + for key, info in keys_to_remove: + file_path = pathlib.Path(info.url_or_path) + if file_path.exists(): + file_path.unlink() + logger.info("removed %s", file_path) + del backend._index[key] + removed += 1 + + click.echo(f"Invalidated {removed} cached artifact(s).") + + +@cache.command() +@click.option( + "--dry-run", + is_flag=True, + default=False, + help="Show what would be removed without actually deleting.", +) +@click.option( + "--keep-latest", + type=int, + default=1, + help="Keep this many build tags per package+version (default: 1).", +) +@click.pass_obj +def gc( + wkctx: context.WorkContext, + dry_run: bool, + keep_latest: int, +) -> None: + """Garbage-collect old builds, keeping only the latest build tags. + + For each package+version, removes all but the --keep-latest most + recent builds (highest build tag number). + """ + manager = _build_cache_manager(wkctx) + removed = 0 + freed_bytes = 0 + + for _coll_name, coll in manager.collections.items(): + for backend in coll.backends: + if not isinstance(backend, LocalDirectoryBackend): + continue + + # Group by (package, version) + groups: dict[tuple, list] = {} + for key, info in backend._index.items(): + group_key = (key.package, key.version) + groups.setdefault(group_key, []).append((key, info)) + + for _group_key, entries in groups.items(): + if len(entries) <= keep_latest: + continue + + # Sort by build tag number descending + entries.sort( + key=lambda e: e[0].build_tag[0] if e[0].build_tag else 0, + reverse=True, + ) + to_remove = entries[keep_latest:] + + for key, info in to_remove: + file_path = pathlib.Path(info.url_or_path) + size = info.size_bytes or 0 + if dry_run: + click.echo( + f" would remove: {info.filename} ({_format_size(size)})" + ) + else: + if file_path.exists(): + file_path.unlink() + del backend._index[key] + logger.info("gc removed %s", file_path) + removed += 1 + freed_bytes += size + + verb = "Would remove" if dry_run else "Removed" + click.echo(f"{verb} {removed} old build(s), freeing {_format_size(freed_bytes)}.") + + +def _format_size(size_bytes: int) -> str: + """Format bytes as human-readable size.""" + if size_bytes == 0: + return "0 B" + for unit in ("B", "KB", "MB", "GB"): + if abs(size_bytes) < 1024: + return f"{size_bytes:.1f} {unit}" + size_bytes /= 1024 # type: ignore[assignment] + return f"{size_bytes:.1f} TB" diff --git a/src/fromager/context.py b/src/fromager/context.py index 58866aed..45c45a57 100644 --- a/src/fromager/context.py +++ b/src/fromager/context.py @@ -22,7 +22,7 @@ ) if typing.TYPE_CHECKING: - from . import build_environment, candidate + from . import build_environment, cache, candidate logger = logging.getLogger(__name__) @@ -98,6 +98,8 @@ def __init__( self.cooldown: candidate.Cooldown | None = cooldown self._max_release_age: datetime.timedelta | None = max_release_age + self._cache: cache.CacheManager | None = None + @property def max_release_age(self) -> datetime.timedelta | None: return self._max_release_age @@ -108,6 +110,15 @@ def set_max_release_age(self, days: int) -> None: raise ValueError(f"max_release_age must be positive, got {days}") self._max_release_age = datetime.timedelta(days=days) + @property + def cache(self) -> cache.CacheManager | None: + """The cache manager, if configured.""" + return self._cache + + @cache.setter + def cache(self, value: cache.CacheManager) -> None: + self._cache = value + def enable_parallel_builds(self) -> None: self._parallel_builds = True diff --git a/src/fromager/requirements_file.py b/src/fromager/requirements_file.py index a588d20a..02f9f4f0 100644 --- a/src/fromager/requirements_file.py +++ b/src/fromager/requirements_file.py @@ -34,6 +34,7 @@ class SourceType(StrEnum): SDIST = "sdist" OVERRIDE = "override" GIT = "git" + CACHED = "cached" def parse_requirements_file( diff --git a/tests/test_bootstrapper_iterative.py b/tests/test_bootstrapper_iterative.py index 27692dc9..6417a990 100644 --- a/tests/test_bootstrapper_iterative.py +++ b/tests/test_bootstrapper_iterative.py @@ -1126,6 +1126,347 @@ def test_constraint_logged_when_present( assert "matches constraint" in caplog.text +class TestPhasePrepareSourceCacheManager: + """Tests for _phase_prepare_source using the CacheManager short-circuit path. + + When a CacheManager is configured on WorkContext and provides a cache hit, + PREPARE_SOURCE should skip directly to PROCESS_INSTALL_DEPS without + downloading source, creating a build environment, or resolving build deps. + """ + + def test_cache_hit_short_circuits_to_process_install_deps( + self, tmp_context: WorkContext + ) -> None: + """Cache hit via CacheManager skips all build phases.""" + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + # Set up a CacheManager with a wheel in the default collection + wheels_dir = tmp_context.wheels_downloads + wheel_file = wheels_dir / "testpkg-1.0-1-py3-none-any.whl" + wheel_file.write_bytes(b"fake wheel") + + backend = LocalDirectoryBackend(wheels_dir, backend_name="local:default") + collection = CacheCollection( + name="default", backends=[backend], store_backend=backend + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + cache_mgr = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + item = _make_build_item(phase=BootstrapPhase.PREPARE_SOURCE) + + with ( + patch.object(tmp_context.constraints, "get_constraint", return_value=None), + patch.object( + bt, + "_find_cached_wheel_via_manager", + return_value=(wheel_file, tmp_context.work_dir / "testpkg-1.0"), + ), + patch.object( + bt, + "_create_unpack_dir", + return_value=tmp_context.work_dir / "testpkg-1.0", + ), + patch.object(bt, "_download_source") as mock_dl_src, + patch.object(bt, "_prepare_source") as mock_prep, + patch.object(bt, "_create_build_env") as mock_create_env, + patch("fromager.bootstrapper.server.update_wheel_mirror") as mock_mirror, + ): + result = bt._phase_prepare_source(item) + + # Short-circuit: jumped to PROCESS_INSTALL_DEPS + assert item.phase == BootstrapPhase.PROCESS_INSTALL_DEPS + assert item.build_result is not None + assert item.build_result.wheel_filename == wheel_file + assert item.build_result.source_type == SourceType.CACHED + assert item.build_result.build_env is None + assert item.build_result.sdist_filename is None + + # Nothing from the build path was called + mock_dl_src.assert_not_called() + mock_prep.assert_not_called() + mock_create_env.assert_not_called() + + # Wheel mirror was updated so cached wheel is indexed for build deps + mock_mirror.assert_called_once_with(tmp_context) + + # Only the continuation item, no build dep items + assert len(result) == 1 + assert result[0] is item + + def test_cache_hit_updates_wheel_mirror_for_build_dep_resolution( + self, tmp_context: WorkContext + ) -> None: + """Cache hit calls update_wheel_mirror so downloaded wheels are available + to subsequent build environments via the internal wheel server.""" + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + wheels_dir = tmp_context.wheels_downloads + wheel_file = wheels_dir / "testpkg-1.0-1-py3-none-any.whl" + wheel_file.write_bytes(b"fake wheel") + + backend = LocalDirectoryBackend(wheels_dir, backend_name="local:default") + collection = CacheCollection( + name="default", backends=[backend], store_backend=backend + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + cache_mgr = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + item = _make_build_item(phase=BootstrapPhase.PREPARE_SOURCE) + + with ( + patch.object(tmp_context.constraints, "get_constraint", return_value=None), + patch.object( + bt, + "_find_cached_wheel_via_manager", + return_value=(wheel_file, tmp_context.work_dir / "testpkg-1.0"), + ), + patch.object( + bt, + "_create_unpack_dir", + return_value=tmp_context.work_dir / "testpkg-1.0", + ), + patch("fromager.bootstrapper.server.update_wheel_mirror") as mock_mirror, + ): + bt._phase_prepare_source(item) + + # update_wheel_mirror called before returning, ensuring the wheel is + # symlinked into the simple/ index for uv pip install to find it + mock_mirror.assert_called_once_with(tmp_context) + + def test_cache_miss_with_manager_falls_through_to_build( + self, tmp_context: WorkContext + ) -> None: + """Cache miss via CacheManager proceeds to normal build path.""" + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + # Empty cache — no wheels + wheels_dir = tmp_context.wheels_downloads + backend = LocalDirectoryBackend(wheels_dir, backend_name="local:default") + collection = CacheCollection( + name="default", backends=[backend], store_backend=backend + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + cache_mgr = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + item = _make_build_item(phase=BootstrapPhase.PREPARE_SOURCE) + + sdist_root = tmp_context.work_dir / "testpkg-1.0" / "testpkg-1.0" + mock_env = Mock() + + with ( + patch.object(tmp_context.constraints, "get_constraint", return_value=None), + patch.object( + bt, "_download_source", return_value=tmp_context.work_dir / "src.tar.gz" + ) as mock_dl_src, + patch.object(bt, "_prepare_source", return_value=sdist_root) as mock_prep, + patch.object(bt, "_create_build_env", return_value=mock_env), + patch( + "fromager.dependencies.get_build_system_dependencies", + return_value=set(), + ), + patch.object(bt, "_create_unresolved_work_items", return_value=[]), + ): + bt._phase_prepare_source(item) + + # Normal build path — advances to PREPARE_BUILD + assert item.phase == BootstrapPhase.PREPARE_BUILD + assert item.build_env is mock_env + mock_dl_src.assert_called_once() + mock_prep.assert_called_once() + + def test_cache_hit_records_stats(self, tmp_context: WorkContext) -> None: + """Cache hit via CacheManager records a hit event in stats.""" + + from unittest.mock import MagicMock + + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + wheels_dir = tmp_context.wheels_downloads + wheel_file = wheels_dir / "testpkg-1.0-1-py3-none-any.whl" + wheel_file.write_bytes(b"fake wheel") + + backend = LocalDirectoryBackend(wheels_dir, backend_name="local:default") + collection = CacheCollection( + name="default", backends=[backend], store_backend=backend + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + cache_mgr = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + item = _make_build_item(phase=BootstrapPhase.PREPARE_SOURCE) + + mock_pbi = MagicMock() + mock_pbi.build_tag.return_value = (1, "") + + with ( + patch.object(tmp_context.constraints, "get_constraint", return_value=None), + patch.object(tmp_context, "package_build_info", return_value=mock_pbi), + patch.object( + bt, + "_unpack_metadata_from_wheel", + return_value=tmp_context.work_dir / "testpkg-1.0", + ), + patch.object( + bt, + "_create_unpack_dir", + return_value=tmp_context.work_dir / "testpkg-1.0", + ), + patch("fromager.bootstrapper.server.update_wheel_mirror"), + ): + bt._phase_prepare_source(item) + + assert item.phase == BootstrapPhase.PROCESS_INSTALL_DEPS + # Stats recorded the hit through the real lookup_wheel path + summary = cache_mgr.stats.summary() + assert summary["hits"]["total"] == 1 + assert summary["hits"]["by_collection"]["default"] == 1 + + def test_no_cache_manager_uses_legacy_path(self, tmp_context: WorkContext) -> None: + """Without CacheManager, legacy path is used (no short-circuit).""" + # Explicitly no cache manager + assert tmp_context.cache is None + + bt = bootstrapper.Bootstrapper(tmp_context) + item = _make_build_item(phase=BootstrapPhase.PREPARE_SOURCE) + + unpacked = tmp_context.work_dir / "testpkg-1.0" + unpacked.mkdir(parents=True) + cached_wheel = tmp_context.work_dir / "testpkg-1.0-py3-none-any.whl" + mock_env = Mock() + + with ( + patch.object(tmp_context.constraints, "get_constraint", return_value=None), + patch.object( + bt, "_find_cached_wheel", return_value=(cached_wheel, unpacked) + ), + patch.object(bt, "_download_source") as mock_dl_src, + patch.object(bt, "_prepare_source") as mock_prep, + patch.object(bt, "_create_build_env", return_value=mock_env), + patch( + "fromager.dependencies.get_build_system_dependencies", + return_value=set(), + ), + patch.object(bt, "_create_unresolved_work_items", return_value=[]), + ): + bt._phase_prepare_source(item) + + # Legacy path: cached wheel found but no short-circuit — goes to PREPARE_BUILD + assert item.phase == BootstrapPhase.PREPARE_BUILD + assert item.cached_wheel_filename == cached_wheel + mock_dl_src.assert_not_called() + mock_prep.assert_not_called() + + def test_force_mode_skips_cache(self, tmp_context: WorkContext) -> None: + """--force flag causes CacheManager to return miss, triggering full build.""" + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + wheels_dir = tmp_context.wheels_downloads + # Put a wheel in the cache + wheel_file = wheels_dir / "testpkg-1.0-1-py3-none-any.whl" + wheel_file.write_bytes(b"fake wheel") + + backend = LocalDirectoryBackend(wheels_dir, backend_name="local:default") + collection = CacheCollection( + name="default", backends=[backend], store_backend=backend + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + cache_mgr = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + force=True, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + item = _make_build_item(phase=BootstrapPhase.PREPARE_SOURCE) + + sdist_root = tmp_context.work_dir / "testpkg-1.0" / "testpkg-1.0" + mock_env = Mock() + + with ( + patch.object(tmp_context.constraints, "get_constraint", return_value=None), + patch.object( + bt, "_download_source", return_value=tmp_context.work_dir / "src.tar.gz" + ) as mock_dl_src, + patch.object(bt, "_prepare_source", return_value=sdist_root), + patch.object(bt, "_create_build_env", return_value=mock_env), + patch( + "fromager.dependencies.get_build_system_dependencies", + return_value=set(), + ), + patch.object(bt, "_create_unresolved_work_items", return_value=[]), + ): + bt._phase_prepare_source(item) + + # Force mode: goes through full build path despite wheel existing + assert item.phase == BootstrapPhase.PREPARE_BUILD + mock_dl_src.assert_called_once() + + class TestPhasePrepareBuild: """Tests for _phase_prepare_build: dep installation and extraction.""" @@ -1455,6 +1796,187 @@ def test_returns_single_item_at_process_install_deps( assert result[0] is item assert item.phase == BootstrapPhase.PROCESS_INSTALL_DEPS + def test_store_wheel_called_for_freshly_built_wheel( + self, tmp_context: WorkContext + ) -> None: + """When CacheManager is active and a wheel is built (not cached), + store_wheel routes it to the appropriate collection.""" + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + wheels_dir = tmp_context.wheels_downloads + variant_dir = tmp_context.wheels_repo / "variants" / "gaudi-ubi9" + variant_dir.mkdir(parents=True) + + shared_backend = LocalDirectoryBackend(wheels_dir, backend_name="local:shared") + variant_backend = LocalDirectoryBackend( + variant_dir, backend_name="local:gaudi-ubi9" + ) + + default_coll = CacheCollection( + name="default", backends=[shared_backend], store_backend=shared_backend + ) + variant_coll = CacheCollection( + name="gaudi-ubi9", + backends=[variant_backend], + store_backend=variant_backend, + ) + + router = StoreRouter( + overrides={}, + accelerated_packages={canonicalize_name("testpkg")}, + active_variant="gaudi-ubi9", + ) + cache_mgr = CacheManager( + collections={"default": default_coll, "gaudi-ubi9": variant_coll}, + search_order=["gaudi-ubi9", "default"], + store_routing=router, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + mock_env = Mock() + sdist_root = tmp_context.work_dir / "testpkg-1.0" / "testpkg-1.0" + sdist_root.parent.mkdir(parents=True, exist_ok=True) + item = _make_build_item( + phase=BootstrapPhase.BUILD, + build_env=mock_env, + sdist_root_dir=sdist_root, + ) + + built_wheel = wheels_dir / "testpkg-1.0-1-py3-none-any.whl" + built_wheel.write_bytes(b"fake built wheel") + + with ( + patch.object(bt, "_do_build", return_value=(built_wheel, None)), + patch("fromager.sources.get_source_type", return_value=SourceType.SDIST), + ): + bt._phase_build(item) + + # Wheel was copied to the variant collection directory + assert (variant_dir / "testpkg-1.0-1-py3-none-any.whl").exists() + # Original in downloads preserved for internal wheel server + assert built_wheel.exists() + + def test_store_wheel_routes_unlisted_to_default( + self, tmp_context: WorkContext + ) -> None: + """Unlisted packages (not accelerated) are stored in the default collection.""" + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + wheels_dir = tmp_context.wheels_downloads + variant_dir = tmp_context.wheels_repo / "variants" / "gaudi-ubi9" + variant_dir.mkdir(parents=True) + + shared_backend = LocalDirectoryBackend(wheels_dir, backend_name="local:shared") + variant_backend = LocalDirectoryBackend( + variant_dir, backend_name="local:gaudi-ubi9" + ) + + default_coll = CacheCollection( + name="default", backends=[shared_backend], store_backend=shared_backend + ) + variant_coll = CacheCollection( + name="gaudi-ubi9", + backends=[variant_backend], + store_backend=variant_backend, + ) + + # testpkg is NOT in accelerated_packages + router = StoreRouter( + overrides={}, + accelerated_packages={canonicalize_name("torch")}, + active_variant="gaudi-ubi9", + ) + cache_mgr = CacheManager( + collections={"default": default_coll, "gaudi-ubi9": variant_coll}, + search_order=["gaudi-ubi9", "default"], + store_routing=router, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + mock_env = Mock() + sdist_root = tmp_context.work_dir / "testpkg-1.0" / "testpkg-1.0" + sdist_root.parent.mkdir(parents=True, exist_ok=True) + item = _make_build_item( + phase=BootstrapPhase.BUILD, + build_env=mock_env, + sdist_root_dir=sdist_root, + ) + + built_wheel = wheels_dir / "testpkg-1.0-1-py3-none-any.whl" + built_wheel.write_bytes(b"fake built wheel") + + with ( + patch.object(bt, "_do_build", return_value=(built_wheel, None)), + patch("fromager.sources.get_source_type", return_value=SourceType.SDIST), + ): + bt._phase_build(item) + + # Wheel stays in downloads (default collection dir == downloads) + assert built_wheel.exists() + # NOT copied to variant dir + assert not (variant_dir / "testpkg-1.0-1-py3-none-any.whl").exists() + + def test_store_wheel_skipped_when_cached(self, tmp_context: WorkContext) -> None: + """When a cached wheel is used (not freshly built), store_wheel is not called.""" + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + wheels_dir = tmp_context.wheels_downloads + shared_backend = LocalDirectoryBackend(wheels_dir, backend_name="local:shared") + default_coll = CacheCollection( + name="default", backends=[shared_backend], store_backend=shared_backend + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + cache_mgr = CacheManager( + collections={"default": default_coll}, + search_order=["default"], + store_routing=router, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + mock_env = Mock() + sdist_root = tmp_context.work_dir / "testpkg-1.0" / "testpkg-1.0" + sdist_root.parent.mkdir(parents=True, exist_ok=True) + cached_wheel = wheels_dir / "testpkg-1.0-1-py3-none-any.whl" + cached_wheel.write_bytes(b"cached wheel") + item = _make_build_item( + phase=BootstrapPhase.BUILD, + build_env=mock_env, + sdist_root_dir=sdist_root, + cached_wheel_filename=cached_wheel, + ) + + with ( + patch.object(bt, "_do_build", return_value=(cached_wheel, None)), + patch("fromager.sources.get_source_type", return_value=SourceType.SDIST), + patch.object(cache_mgr, "store_wheel") as mock_store, + ): + bt._phase_build(item) + + mock_store.assert_not_called() + class TestPhaseProcessInstallDeps: """Tests for _phase_process_install_deps: hooks, dep extraction, error modes.""" @@ -1672,3 +2194,155 @@ def test_build_order_called_with_correct_args( prebuilt=True, constraint=constraint, ) + + +class TestFindCachedWheelDispatch: + """Tests for _find_cached_wheel dispatch to manager vs legacy.""" + + def test_dispatches_to_manager_when_configured( + self, tmp_context: WorkContext + ) -> None: + """With CacheManager set, _find_cached_wheel calls via_manager.""" + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + wheels_dir = tmp_context.wheels_downloads + backend = LocalDirectoryBackend(wheels_dir, backend_name="local:default") + collection = CacheCollection( + name="default", backends=[backend], store_backend=backend + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + cache_mgr = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + req = Requirement("testpkg") + version = Version("1.0") + + with ( + patch.object( + bt, "_find_cached_wheel_via_manager", return_value=(None, None) + ) as mock_mgr, + patch.object(bt, "_find_cached_wheel_legacy") as mock_legacy, + ): + bt._find_cached_wheel(req, version) + + mock_mgr.assert_called_once_with(req, version) + mock_legacy.assert_not_called() + + def test_dispatches_to_legacy_when_no_manager( + self, tmp_context: WorkContext + ) -> None: + """Without CacheManager, _find_cached_wheel calls legacy.""" + assert tmp_context.cache is None + + bt = bootstrapper.Bootstrapper(tmp_context) + req = Requirement("testpkg") + version = Version("1.0") + + with ( + patch.object(bt, "_find_cached_wheel_via_manager") as mock_mgr, + patch.object( + bt, "_find_cached_wheel_legacy", return_value=(None, None) + ) as mock_legacy, + ): + bt._find_cached_wheel(req, version) + + mock_legacy.assert_called_once_with(req, version) + mock_mgr.assert_not_called() + + def test_via_manager_returns_hit_path(self, tmp_context: WorkContext) -> None: + """_find_cached_wheel_via_manager returns path on cache hit.""" + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + wheels_dir = tmp_context.wheels_downloads + wheel_file = wheels_dir / "testpkg-1.0-1-py3-none-any.whl" + wheel_file.write_bytes(b"fake wheel content") + + backend = LocalDirectoryBackend(wheels_dir, backend_name="local:default") + collection = CacheCollection( + name="default", backends=[backend], store_backend=backend + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + cache_mgr = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + req = Requirement("testpkg") + version = Version("1.0") + + # build_tag must be a BuildTag tuple matching what parse_wheel_filename produces + mock_pbi = Mock() + mock_pbi.build_tag.return_value = (1, "") + metadata_dir = tmp_context.work_dir / "testpkg-1.0.dist-info" + + with ( + patch.object(tmp_context, "package_build_info", return_value=mock_pbi), + patch.object(bt, "_unpack_metadata_from_wheel", return_value=metadata_dir), + ): + wheel_path, meta_dir = bt._find_cached_wheel_via_manager(req, version) + + assert wheel_path == wheel_file + assert meta_dir == metadata_dir + + def test_via_manager_returns_none_on_miss(self, tmp_context: WorkContext) -> None: + """_find_cached_wheel_via_manager returns (None, None) on cache miss.""" + from fromager.cache import ( + CacheCollection, + CacheManager, + LocalDirectoryBackend, + StoreRouter, + ) + + # Empty cache + wheels_dir = tmp_context.wheels_downloads + backend = LocalDirectoryBackend(wheels_dir, backend_name="local:default") + collection = CacheCollection( + name="default", backends=[backend], store_backend=backend + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + cache_mgr = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + ) + cache_mgr.initialize() + tmp_context.cache = cache_mgr + + bt = bootstrapper.Bootstrapper(tmp_context) + req = Requirement("testpkg") + version = Version("1.0") + + mock_pbi = Mock() + mock_pbi.build_tag.return_value = (1, "") + + with patch.object(tmp_context, "package_build_info", return_value=mock_pbi): + wheel_path, meta_dir = bt._find_cached_wheel_via_manager(req, version) + + assert wheel_path is None + assert meta_dir is None diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 00000000..201d8ca8 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,1276 @@ +"""Unit tests for the cache subsystem.""" + +import pathlib + +import pytest +import requests_mock +from packaging.requirements import Requirement +from packaging.utils import canonicalize_name +from packaging.version import Version + +from fromager.cache import ( + CacheCollection, + CacheManager, + CacheResult, + CacheStats, + LocalDirectoryBackend, + RemotePEP503Backend, + SdistCacheKey, + StoreRouter, + WheelCacheKey, +) + +# --------------------------------------------------------------------------- +# WheelCacheKey tests +# --------------------------------------------------------------------------- + + +class TestWheelCacheKey: + def test_creation(self) -> None: + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + assert key.package == "numpy" + assert key.version == Version("1.26.4") + assert key.build_tag == (2, "") + + def test_equality(self) -> None: + key1 = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + key2 = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + assert key1 == key2 + + def test_inequality_version(self) -> None: + key1 = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + key2 = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.5"), + build_tag=(2, ""), + ) + assert key1 != key2 + + def test_inequality_build_tag(self) -> None: + key1 = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + key2 = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(3, ""), + ) + assert key1 != key2 + + def test_hashable(self) -> None: + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + d: dict[WheelCacheKey, str] = {key: "found"} + assert d[key] == "found" + + def test_str_with_build_tag(self) -> None: + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + assert str(key) == "numpy==1.26.4-2" + + def test_str_without_build_tag(self) -> None: + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(), + ) + assert str(key) == "numpy==1.26.4" + + def test_name_normalization(self) -> None: + key1 = WheelCacheKey( + package=canonicalize_name("Flask-RESTful"), + version=Version("0.3.10"), + build_tag=(), + ) + key2 = WheelCacheKey( + package=canonicalize_name("flask_restful"), + version=Version("0.3.10"), + build_tag=(), + ) + assert key1 == key2 + + +class TestSdistCacheKey: + def test_creation(self) -> None: + key = SdistCacheKey( + package=canonicalize_name("requests"), + version=Version("2.31.0"), + ) + assert key.package == "requests" + assert str(key) == "requests==2.31.0" + + +# --------------------------------------------------------------------------- +# LocalDirectoryBackend tests +# --------------------------------------------------------------------------- + + +def _create_wheel_file(directory: pathlib.Path, filename: str) -> pathlib.Path: + """Create a fake wheel file for testing.""" + directory.mkdir(parents=True, exist_ok=True) + wheel_path = directory / filename + wheel_path.write_bytes(b"fake wheel content") + return wheel_path + + +class TestLocalDirectoryBackend: + def test_name(self, tmp_path: pathlib.Path) -> None: + backend = LocalDirectoryBackend( + tmp_path / "wheels", backend_name="local:default" + ) + assert backend.name == "local:default" + + def test_writable(self, tmp_path: pathlib.Path) -> None: + backend = LocalDirectoryBackend(tmp_path / "wheels") + assert backend.writable is True + + def test_scan_empty_directory(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + wheels_dir.mkdir() + backend = LocalDirectoryBackend(wheels_dir) + result = backend.scan() + assert result == {} + + def test_scan_nonexistent_directory(self, tmp_path: pathlib.Path) -> None: + backend = LocalDirectoryBackend(tmp_path / "nonexistent") + result = backend.scan() + assert result == {} + + def test_scan_finds_wheels(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + _create_wheel_file(wheels_dir, "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl") + backend = LocalDirectoryBackend(wheels_dir) + result = backend.scan() + + expected_key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + assert expected_key in result + assert ( + result[expected_key].filename + == "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl" + ) + + def test_scan_skips_non_wheel_files(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + wheels_dir.mkdir() + (wheels_dir / "readme.txt").write_text("not a wheel") + (wheels_dir / "numpy-1.26.4.tar.gz").write_bytes(b"sdist") + backend = LocalDirectoryBackend(wheels_dir) + result = backend.scan() + assert result == {} + + def test_scan_skips_unparseable_wheels(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + wheels_dir.mkdir() + (wheels_dir / "totally-invalid-name.whl").write_bytes(b"bad") + backend = LocalDirectoryBackend(wheels_dir) + result = backend.scan() + assert result == {} + + def test_lookup_hit(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + _create_wheel_file(wheels_dir, "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl") + backend = LocalDirectoryBackend(wheels_dir) + backend.scan() + + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + result = backend.lookup(key) + assert result is not None + assert result.filename == "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl" + + def test_lookup_miss(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + wheels_dir.mkdir() + backend = LocalDirectoryBackend(wheels_dir) + backend.scan() + + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + assert backend.lookup(key) is None + + def test_lookup_evicts_deleted_file(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + whl = _create_wheel_file( + wheels_dir, "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl" + ) + backend = LocalDirectoryBackend(wheels_dir) + backend.scan() + + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + # File exists initially + assert backend.lookup(key) is not None + # Delete it + whl.unlink() + # Now lookup should return None and evict from index + assert backend.lookup(key) is None + + def test_fetch_returns_local_path(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + whl = _create_wheel_file( + wheels_dir, "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl" + ) + backend = LocalDirectoryBackend(wheels_dir) + backend.scan() + + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + info = backend.lookup(key) + assert info is not None + result = backend.fetch(key, info, tmp_path / "dest") + assert result == whl.resolve() + + def test_store_copies_file(self, tmp_path: pathlib.Path) -> None: + """Store copies the wheel to the collection directory, preserving the original.""" + wheels_dir = tmp_path / "wheels" + wheels_dir.mkdir() + backend = LocalDirectoryBackend(wheels_dir) + backend.scan() + + # Create a wheel in a "build" directory + build_dir = tmp_path / "build" + whl = _create_wheel_file(build_dir, "requests-2.31.0-1-py3-none-any.whl") + + key = WheelCacheKey( + package=canonicalize_name("requests"), + version=Version("2.31.0"), + build_tag=(1, ""), + ) + info = backend.store(key, whl) + + assert info.filename == "requests-2.31.0-1-py3-none-any.whl" + assert (wheels_dir / "requests-2.31.0-1-py3-none-any.whl").exists() + assert whl.exists() # Original preserved for internal wheel server + + def test_store_updates_index(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + wheels_dir.mkdir() + backend = LocalDirectoryBackend(wheels_dir) + backend.scan() + + build_dir = tmp_path / "build" + whl = _create_wheel_file(build_dir, "requests-2.31.0-1-py3-none-any.whl") + + key = WheelCacheKey( + package=canonicalize_name("requests"), + version=Version("2.31.0"), + build_tag=(1, ""), + ) + backend.store(key, whl) + + # Should be findable via lookup now + result = backend.lookup(key) + assert result is not None + assert result.filename == "requests-2.31.0-1-py3-none-any.whl" + + def test_store_no_move_if_already_exists(self, tmp_path: pathlib.Path) -> None: + wheels_dir = tmp_path / "wheels" + existing = _create_wheel_file(wheels_dir, "requests-2.31.0-1-py3-none-any.whl") + backend = LocalDirectoryBackend(wheels_dir) + backend.scan() + + key = WheelCacheKey( + package=canonicalize_name("requests"), + version=Version("2.31.0"), + build_tag=(1, ""), + ) + # Store with same filename that already exists + info = backend.store(key, existing) + assert info.filename == "requests-2.31.0-1-py3-none-any.whl" + assert existing.exists() + + +# --------------------------------------------------------------------------- +# RemotePEP503Backend tests +# --------------------------------------------------------------------------- + + +class TestRemotePEP503Backend: + def test_name_default(self) -> None: + backend = RemotePEP503Backend( + server_url="https://cache.example.com/simple", + download_dir=pathlib.Path("/tmp/downloads"), + ) + assert backend.name == "remote:https://cache.example.com/simple" + + def test_not_writable(self) -> None: + backend = RemotePEP503Backend( + server_url="https://cache.example.com/simple", + download_dir=pathlib.Path("/tmp/downloads"), + ) + assert backend.writable is False + + def test_store_raises(self) -> None: + backend = RemotePEP503Backend( + server_url="https://cache.example.com/simple", + download_dir=pathlib.Path("/tmp/downloads"), + ) + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + with pytest.raises(NotImplementedError): + backend.store(key, pathlib.Path("/fake.whl")) + + def test_parse_index_page(self) -> None: + html = """ + + + numpy + requests + Flask + + """ + result = RemotePEP503Backend._parse_index_page(html) + assert canonicalize_name("numpy") in result + assert canonicalize_name("requests") in result + assert canonicalize_name("flask") in result + + def test_parse_project_page(self) -> None: + html = """ + + + numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl + numpy-1.26.4.tar.gz + numpy-1.25.0-cp312-cp312-linux_x86_64.whl + + """ + result = RemotePEP503Backend._parse_project_page( + html, "https://cache.test/simple/numpy/" + ) + assert len(result) == 2 # Only .whl files + assert result[0].filename == "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl" + assert result[0].sha256 == "abc123" + assert result[1].filename == "numpy-1.25.0-cp312-cp312-linux_x86_64.whl" + assert "cache.test" in result[1].url_or_path + + def test_parse_project_page_absolute_url(self) -> None: + html = 'numpy-1.26.4-cp312-cp312-linux_x86_64.whl' + result = RemotePEP503Backend._parse_project_page( + html, "https://cache.test/simple/numpy/" + ) + assert len(result) == 1 + assert ( + result[0].url_or_path + == "https://other.test/numpy-1.26.4-cp312-cp312-linux_x86_64.whl" + ) + + def test_lookup_short_circuits_unknown_package(self) -> None: + backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=pathlib.Path("/tmp"), + ) + # Simulate scan() having populated the available packages set + backend._available_packages = {canonicalize_name("requests")} + + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + assert backend.lookup(key) is None + + def test_scan_populates_available_packages( + self, requests_mock: requests_mock.Mocker + ) -> None: + """scan() fetches the index page and populates available_packages.""" + index_html = """ + + numpy + torch + + """ + requests_mock.get("https://cache.test/simple/", text=index_html) + + backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=pathlib.Path("/tmp"), + ) + result = backend.scan() + + assert result == {} + assert backend._available_packages is not None + assert canonicalize_name("numpy") in backend._available_packages + assert canonicalize_name("torch") in backend._available_packages + assert len(backend._available_packages) == 2 + + def test_scan_handles_network_error( + self, requests_mock: requests_mock.Mocker + ) -> None: + """scan() gracefully handles a network error.""" + import requests + + requests_mock.get( + "https://cache.test/simple/", exc=requests.ConnectionError("timeout") + ) + + backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=pathlib.Path("/tmp"), + ) + result = backend.scan() + + assert result == {} + assert backend._available_packages == set() + + def test_lookup_fetches_project_page_lazily( + self, requests_mock: requests_mock.Mocker + ) -> None: + """lookup() fetches the project page on first access and caches it.""" + project_html = """ + numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl + """ + requests_mock.get("https://cache.test/simple/numpy/", text=project_html) + + backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=pathlib.Path("/tmp"), + ) + backend._available_packages = {canonicalize_name("numpy")} + + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + info = backend.lookup(key) + + assert info is not None + assert info.filename == "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl" + assert info.sha256 == "abc" + # Project page is now cached + assert canonicalize_name("numpy") in backend._project_cache + + def test_lookup_returns_none_for_unmatched_version( + self, requests_mock: requests_mock.Mocker + ) -> None: + """lookup() returns None when version doesn't match any wheel.""" + project_html = """ + numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl + """ + requests_mock.get("https://cache.test/simple/numpy/", text=project_html) + + backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=pathlib.Path("/tmp"), + ) + backend._available_packages = {canonicalize_name("numpy")} + + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("2.0.0"), + build_tag=(1, ""), + ) + assert backend.lookup(key) is None + + def test_lookup_caches_project_page( + self, requests_mock: requests_mock.Mocker + ) -> None: + """Second lookup() for same package does not fetch again.""" + project_html = """ + numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl + """ + requests_mock.get("https://cache.test/simple/numpy/", text=project_html) + + backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=pathlib.Path("/tmp"), + ) + backend._available_packages = {canonicalize_name("numpy")} + + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + backend.lookup(key) + backend.lookup(key) + + # Only one request made to the project page + assert requests_mock.call_count == 1 + + def test_fetch_downloads_wheel( + self, tmp_path: pathlib.Path, requests_mock: requests_mock.Mocker + ) -> None: + """fetch() downloads wheel content to the destination directory.""" + wheel_content = b"PK\x03\x04fake wheel archive content" + requests_mock.get( + "https://cache.test/files/numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + content=wheel_content, + ) + + from fromager.cache import ArtifactInfo + + backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=tmp_path / "downloads", + ) + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + info = ArtifactInfo( + filename="numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + url_or_path="https://cache.test/files/numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + ) + + dest = tmp_path / "dest" + result_path = backend.fetch(key, info, dest) + + assert result_path.exists() + assert result_path.name == "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl" + assert result_path.read_bytes() == wheel_content + + def test_fetch_skips_existing_file(self, tmp_path: pathlib.Path) -> None: + """fetch() returns existing path without downloading if file exists.""" + from fromager.cache import ArtifactInfo + + dest = tmp_path / "dest" + dest.mkdir() + existing = dest / "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl" + existing.write_bytes(b"existing content") + + backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=tmp_path / "downloads", + ) + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + info = ArtifactInfo( + filename="numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + url_or_path="https://cache.test/files/numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + ) + + result_path = backend.fetch(key, info, dest) + + assert result_path == existing + assert result_path.read_bytes() == b"existing content" + + def test_full_scan_lookup_fetch_flow( + self, + tmp_path: pathlib.Path, + requests_mock: requests_mock.Mocker, + ) -> None: + """End-to-end: scan -> lookup -> fetch for a remote backend.""" + wheel_content = b"wheel bytes" + wheel_sha256 = ( + "67c0d8f7de19e30c2d5891030a0b37cbfcdd240852b53055c0b28290ad52290b" + ) + index_html = 'numpy' + project_html = f""" + numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl + """ + + requests_mock.get("https://cache.test/simple/", text=index_html) + requests_mock.get("https://cache.test/simple/numpy/", text=project_html) + requests_mock.get( + "https://cache.test/files/numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + content=wheel_content, + ) + + backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=tmp_path / "downloads", + ) + + # Step 1: scan + backend.scan() + assert backend._available_packages is not None + assert canonicalize_name("numpy") in backend._available_packages + + # Step 2: lookup + key = WheelCacheKey( + package=canonicalize_name("numpy"), + version=Version("1.26.4"), + build_tag=(2, ""), + ) + info = backend.lookup(key) + assert info is not None + assert info.sha256 == wheel_sha256 + + # Step 3: fetch + dest = tmp_path / "local-cache" + result_path = backend.fetch(key, info, dest) + assert result_path.exists() + assert result_path.read_bytes() == wheel_content + + def test_fetch_rejects_sha256_mismatch( + self, + tmp_path: pathlib.Path, + requests_mock: requests_mock.Mocker, + ) -> None: + """Fetch raises ValueError and removes file on sha256 mismatch.""" + project_html = """ + bad-1.0-1-py3-none-any.whl + """ + requests_mock.get("https://cache.test/simple/", text='bad') + requests_mock.get("https://cache.test/simple/bad/", text=project_html) + requests_mock.get( + "https://cache.test/files/bad-1.0-1-py3-none-any.whl", + content=b"tampered content", + ) + + backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=tmp_path / "downloads", + ) + backend.scan() + + key = WheelCacheKey( + package=canonicalize_name("bad"), + version=Version("1.0"), + build_tag=(1, ""), + ) + info = backend.lookup(key) + assert info is not None + + import pytest + + dest = tmp_path / "local-cache" + with pytest.raises(ValueError, match="sha256 mismatch"): + backend.fetch(key, info, dest) + assert not (dest / "bad-1.0-1-py3-none-any.whl").exists() + + def test_parse_project_page_rejects_path_traversal(self) -> None: + """Filenames with path components are rejected.""" + html = """ + ../../etc/evil.whl + good-1.0-1-py3-none-any.whl + """ + artifacts = RemotePEP503Backend._parse_project_page( + html, "https://cache.test/simple/pkg/" + ) + assert len(artifacts) == 1 + assert artifacts[0].filename == "good-1.0-1-py3-none-any.whl" + + +# --------------------------------------------------------------------------- +# StoreRouter tests +# --------------------------------------------------------------------------- + + +class TestStoreRouter: + def test_override_wins(self) -> None: + router = StoreRouter( + overrides={canonicalize_name("torch"): "cuda"}, + accelerated_packages=set(), + active_variant="cuda", + ) + req = Requirement("torch>=2.0") + assert router.route(req) == "cuda" + + def test_accelerated_package(self) -> None: + router = StoreRouter( + overrides={}, + accelerated_packages={canonicalize_name("flash-attn")}, + active_variant="cuda", + ) + req = Requirement("flash-attn>=2.0") + assert router.route(req) == "cuda" + + def test_default_fallback(self) -> None: + router = StoreRouter( + overrides={}, + accelerated_packages={canonicalize_name("torch")}, + active_variant="cuda", + ) + req = Requirement("requests>=2.0") + assert router.route(req) == "default" + + def test_override_takes_priority_over_accelerated(self) -> None: + router = StoreRouter( + overrides={canonicalize_name("numpy"): "default"}, + accelerated_packages={canonicalize_name("numpy")}, + active_variant="cuda", + ) + req = Requirement("numpy>=1.0") + assert router.route(req) == "default" + + def test_custom_default_collection(self) -> None: + router = StoreRouter( + overrides={}, + accelerated_packages=set(), + active_variant="rocm", + default_collection="base", + ) + req = Requirement("six") + assert router.route(req) == "base" + + +# --------------------------------------------------------------------------- +# CacheManager tests +# --------------------------------------------------------------------------- + + +def _make_collection(tmp_path: pathlib.Path, name: str) -> CacheCollection: + """Create a CacheCollection with a single local backend.""" + wheels_dir = tmp_path / f"wheels-{name}" + wheels_dir.mkdir(parents=True, exist_ok=True) + backend = LocalDirectoryBackend(wheels_dir, backend_name=f"local:{name}") + return CacheCollection( + name=name, + backends=[backend], + store_backend=backend, + ) + + +class TestCacheManager: + def test_lookup_miss_empty_cache(self, tmp_path: pathlib.Path) -> None: + default = _make_collection(tmp_path, "default") + manager = CacheManager( + collections={"default": default}, + search_order=["default"], + store_routing=StoreRouter( + overrides={}, + accelerated_packages=set(), + active_variant="cpu", + ), + ) + manager.initialize() + + result = manager.lookup_wheel( + Requirement("numpy"), + Version("1.26.4"), + build_tag=(2, ""), + ) + assert result.hit is False + assert result.miss is True + + def test_lookup_hit_local(self, tmp_path: pathlib.Path) -> None: + default = _make_collection(tmp_path, "default") + _create_wheel_file( + tmp_path / "wheels-default", + "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + ) + manager = CacheManager( + collections={"default": default}, + search_order=["default"], + store_routing=StoreRouter( + overrides={}, + accelerated_packages=set(), + active_variant="cpu", + ), + ) + manager.initialize() + + result = manager.lookup_wheel( + Requirement("numpy"), + Version("1.26.4"), + build_tag=(2, ""), + ) + assert result.hit is True + assert result.collection == "default" + assert result.backend_name == "local:default" + assert result.path is not None + + def test_lookup_respects_search_order(self, tmp_path: pathlib.Path) -> None: + """CUDA collection is searched first, default second.""" + cuda = _make_collection(tmp_path, "cuda") + default = _make_collection(tmp_path, "default") + + # Put the wheel only in default + _create_wheel_file( + tmp_path / "wheels-default", + "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + ) + + manager = CacheManager( + collections={"cuda": cuda, "default": default}, + search_order=["cuda", "default"], + store_routing=StoreRouter( + overrides={}, + accelerated_packages=set(), + active_variant="cuda", + ), + ) + manager.initialize() + + result = manager.lookup_wheel( + Requirement("numpy"), + Version("1.26.4"), + build_tag=(2, ""), + ) + assert result.hit is True + assert result.collection == "default" + + def test_lookup_variant_collection_takes_priority( + self, tmp_path: pathlib.Path + ) -> None: + """When wheel exists in both, variant collection wins.""" + cuda = _make_collection(tmp_path, "cuda") + default = _make_collection(tmp_path, "default") + + _create_wheel_file( + tmp_path / "wheels-cuda", + "torch-2.10.0-7-cp312-cp312-linux_x86_64.whl", + ) + _create_wheel_file( + tmp_path / "wheels-default", + "torch-2.10.0-7-cp312-cp312-linux_x86_64.whl", + ) + + manager = CacheManager( + collections={"cuda": cuda, "default": default}, + search_order=["cuda", "default"], + store_routing=StoreRouter( + overrides={}, + accelerated_packages=set(), + active_variant="cuda", + ), + ) + manager.initialize() + + result = manager.lookup_wheel( + Requirement("torch"), + Version("2.10.0"), + build_tag=(7, ""), + ) + assert result.hit is True + assert result.collection == "cuda" + + def test_force_skips_lookup(self, tmp_path: pathlib.Path) -> None: + default = _make_collection(tmp_path, "default") + _create_wheel_file( + tmp_path / "wheels-default", + "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + ) + manager = CacheManager( + collections={"default": default}, + search_order=["default"], + store_routing=StoreRouter( + overrides={}, + accelerated_packages=set(), + active_variant="cpu", + ), + force=True, + ) + manager.initialize() + + result = manager.lookup_wheel( + Requirement("numpy"), + Version("1.26.4"), + build_tag=(2, ""), + ) + assert result.hit is False + + def test_store_routes_to_correct_collection(self, tmp_path: pathlib.Path) -> None: + cuda = _make_collection(tmp_path, "cuda") + default = _make_collection(tmp_path, "default") + + build_dir = tmp_path / "build" + whl = _create_wheel_file( + build_dir, "torch-2.10.0-7-cp312-cp312-linux_x86_64.whl" + ) + + manager = CacheManager( + collections={"cuda": cuda, "default": default}, + search_order=["cuda", "default"], + store_routing=StoreRouter( + overrides={}, + accelerated_packages={canonicalize_name("torch")}, + active_variant="cuda", + ), + ) + manager.initialize() + + result_path = manager.store_wheel( + Requirement("torch"), + Version("2.10.0"), + build_tag=(7, ""), + wheel_path=whl, + ) + + # Should be stored in cuda collection + assert "wheels-cuda" in str(result_path) + assert ( + tmp_path / "wheels-cuda" / "torch-2.10.0-7-cp312-cp312-linux_x86_64.whl" + ).exists() + + def test_store_default_collection(self, tmp_path: pathlib.Path) -> None: + cuda = _make_collection(tmp_path, "cuda") + default = _make_collection(tmp_path, "default") + + build_dir = tmp_path / "build" + whl = _create_wheel_file(build_dir, "requests-2.31.0-1-py3-none-any.whl") + + manager = CacheManager( + collections={"cuda": cuda, "default": default}, + search_order=["cuda", "default"], + store_routing=StoreRouter( + overrides={}, + accelerated_packages={canonicalize_name("torch")}, + active_variant="cuda", + ), + ) + manager.initialize() + + result_path = manager.store_wheel( + Requirement("requests"), + Version("2.31.0"), + build_tag=(1, ""), + wheel_path=whl, + ) + + assert "wheels-default" in str(result_path) + + def test_store_then_lookup(self, tmp_path: pathlib.Path) -> None: + """A stored wheel is immediately findable via lookup.""" + default = _make_collection(tmp_path, "default") + + build_dir = tmp_path / "build" + whl = _create_wheel_file(build_dir, "requests-2.31.0-1-py3-none-any.whl") + + manager = CacheManager( + collections={"default": default}, + search_order=["default"], + store_routing=StoreRouter( + overrides={}, + accelerated_packages=set(), + active_variant="cpu", + ), + ) + manager.initialize() + + manager.store_wheel( + Requirement("requests"), + Version("2.31.0"), + build_tag=(1, ""), + wheel_path=whl, + ) + + result = manager.lookup_wheel( + Requirement("requests"), + Version("2.31.0"), + build_tag=(1, ""), + ) + assert result.hit is True + + def test_store_unknown_collection_raises(self, tmp_path: pathlib.Path) -> None: + default = _make_collection(tmp_path, "default") + build_dir = tmp_path / "build" + whl = _create_wheel_file( + build_dir, "torch-2.10.0-7-cp312-cp312-linux_x86_64.whl" + ) + + manager = CacheManager( + collections={"default": default}, + search_order=["default"], + store_routing=StoreRouter( + overrides={}, + accelerated_packages={canonicalize_name("torch")}, + active_variant="cuda", # No "cuda" collection configured + ), + ) + manager.initialize() + + with pytest.raises(ValueError, match="unknown collection"): + manager.store_wheel( + Requirement("torch"), + Version("2.10.0"), + build_tag=(7, ""), + wheel_path=whl, + ) + + +# --------------------------------------------------------------------------- +# CacheStats tests +# --------------------------------------------------------------------------- + + +class TestCacheStats: + def test_empty_stats(self) -> None: + stats = CacheStats() + assert stats.hits == 0 + assert stats.misses == 0 + assert stats.stores == 0 + assert stats.hit_rate == 0.0 + + def test_record_hit(self) -> None: + stats = CacheStats() + stats.record_hit(Requirement("numpy"), Version("1.26.4"), "default", "local") + assert stats.hits == 1 + assert stats.misses == 0 + + def test_record_miss(self) -> None: + stats = CacheStats() + stats.record_miss(Requirement("numpy"), Version("1.26.4"), "not_found") + assert stats.misses == 1 + assert stats.hits == 0 + + def test_hit_rate(self) -> None: + stats = CacheStats() + stats.record_hit(Requirement("numpy"), Version("1.26.4"), "default", "local") + stats.record_hit(Requirement("requests"), Version("2.31.0"), "default", "local") + stats.record_miss(Requirement("torch"), Version("2.10.0"), "not_found") + assert stats.hit_rate == pytest.approx(2 / 3) + + def test_summary(self) -> None: + stats = CacheStats() + stats.record_hit( + Requirement("numpy"), Version("1.26.4"), "default", "local:default" + ) + stats.record_miss(Requirement("torch"), Version("2.10.0"), "not_found") + stats.record_store(Requirement("torch"), Version("2.10.0"), "cuda") + + summary = stats.summary() + assert summary["hits"]["total"] == 1 + assert summary["hits"]["by_collection"]["default"] == 1 + assert summary["misses"] == 1 + assert summary["stores"] == 1 + + +# --------------------------------------------------------------------------- +# CacheResult tests +# --------------------------------------------------------------------------- + + +class TestCacheResult: + def test_miss_property(self) -> None: + result = CacheResult(hit=False) + assert result.miss is True + assert result.hit is False + + def test_hit_result(self) -> None: + result = CacheResult( + hit=True, + path=pathlib.Path("/some/wheel.whl"), + collection="default", + backend_name="local:default", + ) + assert result.miss is False + assert result.path == pathlib.Path("/some/wheel.whl") + + +# --------------------------------------------------------------------------- +# CacheManager + RemotePEP503Backend integration tests +# --------------------------------------------------------------------------- + + +class TestCacheManagerRemoteIntegration: + """Integration tests verifying CacheManager with remote backends.""" + + def test_remote_hit_downloads_to_local_store( + self, + tmp_path: pathlib.Path, + requests_mock: requests_mock.Mocker, + ) -> None: + """CacheManager lookup via remote backend downloads wheel locally.""" + + # Set up remote backend with a wheel available + wheel_content = b"remote wheel content" + wheel_sha = "afb823df34d54af96bcc9a759d34c85fc14f30840bf45377ef911e68be9569df" + project_html = f""" + numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl + """ + requests_mock.get( + "https://cache.test/simple/", text='numpy' + ) + requests_mock.get("https://cache.test/simple/numpy/", text=project_html) + requests_mock.get( + "https://cache.test/files/numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + content=wheel_content, + ) + + local_dir = tmp_path / "local-wheels" + local_dir.mkdir() + local_backend = LocalDirectoryBackend(local_dir, backend_name="local:default") + + remote_backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=tmp_path / "downloads", + ) + + collection = CacheCollection( + name="default", + backends=[local_backend, remote_backend], + store_backend=local_backend, + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + manager = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + ) + manager.initialize() + + # Lookup should miss local, hit remote, and download + result = manager.lookup_wheel( + Requirement("numpy"), + Version("1.26.4"), + build_tag=(2, ""), + ) + + assert result.hit is True + assert result.path is not None + assert result.path.exists() + assert result.path.read_bytes() == wheel_content + assert result.backend_name == "remote:https://cache.test/simple" + assert result.was_downloaded is True + + def test_local_hit_takes_priority_over_remote( + self, + tmp_path: pathlib.Path, + requests_mock: requests_mock.Mocker, + ) -> None: + """Local backend hit means remote is never consulted.""" + local_dir = tmp_path / "local-wheels" + local_dir.mkdir() + # Put a wheel in the local cache + wheel_file = local_dir / "numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl" + wheel_file.write_bytes(b"local wheel") + + local_backend = LocalDirectoryBackend(local_dir, backend_name="local:default") + remote_backend = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=tmp_path / "downloads", + ) + + # Don't register any requests_mock responses — remote should not be hit + requests_mock.get( + "https://cache.test/simple/", text='numpy' + ) + + collection = CacheCollection( + name="default", + backends=[local_backend, remote_backend], + store_backend=local_backend, + ) + router = StoreRouter( + overrides={}, accelerated_packages=set(), active_variant="cpu" + ) + manager = CacheManager( + collections={"default": collection}, + search_order=["default"], + store_routing=router, + ) + manager.initialize() + + result = manager.lookup_wheel( + Requirement("numpy"), + Version("1.26.4"), + build_tag=(2, ""), + ) + + assert result.hit is True + assert result.path == wheel_file + assert result.backend_name == "local:default" + assert result.was_downloaded is False + # Remote project page never fetched + assert not any("simple/numpy/" in h.url for h in requests_mock.request_history) + + def test_hierarchical_search_across_collections_with_remote( + self, + tmp_path: pathlib.Path, + requests_mock: requests_mock.Mocker, + ) -> None: + """CacheManager searches variant collection first, then falls through to default.""" + # CUDA collection: empty (no local, remote has nothing matching) + cuda_local = tmp_path / "cuda-wheels" + cuda_local.mkdir() + cuda_backend = LocalDirectoryBackend(cuda_local, backend_name="local:cuda") + + # Default collection: has numpy via remote + default_local = tmp_path / "default-wheels" + default_local.mkdir() + default_backend = LocalDirectoryBackend( + default_local, backend_name="local:default" + ) + + project_html = """ + numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl + """ + requests_mock.get( + "https://cache.test/simple/", text='numpy' + ) + requests_mock.get("https://cache.test/simple/numpy/", text=project_html) + requests_mock.get( + "https://cache.test/files/numpy-1.26.4-2-cp312-cp312-linux_x86_64.whl", + content=b"remote numpy", + ) + + remote_default = RemotePEP503Backend( + server_url="https://cache.test/simple", + download_dir=tmp_path / "downloads", + ) + + cuda_collection = CacheCollection( + name="cuda", + backends=[cuda_backend], + store_backend=cuda_backend, + ) + default_collection = CacheCollection( + name="default", + backends=[default_backend, remote_default], + store_backend=default_backend, + ) + + router = StoreRouter( + overrides={}, + accelerated_packages={canonicalize_name("torch")}, + active_variant="cuda", + ) + manager = CacheManager( + collections={"cuda": cuda_collection, "default": default_collection}, + search_order=["cuda", "default"], + store_routing=router, + ) + manager.initialize() + + # numpy is not in cuda, should be found in default via remote + result = manager.lookup_wheel( + Requirement("numpy"), + Version("1.26.4"), + build_tag=(2, ""), + ) + + assert result.hit is True + assert result.collection == "default" + assert result.path is not None + assert result.path.read_bytes() == b"remote numpy"