Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions src/fromager/bootstrap_requirement_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ def __init__(
self.prev_graph = prev_graph
# Session-level resolution cache to avoid re-resolving same requirements
# Key: (requirement_string, pre_built) to distinguish source vs prebuilt
# Value: list of (url, version) tuples sorted by version (highest first)
# Value: tuple of (url, version) tuples sorted by version (highest first)
# Values are stored as immutable tuples to prevent accidental corruption
# when callers modify the returned reference.
self._resolved_requirements: dict[
tuple[str, bool], list[tuple[str, Version]]
tuple[str, bool], tuple[tuple[str, Version], ...]
] = {}

def resolve(
Expand Down Expand Up @@ -106,7 +108,7 @@ def resolve(
cached_result = self.get_cached_resolution(req, pre_built)
if cached_result is not None:
logger.debug(f"resolved {req} from cache")
return cached_result if return_all_versions else [cached_result[0]]
return list(cached_result) if return_all_versions else [cached_result[0]]

# Resolve using strategies
results = self._resolve(req, req_type, parent_req, pre_built)
Expand Down Expand Up @@ -182,15 +184,17 @@ def get_cached_resolution(
self,
req: Requirement,
pre_built: bool,
) -> list[tuple[str, Version]] | None:
) -> tuple[tuple[str, Version], ...] | None:
"""Get a cached resolution result if it exists.

Returns an immutable tuple to prevent accidental cache corruption.

Args:
req: Package requirement to look up in cache
pre_built: Whether looking for prebuilt or source resolution

Returns:
List of (url, version) tuples if cached, None otherwise
Tuple of (url, version) tuples if cached, None otherwise
"""
cache_key = (str(req), pre_built)
return self._resolved_requirements.get(cache_key)
Expand All @@ -203,6 +207,9 @@ def cache_resolution(
) -> None:
"""Cache a resolution result.

The result is stored as an immutable tuple to prevent accidental
corruption when callers modify the original list.

Used by Bootstrapper to cache git URL resolutions that are
handled externally (outside this resolver).

Expand All @@ -212,7 +219,7 @@ def cache_resolution(
result: List of (url, version) tuples
"""
cache_key = (str(req), pre_built)
self._resolved_requirements[cache_key] = result
self._resolved_requirements[cache_key] = tuple(result)

def _resolve_from_graph(
self,
Expand Down
4 changes: 3 additions & 1 deletion src/fromager/bootstrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,9 @@ def resolve_versions(
cached_result = self._resolver.get_cached_resolution(req, pre_built=False)
if cached_result is not None:
logger.debug(f"resolved {req} from cache")
return cached_result if return_all_versions else [cached_result[0]]
return (
list(cached_result) if return_all_versions else [cached_result[0]]
)

logger.info("resolving source via URL, ignoring any plugins")
source_url, resolved_version = self._resolve_version_from_git_url(req=req)
Expand Down
80 changes: 80 additions & 0 deletions tests/test_bootstrap_requirement_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,86 @@ def test_resolve_auto_routes_to_source(
assert version == Version("2.0")


def test_cache_resolution_stores_immutable_tuple(tmp_context: WorkContext) -> None:
"""cache_resolution() stores an immutable tuple, not the original list."""
resolver = BootstrapRequirementResolver(tmp_context)
req = Requirement("mypkg>=1.0")
original = [("https://example.com/mypkg-1.0.tar.gz", Version("1.0"))]

resolver.cache_resolution(req, pre_built=False, result=original)
cached = resolver.get_cached_resolution(req, pre_built=False)

# Cached value should be a tuple
assert isinstance(cached, tuple)

# Mutating the original list must not affect the cache
original.append(("https://example.com/mypkg-2.0.tar.gz", Version("2.0")))
cached_after = resolver.get_cached_resolution(req, pre_built=False)
assert cached_after is not None
assert len(cached_after) == 1


def test_get_cached_resolution_returns_immutable(tmp_context: WorkContext) -> None:
"""get_cached_resolution() returns a tuple that cannot be mutated."""
resolver = BootstrapRequirementResolver(tmp_context)
req = Requirement("mypkg>=1.0")

resolver.cache_resolution(
req,
pre_built=False,
result=[("https://example.com/mypkg-1.0.tar.gz", Version("1.0"))],
)
cached = resolver.get_cached_resolution(req, pre_built=False)
assert cached is not None

with pytest.raises(AttributeError):
cached.append(("https://example.com/bad.tar.gz", Version("2.0"))) # type: ignore[attr-defined, union-attr]

with pytest.raises(TypeError):
cached[0] = ("https://example.com/bad.tar.gz", Version("2.0")) # type: ignore[index]


@patch("fromager.resolver.find_all_matching_from_provider")
def test_resolve_cache_returns_independent_lists(
mock_resolve: MagicMock,
tmp_context: WorkContext,
) -> None:
"""resolve() returns independent list copies from the cache, not shared references."""
req = Requirement("mypkg>=1.0")
mock_resolve.return_value = [
("https://example.com/mypkg-2.0.tar.gz", Version("2.0")),
("https://example.com/mypkg-1.5.tar.gz", Version("1.5")),
]

resolver = BootstrapRequirementResolver(tmp_context)

# First call populates cache
results1 = resolver.resolve(
req=req,
req_type=RequirementType.INSTALL,
parent_req=None,
pre_built=False,
return_all_versions=True,
)

# Mutate the returned list
results1.append(("https://example.com/injected.tar.gz", Version("9.9")))

# Second call should return clean cached data, unaffected by mutation
results2 = resolver.resolve(
req=req,
req_type=RequirementType.INSTALL,
parent_req=None,
pre_built=False,
return_all_versions=True,
)

assert len(results2) == 2
assert results1 is not results2
# Only called once — second call used cache
mock_resolve.assert_called_once()


@patch("fromager.resolver.find_all_matching_from_provider")
def test_resolve_prebuilt_after_source_uses_separate_cache(
mock_resolve: MagicMock,
Expand Down
Loading