diff --git a/src/fromager/packagesettings/_models.py b/src/fromager/packagesettings/_models.py index 4e982e18..2a5f4557 100644 --- a/src/fromager/packagesettings/_models.py +++ b/src/fromager/packagesettings/_models.py @@ -15,7 +15,7 @@ from pydantic import AnyUrl, Field from pydantic_core import core_schema -# from ._resolver import SourceResolver +from ._resolver import SourceResolver from ._typedefs import ( MODEL_CONFIG, BuildDirectory, @@ -324,9 +324,11 @@ class VariantInfo(pydantic.BaseModel): pre_built: bool = False """Use pre-built wheel from index server?""" - # TODO - # source: SourceResolver | None - # """Source resolver and downloader""" + source: SourceResolver | None = None + """Source resolver and downloader + + .. versionadded:: 0.86 + """ class GitOptions(pydantic.BaseModel): @@ -336,6 +338,7 @@ class GitOptions(pydantic.BaseModel): submodules: False submodule_paths: [] + remove_dot_git: False """ model_config = MODEL_CONFIG @@ -358,6 +361,18 @@ class GitOptions(pydantic.BaseModel): - ["vendor/lib1", "vendor/lib2"] """ + remove_dot_git: bool = False + """Remove ``.git`` directory after cloning? + + When True, the ``.git`` directory is removed from the cloned source + tree so it does not end up in the built sdist. Defaults to False + to preserve backward compatibility with existing ``req.url`` git + clones that rely on ``.git`` for version detection (e.g. via + setuptools-scm). + + .. versionadded:: 0.85 + """ + _DictStrAny = dict[str, typing.Any] @@ -452,9 +467,11 @@ class PackageSettings(pydantic.BaseModel): project_override: ProjectOverride = Field(default_factory=ProjectOverride) """Patch project settings""" - # TODO - # source: SourceResolver | None - # """Source resolver and downloader""" + source: SourceResolver | None = None + """Source resolver and downloader + + .. versionadded:: 0.86 + """ variants: Mapping[Variant, VariantInfo] = Field(default_factory=dict) """Variant configuration""" diff --git a/src/fromager/packagesettings/_pbi.py b/src/fromager/packagesettings/_pbi.py index ff31cd0c..ef269bc3 100644 --- a/src/fromager/packagesettings/_pbi.py +++ b/src/fromager/packagesettings/_pbi.py @@ -26,6 +26,7 @@ if typing.TYPE_CHECKING: from .. import build_environment + from ._resolver import SourceResolver from ._settings import Settings logger = logging.getLogger(__name__) @@ -176,6 +177,18 @@ def pre_built(self) -> bool: return vi.pre_built return False + @property + def source_resolver(self) -> SourceResolver | None: + """Effective source resolver for this package and variant. + + Returns the variant-level ``source`` override if set, otherwise + the package-level ``source``, or ``None`` when neither is configured. + """ + vi = self._ps.variants.get(self._variant) + if vi is not None and vi.source is not None: + return vi.source + return self._ps.source + @property def wheel_server_url(self) -> str | None: """Alternative package index for pre-build wheel""" diff --git a/src/fromager/packagesettings/_resolver.py b/src/fromager/packagesettings/_resolver.py index a7ae076c..a9f0a395 100644 --- a/src/fromager/packagesettings/_resolver.py +++ b/src/fromager/packagesettings/_resolver.py @@ -34,7 +34,9 @@ class AbstractResolver(pydantic.BaseModel): provider: str def resolver_provider( - self, ctx: context.WorkContext, req_type: requirements_file.RequirementType + self, + ctx: context.WorkContext, + req_type: requirements_file.RequirementType | None = None, ) -> resolver.BaseProvider: raise NotImplementedError @@ -88,7 +90,9 @@ class PyPISDistResolver(AbstractPyPIResolver): build_sdist: typing.ClassVar[BuildSDist | None] = BuildSDist.tarball def resolver_provider( - self, ctx: context.WorkContext, req_type: requirements_file.RequirementType + self, + ctx: context.WorkContext, + req_type: requirements_file.RequirementType | None = None, ) -> resolver.PyPIProvider: return resolver.PyPIProvider( include_sdists=True, @@ -123,7 +127,9 @@ class PyPIPrebuiltResolver(AbstractPyPIResolver): build_sdist: typing.ClassVar[BuildSDist | None] = None def resolver_provider( - self, ctx: context.WorkContext, req_type: requirements_file.RequirementType + self, + ctx: context.WorkContext, + req_type: requirements_file.RequirementType | None = None, ) -> resolver.PyPIProvider: return resolver.PyPIProvider( include_sdists=False, @@ -177,7 +183,9 @@ def validate_download_url(cls, value: pydantic.HttpUrl) -> pydantic.HttpUrl: return value def resolver_provider( - self, ctx: context.WorkContext, req_type: requirements_file.RequirementType + self, + ctx: context.WorkContext, + req_type: requirements_file.RequirementType | None = None, ) -> resolver.PyPIProvider: return resolver.PyPIProvider( include_sdists=True, @@ -245,7 +253,9 @@ def validate_tag(cls, value: str) -> str: return value def resolver_provider( - self, ctx: context.WorkContext, req_type: requirements_file.RequirementType + self, + ctx: context.WorkContext, + req_type: requirements_file.RequirementType | None = None, ) -> resolver.PyPIProvider: download_url = f"git+{self.clone_url}@refs/tags/{self.tag}" return resolver.PyPIProvider( @@ -342,7 +352,7 @@ def _github_provider( self, *, ctx: context.WorkContext, - req_type: requirements_file.RequirementType, + req_type: requirements_file.RequirementType | None = None, override_download_url: str | None = None, ) -> resolver.GitHubTagProvider: if self.project_url.host != "github.com": @@ -366,7 +376,7 @@ def _gitlab_provider( self, *, ctx: context.WorkContext, - req_type: requirements_file.RequirementType, + req_type: requirements_file.RequirementType | None = None, override_download_url: str | None = None, ) -> resolver.GitLabTagProvider: if not self.project_url.path: @@ -398,7 +408,9 @@ class GitHubTagDownloadResolver(AbstractGitSourceResolver): provider: typing.Literal["github-tag-download"] def resolver_provider( - self, ctx: context.WorkContext, req_type: requirements_file.RequirementType + self, + ctx: context.WorkContext, + req_type: requirements_file.RequirementType | None = None, ) -> resolver.GitHubTagProvider: return self._github_provider( ctx=ctx, @@ -423,7 +435,9 @@ class GitHubTagCloneResolver(AbstractGitSourceResolver): provider: typing.Literal["github-tag-git"] def resolver_provider( - self, ctx: context.WorkContext, req_type: requirements_file.RequirementType + self, + ctx: context.WorkContext, + req_type: requirements_file.RequirementType | None = None, ) -> resolver.GitHubTagProvider: return self._github_provider( ctx=ctx, @@ -448,7 +462,9 @@ class GitLabTagDownloadResolver(AbstractGitSourceResolver): provider: typing.Literal["gitlab-tag-download"] def resolver_provider( - self, ctx: context.WorkContext, req_type: requirements_file.RequirementType + self, + ctx: context.WorkContext, + req_type: requirements_file.RequirementType | None = None, ) -> resolver.GitLabTagProvider: return self._gitlab_provider( ctx=ctx, @@ -473,7 +489,9 @@ class GitLabTagCloneResolver(AbstractGitSourceResolver): provider: typing.Literal["gitlab-tag-git"] def resolver_provider( - self, ctx: context.WorkContext, req_type: requirements_file.RequirementType + self, + ctx: context.WorkContext, + req_type: requirements_file.RequirementType | None = None, ) -> resolver.GitLabTagProvider: return self._gitlab_provider( ctx=ctx, @@ -488,7 +506,9 @@ class NotAvailableResolver(AbstractResolver): provider: typing.Literal["not-available"] def resolver_provider( - self, ctx: context.WorkContext, req_type: requirements_file.RequirementType + self, + ctx: context.WorkContext, + req_type: requirements_file.RequirementType | None = None, ) -> resolver.BaseProvider: raise ValueError("package is not available") @@ -499,7 +519,9 @@ class HookResolver(AbstractResolver): provider: typing.Literal["hook"] def resolver_provider( - self, ctx: context.WorkContext, req_type: requirements_file.RequirementType + self, + ctx: context.WorkContext, + req_type: requirements_file.RequirementType | None = None, ) -> resolver.BaseProvider: # TODO raise NotImplementedError("Hook resolver needs a hook") diff --git a/src/fromager/resolver.py b/src/fromager/resolver.py index d6e113d1..9012c116 100644 --- a/src/fromager/resolver.py +++ b/src/fromager/resolver.py @@ -90,19 +90,28 @@ def resolve( Returns (url, version) tuple for the highest matching version. """ - # Create the (reusable) resolver. - provider = overrides.find_and_invoke( - req.name, - "get_resolver_provider", - default_resolver_provider, - ctx=ctx, - req=req, - include_sdists=include_sdists, - include_wheels=include_wheels, - sdist_server_url=sdist_server_url, - req_type=req_type, - ignore_platform=ignore_platform, - ) + pbi = ctx.package_build_info(req) + source = pbi.source_resolver + if source is not None and source.provider != "hook": + logger.info( + "%s: using source resolver provider %r", + req.name, + source.provider, + ) + provider = source.resolver_provider(ctx, req_type) + else: + provider = overrides.find_and_invoke( + req.name, + "get_resolver_provider", + default_resolver_provider, + ctx=ctx, + req=req, + include_sdists=include_sdists, + include_wheels=include_wheels, + sdist_server_url=sdist_server_url, + req_type=req_type, + ignore_platform=ignore_platform, + ) provider.cooldown = resolve_package_cooldown(ctx, req, req_type=req_type) max_age_cutoff = _compute_max_age_cutoff(ctx) results = find_all_matching_from_provider( @@ -119,14 +128,8 @@ def default_resolver_provider( include_wheels: bool, req_type: RequirementType | None = None, ignore_platform: bool = False, -) -> ( - PyPIProvider - | GenericProvider - | GitHubTagProvider - | GitLabTagProvider - | VersionMapProvider -): - """Lookup resolver provider to resolve package versions""" +) -> BaseProvider: + """Lookup resolver provider to resolve package versions.""" return PyPIProvider( include_sdists=include_sdists, include_wheels=include_wheels, diff --git a/src/fromager/sources.py b/src/fromager/sources.py index bd308adb..6511dd2a 100644 --- a/src/fromager/sources.py +++ b/src/fromager/sources.py @@ -105,16 +105,28 @@ def download_source( ) return download_path - source_path = overrides.find_and_invoke( - req.name, - "download_source", - default_download_source, - ctx=ctx, - req=req, - version=version, - download_url=download_url, - sdists_downloads_dir=ctx.sdists_downloads, - ) + pbi = ctx.package_build_info(req) + source = pbi.source_resolver + + if source is not None and source.provider != "hook": + source_path = default_download_source( + ctx=ctx, + req=req, + version=version, + download_url=download_url, + sdists_downloads_dir=ctx.sdists_downloads, + ) + else: + source_path = overrides.find_and_invoke( + req.name, + "download_source", + default_download_source, + ctx=ctx, + req=req, + version=version, + download_url=download_url, + sdists_downloads_dir=ctx.sdists_downloads, + ) if not isinstance(source_path, pathlib.Path): raise ValueError( @@ -136,23 +148,27 @@ def get_source_provider( (sdist/wheel inclusion, platform matching, server URL override). """ pbi = ctx.package_build_info(req) - override_sdist_server_url = pbi.resolver_sdist_server_url(sdist_server_url) + source = pbi.source_resolver - provider = typing.cast( - resolver.BaseProvider, - overrides.find_and_invoke( - req.name, - "get_resolver_provider", - resolver.default_resolver_provider, - ctx=ctx, - req=req, - include_sdists=pbi.resolver_include_sdists, - include_wheels=pbi.resolver_include_wheels, - sdist_server_url=override_sdist_server_url, - req_type=req_type, - ignore_platform=pbi.resolver_ignore_platform, - ), - ) + if source is not None and source.provider != "hook": + provider = source.resolver_provider(ctx, req_type) + else: + override_sdist_server_url = pbi.resolver_sdist_server_url(sdist_server_url) + provider = typing.cast( + resolver.BaseProvider, + overrides.find_and_invoke( + req.name, + "get_resolver_provider", + resolver.default_resolver_provider, + ctx=ctx, + req=req, + include_sdists=pbi.resolver_include_sdists, + include_wheels=pbi.resolver_include_wheels, + sdist_server_url=override_sdist_server_url, + req_type=req_type, + ignore_platform=pbi.resolver_ignore_platform, + ), + ) provider.cooldown = resolver.resolve_package_cooldown(ctx, req, req_type=req_type) return provider @@ -199,6 +215,30 @@ def resolve_source( raise +def _is_git_url(url: str) -> bool: + """Return True if *url* is a VCS-style ``git+https://`` URL.""" + return url.startswith("git+https://") + + +def _parse_git_url(url: str) -> tuple[str, str | None]: + """Split a VCS URL into clone URL and optional ref. + + ``git+https://host/repo@ref`` -> ``(https://host/repo, ref)`` + ``git+https://host/repo`` -> ``(https://host/repo, None)`` + """ + clone_url = url + if clone_url.startswith("git+"): + clone_url = clone_url[len("git+") :] + + parsed = urlparse(clone_url) + ref: str | None = None + if "@" in parsed.path: + new_path, _, ref = parsed.path.rpartition("@") + clone_url = parsed._replace(path=new_path).geturl() + + return clone_url, ref + + def default_download_source( ctx: context.WorkContext, req: Requirement, @@ -208,10 +248,29 @@ def default_download_source( ) -> pathlib.Path: "Download the requirement and return the name of the output path." pbi = ctx.package_build_info(req) - destination_filename = pbi.download_source_destination_filename(version=version) url = pbi.download_source_url(version=version, default=download_url) if url is None: raise ValueError(f"Could not determine download URL for {req}") + + if _is_git_url(url): + clone_url, ref = _parse_git_url(url) + download_path = ctx.work_dir / f"{req.name}-{version}" / f"{req.name}-{version}" + download_path.mkdir(parents=True, exist_ok=True) + gitutils.git_clone_fast( + output_dir=download_path, + repo_url=clone_url, + ref=ref or "HEAD", + ) + if pbi.git_options.remove_dot_git: + for dot_git in download_path.rglob(".git"): + logger.info("removing %s", dot_git) + if dot_git.is_dir(): + shutil.rmtree(dot_git) + else: + dot_git.unlink() + return download_path + + destination_filename = pbi.download_source_destination_filename(version=version) if destination_filename is None: url_filename = resolver.extract_filename_from_url(url) if url_filename.endswith(".zip"): @@ -239,6 +298,11 @@ def download_git_source( destination_dir: pathlib.Path, ref: str | None = None, ) -> None: + """Clone a git repository into *destination_dir*. + + Applies ``git_options`` from the package settings (submodules, + ``remove_dot_git``). + """ if url_to_clone.startswith("git+"): url_to_clone = url_to_clone[len("git+") :] @@ -265,6 +329,14 @@ def download_git_source( ref=ref, ) + if git_opts.remove_dot_git: + for dot_git in destination_dir.rglob(".git"): + logger.info("removing %s", dot_git) + if dot_git.is_dir(): + shutil.rmtree(dot_git) + else: + dot_git.unlink() + # Helper method to check whether .zip /.tar / .tgz is able to extract and check its content. # It will throw exception if any other file is encountered. Eg: index.html diff --git a/tests/test_packagesettings.py b/tests/test_packagesettings.py index f65c1d67..781fbb5e 100644 --- a/tests/test_packagesettings.py +++ b/tests/test_packagesettings.py @@ -71,6 +71,7 @@ "git_options": { "submodules": False, "submodule_paths": [], + "remove_dot_git": False, }, "name": "test-pkg", "has_config": True, @@ -88,6 +89,7 @@ "use_pypi_org_metadata": True, "min_release_age": None, }, + "source": None, "variants": { "cpu": { "annotations": { @@ -96,6 +98,7 @@ "env": {"EGG": "spam ${EGG}", "EGG_AGAIN": "$EGG"}, "wheel_server_url": "https://wheel.test/simple", "pre_built": False, + "source": None, }, "rocm": { "annotations": { @@ -104,12 +107,14 @@ "env": {"SPAM": ""}, "wheel_server_url": None, "pre_built": True, + "source": None, }, "cuda": { "annotations": None, "env": {}, "wheel_server_url": None, "pre_built": False, + "source": None, }, }, } @@ -134,6 +139,7 @@ "git_options": { "submodules": False, "submodule_paths": [], + "remove_dot_git": False, }, "has_config": True, "purl": None, @@ -150,6 +156,7 @@ "use_pypi_org_metadata": None, "min_release_age": None, }, + "source": None, "variants": {}, } @@ -175,6 +182,7 @@ "git_options": { "submodules": False, "submodule_paths": [], + "remove_dot_git": False, }, "has_config": True, "purl": None, @@ -191,12 +199,14 @@ "use_pypi_org_metadata": None, "min_release_age": None, }, + "source": None, "variants": { "cpu": { "annotations": None, "env": {}, "pre_built": True, "wheel_server_url": None, + "source": None, }, }, } diff --git a/tests/test_source_resolver_wiring.py b/tests/test_source_resolver_wiring.py new file mode 100644 index 00000000..c20e8d38 --- /dev/null +++ b/tests/test_source_resolver_wiring.py @@ -0,0 +1,374 @@ +"""Tests for wiring source resolver configuration into the runtime. + +Covers: +- ``PackageBuildInfo.source_resolver`` property +- ``default_resolver_provider()`` dispatching to source resolver +- ``default_download_source()`` handling ``git+`` URLs +- ``_is_git_url`` and ``_parse_git_url`` helpers +- ``download_git_source()`` honouring ``remove_dot_git`` +""" + +from __future__ import annotations + +import pathlib +from unittest.mock import MagicMock, patch + +from packaging.requirements import Requirement +from packaging.version import Version + +from fromager import context, packagesettings, resolver, sources +from fromager.requirements_file import RequirementType + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_context_with_source( + tmp_path: pathlib.Path, + source_yaml: dict, + *, + variant: str = "cpu", + variant_source_yaml: dict | None = None, + git_options: dict | None = None, +) -> context.WorkContext: + """Create a ``WorkContext`` whose test-pkg has a ``source`` field.""" + pkg_config: dict = {"source": source_yaml} + if variant_source_yaml is not None: + pkg_config["variants"] = {variant: {"source": variant_source_yaml}} + if git_options is not None: + pkg_config["git_options"] = git_options + + settings_file = packagesettings.SettingsFile() + ps = packagesettings.PackageSettings.from_mapping( + "test-pkg", + pkg_config, + source="test", + has_config=True, + ) + settings = packagesettings.Settings( + settings=settings_file, + package_settings=[ps], + patches_dir=tmp_path / "patches", + variant=variant, + max_jobs=None, + ) + ctx = context.WorkContext( + active_settings=settings, + patches_dir=tmp_path / "patches", + sdists_repo=tmp_path / "sdists-repo", + wheels_repo=tmp_path / "wheels-repo", + work_dir=tmp_path / "work-dir", + ) + ctx.setup() + return ctx + + +# --------------------------------------------------------------------------- +# _is_git_url / _parse_git_url +# --------------------------------------------------------------------------- + + +class TestGitUrlHelpers: + def test_is_git_url_https(self) -> None: + assert sources._is_git_url("git+https://github.test/org/repo@v1.0") is True + + def test_is_git_url_ssh_not_supported(self) -> None: + assert sources._is_git_url("git+ssh://git@github.test/org/repo@v1.0") is False + + def test_is_git_url_plain_https(self) -> None: + assert sources._is_git_url("https://pypi.test/simple") is False + + def test_is_git_url_tarball(self) -> None: + assert sources._is_git_url("https://pkg.test/pkg-1.0.tar.gz") is False + + def test_parse_git_url_with_ref(self) -> None: + url = "git+https://github.test/org/repo.git@refs/tags/v1.0" + clone_url, ref = sources._parse_git_url(url) + assert clone_url == "https://github.test/org/repo.git" + assert ref == "refs/tags/v1.0" + + def test_parse_git_url_simple_ref(self) -> None: + url = "git+https://github.test/org/repo@v2.0" + clone_url, ref = sources._parse_git_url(url) + assert clone_url == "https://github.test/org/repo" + assert ref == "v2.0" + + def test_parse_git_url_no_ref(self) -> None: + url = "git+https://github.test/org/repo.git" + clone_url, ref = sources._parse_git_url(url) + assert clone_url == "https://github.test/org/repo.git" + assert ref is None + + +# --------------------------------------------------------------------------- +# PackageBuildInfo.source_resolver +# --------------------------------------------------------------------------- + + +class TestSourceResolverProperty: + def test_no_source_returns_none(self, tmp_context: context.WorkContext) -> None: + req = Requirement("some-pkg") + pbi = tmp_context.package_build_info(req) + assert pbi.source_resolver is None + + def test_package_level_source(self, tmp_path: pathlib.Path) -> None: + ctx = _make_context_with_source( + tmp_path, + {"provider": "pypi-sdist", "index_url": "https://pypi.test/simple"}, + ) + req = Requirement("test-pkg") + pbi = ctx.package_build_info(req) + assert pbi.source_resolver is not None + assert pbi.source_resolver.provider == "pypi-sdist" + + def test_variant_overrides_package(self, tmp_path: pathlib.Path) -> None: + ctx = _make_context_with_source( + tmp_path, + {"provider": "pypi-sdist"}, + variant_source_yaml={ + "provider": "pypi-prebuilt", + "index_url": "https://wheels.test/simple", + }, + ) + req = Requirement("test-pkg") + pbi = ctx.package_build_info(req) + sr = pbi.source_resolver + assert sr is not None + assert sr.provider == "pypi-prebuilt" + + def test_variant_without_source_falls_back_to_package( + self, tmp_path: pathlib.Path + ) -> None: + ctx = _make_context_with_source( + tmp_path, + {"provider": "pypi-sdist"}, + ) + req = Requirement("test-pkg") + pbi = ctx.package_build_info(req) + assert pbi.source_resolver is not None + assert pbi.source_resolver.provider == "pypi-sdist" + + +# --------------------------------------------------------------------------- +# Source resolver dispatch (source config takes priority over plugins) +# --------------------------------------------------------------------------- + + +class TestSourceResolverDispatch: + def test_source_config_produces_correct_provider( + self, tmp_path: pathlib.Path + ) -> None: + ctx = _make_context_with_source( + tmp_path, + { + "provider": "pypi-sdist", + "index_url": "https://custom.test/simple", + }, + ) + req = Requirement("test-pkg") + provider = sources.get_source_provider( + ctx=ctx, + req=req, + sdist_server_url="https://pypi.test/simple/", + req_type=RequirementType.INSTALL, + ) + assert isinstance(provider, resolver.PyPIProvider) + assert provider.sdist_server_url == "https://custom.test/simple" + + def test_falls_back_to_pypi_when_no_source( + self, tmp_context: context.WorkContext + ) -> None: + req = Requirement("unknown-pkg") + provider = sources.get_source_provider( + ctx=tmp_context, + req=req, + sdist_server_url="https://pypi.test/simple/", + req_type=RequirementType.INSTALL, + ) + assert isinstance(provider, resolver.PyPIProvider) + assert provider.sdist_server_url == "https://pypi.test/simple/" + + def test_github_tag_resolver_produces_github_provider( + self, tmp_path: pathlib.Path + ) -> None: + ctx = _make_context_with_source( + tmp_path, + { + "provider": "github-tag-download", + "project_url": "https://github.com/python-wheel-build/fromager", + }, + ) + req = Requirement("test-pkg") + provider = sources.get_source_provider( + ctx=ctx, + req=req, + sdist_server_url="https://pypi.test/simple/", + req_type=RequirementType.INSTALL, + ) + assert isinstance(provider, resolver.GitHubTagProvider) + assert provider.organization == "python-wheel-build" + assert provider.repo == "fromager" + + +# --------------------------------------------------------------------------- +# default_download_source with git URLs +# --------------------------------------------------------------------------- + + +class TestDefaultDownloadSourceGitUrl: + @patch("fromager.sources.gitutils.git_clone_fast") + def test_routes_git_url_to_git_clone_fast( + self, + mock_clone_fast: MagicMock, + tmp_context: context.WorkContext, + ) -> None: + req = Requirement("test-pkg==1.0") + version = Version("1.0") + git_url = "git+https://github.test/org/repo.git@refs/tags/v1.0" + + result = sources.default_download_source( + tmp_context, + req, + version, + git_url, + tmp_context.sdists_downloads, + ) + + mock_clone_fast.assert_called_once_with( + output_dir=result, + repo_url="https://github.test/org/repo.git", + ref="refs/tags/v1.0", + ) + assert result.name == "test-pkg-1.0" + + @patch("fromager.sources._download_source_check") + def test_non_git_url_downloads_tarball( + self, + mock_check: MagicMock, + tmp_context: context.WorkContext, + ) -> None: + req = Requirement("test-pkg==1.0") + version = Version("1.0") + tarball_url = "https://packages.test/test-pkg-1.0.tar.gz" + mock_check.return_value = pathlib.Path("test-pkg-1.0.tar.gz") + + sources.default_download_source( + tmp_context, + req, + version, + tarball_url, + tmp_context.sdists_downloads, + ) + + mock_check.assert_called_once() + assert mock_check.call_args[1]["url"] == tarball_url + + +# --------------------------------------------------------------------------- +# download_git_source + remove_dot_git +# --------------------------------------------------------------------------- + + +class TestDownloadGitSourceRemoveDotGit: + @patch("fromager.sources.gitutils.git_clone") + def test_keeps_dot_git_by_default( + self, + mock_git_clone: MagicMock, + tmp_path: pathlib.Path, + tmp_context: context.WorkContext, + ) -> None: + dest = tmp_path / "source" + dest.mkdir() + dot_git = dest / ".git" + dot_git.mkdir() + (dot_git / "HEAD").write_text("ref: refs/heads/main\n") + + req = Requirement("test-pkg") + sources.download_git_source( + ctx=tmp_context, + req=req, + url_to_clone="https://github.test/org/repo.git", + destination_dir=dest, + ref="v1.0", + ) + + mock_git_clone.assert_called_once() + assert dot_git.exists() + + @patch("fromager.sources.gitutils.git_clone") + def test_removes_dot_git_when_enabled( + self, + mock_git_clone: MagicMock, + tmp_path: pathlib.Path, + ) -> None: + ctx = _make_context_with_source( + tmp_path, + {"provider": "pypi-sdist"}, + git_options={"remove_dot_git": True}, + ) + dest = tmp_path / "source" + dest.mkdir() + dot_git = dest / ".git" + dot_git.mkdir() + (dot_git / "HEAD").write_text("ref: refs/heads/main\n") + + req = Requirement("test-pkg") + sources.download_git_source( + ctx=ctx, + req=req, + url_to_clone="https://github.test/org/repo.git", + destination_dir=dest, + ref="v1.0", + ) + + mock_git_clone.assert_called_once() + assert not dot_git.exists() + + @patch("fromager.sources.gitutils.git_clone") + def test_keeps_dot_git_when_disabled( + self, + mock_git_clone: MagicMock, + tmp_path: pathlib.Path, + ) -> None: + ctx = _make_context_with_source( + tmp_path, + {"provider": "pypi-sdist"}, + git_options={"remove_dot_git": False}, + ) + dest = tmp_path / "source" + dest.mkdir() + dot_git = dest / ".git" + dot_git.mkdir() + (dot_git / "HEAD").write_text("ref: refs/heads/main\n") + + req = Requirement("test-pkg") + sources.download_git_source( + ctx=ctx, + req=req, + url_to_clone="https://github.test/org/repo.git", + destination_dir=dest, + ref="v1.0", + ) + + mock_git_clone.assert_called_once() + assert dot_git.exists() + + +# --------------------------------------------------------------------------- +# GitOptions.remove_dot_git field +# --------------------------------------------------------------------------- + + +class TestGitOptionsRemoveDotGit: + def test_default_is_false(self) -> None: + opts = packagesettings.GitOptions() + assert opts.remove_dot_git is False + + def test_can_set_true(self) -> None: + opts = packagesettings.GitOptions(remove_dot_git=True) + assert opts.remove_dot_git is True + + def test_can_set_false(self) -> None: + opts = packagesettings.GitOptions(remove_dot_git=False) + assert opts.remove_dot_git is False