Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
373 changes: 372 additions & 1 deletion src/fromager/commands/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
import itertools
import json
import logging
import math
import pathlib
import sys
import typing

import click
import rich
import rich.box
from packaging.requirements import Requirement
from packaging.utils import canonicalize_name
from packaging.utils import NormalizedName, canonicalize_name
from packaging.version import Version
from rich.table import Table

from fromager import clickext, context
from fromager.commands import bootstrap
Expand Down Expand Up @@ -784,3 +788,370 @@ def n2s(nodes: typing.Iterable[DependencyNode]) -> str:
topo.done(*nodes_to_build)

print(f"\nBuilding {len(graph)} packages in {rounds} rounds.")


def _get_collection_name(graph_path: str) -> str:
"""Derive collection name from file path stem."""
return pathlib.Path(graph_path).stem


def _get_collection_packages(graph_path: str) -> set[NormalizedName]:
"""Load graph and return all canonical package names, excluding ROOT."""
graph = DependencyGraph.from_file(graph_path)
return {
node.canonicalized_name
for node in graph.get_all_nodes()
if node.canonicalized_name != ROOT
}


def _find_shared_packages(
collections: dict[str, set[NormalizedName]],
min_collections: int,
display_names: dict[str, str] | None = None,
) -> list[dict[str, typing.Any]]:
"""Find packages in >= min_collections collections, sorted by count desc then name asc."""
all_packages: set[NormalizedName] = set().union(*collections.values())
results: list[dict[str, typing.Any]] = []
for pkg in all_packages:
containing = [
display_names.get(key, key) if display_names else key
for key, pkgs in collections.items()
if pkg in pkgs
]
if len(containing) >= min_collections:
results.append(
{
"package": pkg,
"collections": sorted(containing),
"count": len(containing),
}
)
results.sort(key=lambda x: (-x["count"], x["package"]))
return results


def _compute_collection_impact(
collections: dict[str, set[NormalizedName]],
base_package_names: set[NormalizedName],
display_names: dict[str, str] | None = None,
) -> list[dict[str, typing.Any]]:
"""For each collection, compute how many packages remain after removing base packages.

Each entry includes per-remaining-package cross-collection counts.
Sorted by remaining package count descending, then collection name ascending.
"""
all_packages: set[NormalizedName] = set().union(*collections.values())
pkg_counts: dict[NormalizedName, int] = {
pkg: sum(1 for pkgs in collections.values() if pkg in pkgs)
for pkg in all_packages
}

result = []
for key, pkgs in collections.items():
coll_name = display_names.get(key, key) if display_names else key
base_pkgs = pkgs & base_package_names
remaining_pkgs = pkgs - base_package_names
remaining_detail = sorted(
[
{"package": pkg, "collection_count": pkg_counts[pkg]}
for pkg in remaining_pkgs
],
key=lambda x: (
-typing.cast(int, x["collection_count"]),
typing.cast(str, x["package"]),
),
)
result.append(
{
"collection": coll_name,
"total_packages": len(pkgs),
"base_packages": len(base_pkgs),
"remaining_packages": len(remaining_pkgs),
"reduction_percentage": (
round(len(base_pkgs) / len(pkgs) * 100, 1) if pkgs else 0.0
),
"remaining": remaining_detail,
}
)
result.sort(
key=lambda x: (
-typing.cast(int, x["remaining_packages"]),
typing.cast(str, x["collection"]),
)
)
return result


def _suggest_base_table(
candidates: list[dict[str, typing.Any]],
total_collections: int,
collection_names: list[str],
min_collections: int,
base_packages: set[NormalizedName] | None,
total_unique_packages: int,
impact: list[dict[str, typing.Any]],
base_only_packages: set[NormalizedName],
) -> None:
"""Display suggest-base results as a rich table."""
title = (
f"Base collection candidates "
f"(threshold: {min_collections}/{total_collections} collections)\n"
f"Collections: {', '.join(sorted(collection_names))}"
)
table = Table(title=title, box=rich.box.MARKDOWN, title_justify="left")
table.add_column("Package", justify="left", no_wrap=True)
table.add_column("Collections", justify="right", no_wrap=True)
table.add_column("Coverage", justify="right", no_wrap=True)
table.add_column("Appears In", justify="left")
if base_packages is not None:
table.add_column("In Base", justify="center", no_wrap=True)

already_in_base = 0
new_candidates = 0
for entry in candidates:
pkg = entry["package"]
count = entry["count"]
cols = entry["collections"]
coverage = f"{(count / total_collections) * 100:.1f}%"
count_str = f"{count}/{total_collections}"
appears_in = ", ".join(cols)
if base_packages is not None:
in_base = pkg in base_packages
if in_base:
already_in_base += 1
else:
new_candidates += 1
table.add_row(
pkg, count_str, coverage, appears_in, "yes" if in_base else "no"
)
else:
new_candidates += 1
table.add_row(pkg, count_str, coverage, appears_in)

console = rich.get_console()
console.print(table)
console.print(f"\nTotal unique packages: {total_unique_packages}")
console.print(f"Packages in >= {min_collections} collections: {len(candidates)}")
if base_packages is not None:
console.print(f"Already in base: {already_in_base}")
console.print(f"New candidates: {new_candidates}")

# Collection Impact table
impact_table = Table(
title="Collection Impact", box=rich.box.MARKDOWN, title_justify="left"
)
impact_table.add_column("Collection", justify="left", no_wrap=True)
impact_table.add_column("Total Pkgs", justify="right", no_wrap=True)
impact_table.add_column("In Base", justify="right", no_wrap=True)
impact_table.add_column("Remaining", justify="right", no_wrap=True)
impact_table.add_column("% Saved", justify="right", no_wrap=True)
for entry in impact:
impact_table.add_row(
entry["collection"],
str(entry["total_packages"]),
str(entry["base_packages"]),
str(entry["remaining_packages"]),
f"{entry['reduction_percentage']:.1f}%",
)
console.print(impact_table)

# Remaining Packages table — deduplicated across all collections
seen: set[NormalizedName] = set()
remaining_rows: list[dict[str, typing.Any]] = []
for entry in impact:
for pkg_entry in entry["remaining"]:
pkg = pkg_entry["package"]
if pkg not in seen:
seen.add(pkg)
remaining_rows.append(pkg_entry)
remaining_rows.sort(key=lambda x: (-x["collection_count"], x["package"]))

remaining_table = Table(
title="Remaining Packages (not in proposed base)",
box=rich.box.MARKDOWN,
title_justify="left",
)
remaining_table.add_column("Package", justify="left", no_wrap=True)
remaining_table.add_column("Collections", justify="right", no_wrap=True)
remaining_table.add_column("Coverage", justify="right", no_wrap=True)
for pkg_entry in remaining_rows:
count = pkg_entry["collection_count"]
remaining_table.add_row(
pkg_entry["package"],
f"{count}/{total_collections}",
f"{(count / total_collections) * 100:.1f}%",
)
console.print(remaining_table)

if base_only_packages:
base_only_table = Table(
title="Existing Base Packages (carried forward, not new candidates)",
box=rich.box.MARKDOWN,
title_justify="left",
)
base_only_table.add_column("Package", justify="left", no_wrap=True)
for pkg in sorted(base_only_packages):
base_only_table.add_row(str(pkg))
console.print(base_only_table)


def _suggest_base_json(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Would it be useful to group these parameters into a dataclass (ex: SuggestBaseResult )? The signature is bit wide and it might make things easier to extend later and maintain contract between json output and table output. Just asking to to hear your thoughts.

candidates: list[dict[str, typing.Any]],
total_collections: int,
collection_names: list[str],
min_collections: int,
base_packages: set[NormalizedName] | None,
base_graph: str | None,
total_unique_packages: int,
impact: list[dict[str, typing.Any]],
base_only_packages: set[NormalizedName],
) -> None:
"""Display suggest-base results as JSON."""
output: dict[str, typing.Any] = {
"metadata": {
"total_collections": total_collections,
"total_unique_packages": total_unique_packages,
"packages_meeting_threshold": len(candidates),
"collections": sorted(collection_names),
"min_collections": min_collections,
},
"candidates": [],
"collection_impact": impact,
}
if base_graph is not None:
output["metadata"]["base_graph"] = base_graph

for entry in candidates:
pkg = entry["package"]
count = entry["count"]
cols = entry["collections"]
candidate: dict[str, typing.Any] = {
"package": pkg,
"collections": cols,
"collection_count": count,
"coverage_percentage": round((count / total_collections) * 100, 1),
}
if base_packages is not None:
candidate["in_base"] = pkg in base_packages
output["candidates"].append(candidate)

if base_only_packages:
output["base_only_packages"] = sorted(str(p) for p in base_only_packages)

json.dump(output, sys.stdout, indent=2)


def _suggest_base_impl(
collection_graphs: tuple[str, ...],
base_graph: str | None,
min_collections: int | None,
output_format: str,
) -> None:
"""Core implementation for suggest_base, testable without a click context."""
if len(collection_graphs) < 2:
raise click.UsageError("At least 2 collection graphs are required")
if min_collections is None:
min_collections = max(2, math.ceil(len(collection_graphs) / 2))
elif min_collections < 2:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: We can delegate this check to click by using type=click.IntRange(min=2)

raise click.UsageError("--min-collections must be >= 2")
if min_collections > len(collection_graphs):
raise click.UsageError(
f"--min-collections ({min_collections}) cannot exceed number of graphs ({len(collection_graphs)})"
)

# Load each collection, keyed by resolved path to avoid stem collisions
collections: dict[str, set[NormalizedName]] = {}
display_names: dict[str, str] = {}
for path in collection_graphs:
key = str(pathlib.Path(path).resolve())
name = _get_collection_name(path)
pkgs = _get_collection_packages(path)
if not pkgs:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The len(collection_graphs) >= 2 check is done before excluding empty collections, so passing 2 graphs where one is empty leaves only 1 collection. Can we add if len(collections) < 2: raise click.UsageError after the loop too.

logger.warning("Collection %s is empty, skipping", name)
continue
collections[key] = pkgs
display_names[key] = name

# Load base graph if provided
base_packages: set[NormalizedName] | None = None
if base_graph:
base_packages = _get_collection_packages(base_graph)

total_unique_packages = len(set().union(*collections.values()))
candidates = _find_shared_packages(collections, min_collections, display_names)
total = len(collections)

candidate_names: set[NormalizedName] = {entry["package"] for entry in candidates}
# The full proposed base includes existing base packages (all carried forward)
proposed_base: set[NormalizedName] = (
candidate_names | base_packages if base_packages else candidate_names
)
# Packages carried from the existing base that are not new candidates
base_only_packages: set[NormalizedName] = (
base_packages - candidate_names if base_packages else set()
)
impact = _compute_collection_impact(collections, proposed_base, display_names)

if output_format == "json":
_suggest_base_json(
candidates,
total,
list(display_names.values()),
min_collections,
base_packages,
base_graph,
total_unique_packages,
impact,
base_only_packages,
)
else:
_suggest_base_table(
candidates,
total,
list(display_names.values()),
min_collections,
base_packages,
total_unique_packages,
impact,
base_only_packages,
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.


@graph.command()
@click.option(
"--base",
"base_graph",
type=str,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use clickext.ClickPath(exists=True) instead of type=str, it gives early validation with a clean Click error message. It also returns pathlib.Path directly, which would eliminate the manual type conversion.

default=None,
help="Existing base collection graph to enhance",
)
@click.option(
"--min-collections",
type=int,
default=None,
help="Minimum collections a package must appear in (default: 50% of provided collections)",
)
@click.option(
"--format",
"output_format",
type=click.Choice(["table", "json"]),
default="table",
help="Output format (default: table)",
)
@click.argument("collection_graphs", nargs=-1, required=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same suggestion as above to use clickext.ClickPath(exists=True)

@click.pass_obj
def suggest_base(
wkctx: context.WorkContext,
Copy link
Copy Markdown
Contributor

@smoparth smoparth May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WorkContext is redundant for this command. All required data is captured by cli args and used in same and we don't pass it anywhere from here.

collection_graphs: tuple[str, ...],
base_graph: str | None,
min_collections: int | None,
output_format: str,
) -> None:
"""Suggest packages for a shared base collection.

Analyzes COLLECTION_GRAPHS (2 or more graph files) to identify packages
appearing across multiple collections. These are candidates for factoring
into a base collection built once and reused.
"""
_suggest_base_impl(collection_graphs, base_graph, min_collections, output_format)
Loading
Loading