diff --git a/src/specify_cli/__init__.py b/src/specify_cli/__init__.py index d5f5aba2d5..99d313e123 100644 --- a/src/specify_cli/__init__.py +++ b/src/specify_cli/__init__.py @@ -2576,7 +2576,7 @@ def preset_list(): @preset_app.command("add") def preset_add( preset_id: str = typer.Argument(None, help="Preset ID to install from catalog"), - from_url: str = typer.Option(None, "--from", help="Install from a URL (ZIP file)"), + from_url: str = typer.Option(None, "--from", help="Install from a URL (ZIP or .tar.gz/.tgz archive)"), dev: str = typer.Option(None, "--dev", help="Install from local directory (development mode)"), priority: int = typer.Option(10, "--priority", help="Resolution priority (lower = higher precedence, default 10)"), ): @@ -2629,17 +2629,24 @@ def preset_add( import urllib.request import urllib.error import tempfile + from .extensions import _detect_archive_format as _det_fmt with tempfile.TemporaryDirectory() as tmpdir: - zip_path = Path(tmpdir) / "preset.zip" + archive_fmt = _det_fmt(from_url) try: with urllib.request.urlopen(from_url, timeout=60) as response: - zip_path.write_bytes(response.read()) + if not archive_fmt: + content_type = response.headers.get("Content-Type", "") + archive_fmt = _det_fmt(from_url, content_type) + archive_data = response.read() except urllib.error.URLError as e: console.print(f"[red]Error:[/red] Failed to download: {e}") raise typer.Exit(1) - manifest = manager.install_from_zip(zip_path, speckit_version, priority) + suffix = ".tar.gz" if archive_fmt == "tar.gz" else ".zip" + archive_path = Path(tmpdir) / f"preset{suffix}" + archive_path.write_bytes(archive_data) + manifest = manager.install_from_zip(archive_path, speckit_version, priority) console.print(f"[green]✓[/green] Preset '{manifest.name}' v{manifest.version} installed (priority {priority})") @@ -3573,7 +3580,7 @@ def catalog_remove( def extension_add( extension: str = typer.Argument(help="Extension name or path"), dev: bool = typer.Option(False, "--dev", help="Install from local directory"), - from_url: Optional[str] = typer.Option(None, "--from", help="Install from custom URL"), + from_url: Optional[str] = typer.Option(None, "--from", help="Install from custom URL (ZIP or .tar.gz/.tgz archive)"), priority: int = typer.Option(10, "--priority", help="Resolution priority (lower = higher precedence, default 10)"), ): """Install an extension.""" @@ -3612,10 +3619,11 @@ def extension_add( manifest = manager.install_from_directory(source_path, speckit_version, priority=priority) elif from_url: - # Install from URL (ZIP file) + # Install from URL (ZIP or tar.gz archive) import urllib.request import urllib.error from urllib.parse import urlparse + from .extensions import _detect_archive_format # Validate URL parsed = urlparse(from_url) @@ -3631,25 +3639,32 @@ def extension_add( console.print("Only install extensions from sources you trust.\n") console.print(f"Downloading from {from_url}...") - # Download ZIP to temp location + # Download archive to temp location; detect format from URL or Content-Type. download_dir = project_root / ".specify" / "extensions" / ".cache" / "downloads" download_dir.mkdir(parents=True, exist_ok=True) - zip_path = download_dir / f"{extension}-url-download.zip" + archive_fmt = _detect_archive_format(from_url) + archive_path = None try: with urllib.request.urlopen(from_url, timeout=60) as response: - zip_data = response.read() - zip_path.write_bytes(zip_data) + if not archive_fmt: + content_type = response.headers.get("Content-Type", "") + archive_fmt = _detect_archive_format(from_url, content_type) + archive_data = response.read() - # Install from downloaded ZIP - manifest = manager.install_from_zip(zip_path, speckit_version, priority=priority) + suffix = ".tar.gz" if archive_fmt == "tar.gz" else ".zip" + archive_path = download_dir / f"{extension}-url-download{suffix}" + archive_path.write_bytes(archive_data) + + # Install from downloaded archive + manifest = manager.install_from_zip(archive_path, speckit_version, priority=priority) except urllib.error.URLError as e: console.print(f"[red]Error:[/red] Failed to download from {from_url}: {e}") raise typer.Exit(1) finally: - # Clean up downloaded ZIP - if zip_path.exists(): - zip_path.unlink() + # Clean up the downloaded archive + if archive_path is not None and archive_path.exists(): + archive_path.unlink() else: # Try bundled extensions first (shipped with spec-kit) @@ -4303,27 +4318,47 @@ def extension_update( # 5. Download new version zip_path = catalog.download_extension(extension_id) try: - # 6. Validate extension ID from ZIP BEFORE modifying installation - # Handle both root-level and nested extension.yml (GitHub auto-generated ZIPs) - with zipfile.ZipFile(zip_path, "r") as zf: - import yaml - manifest_data = None - namelist = zf.namelist() - - # First try root-level extension.yml - if "extension.yml" in namelist: - with zf.open("extension.yml") as f: - manifest_data = yaml.safe_load(f) or {} - else: - # Look for extension.yml in a single top-level subdirectory - # (e.g., "repo-name-branch/extension.yml") - manifest_paths = [n for n in namelist if n.endswith("/extension.yml") and n.count("/") == 1] - if len(manifest_paths) == 1: - with zf.open(manifest_paths[0]) as f: + # 6. Validate extension ID from archive BEFORE modifying installation + # Handle both root-level and nested extension.yml (GitHub auto-generated archives) + from .extensions import _detect_archive_format + import tarfile + archive_fmt = _detect_archive_format(str(zip_path)) + import yaml + manifest_data = None + + if archive_fmt == "tar.gz": + with tarfile.open(zip_path, "r:gz") as tf: + # First try root-level extension.yml + try: + m = tf.getmember("extension.yml") + f = tf.extractfile(m) + if f is not None: + manifest_data = yaml.safe_load(f.read()) or {} + except KeyError: + # Look for extension.yml in a single top-level subdirectory + members = [m for m in tf.getmembers() if m.name.endswith("/extension.yml") and m.name.count("/") == 1] + if len(members) == 1: + f = tf.extractfile(members[0]) + if f is not None: + manifest_data = yaml.safe_load(f.read()) or {} + else: + with zipfile.ZipFile(zip_path, "r") as zf: + namelist = zf.namelist() + + # First try root-level extension.yml + if "extension.yml" in namelist: + with zf.open("extension.yml") as f: manifest_data = yaml.safe_load(f) or {} + else: + # Look for extension.yml in a single top-level subdirectory + # (e.g., "repo-name-branch/extension.yml") + manifest_paths = [n for n in namelist if n.endswith("/extension.yml") and n.count("/") == 1] + if len(manifest_paths) == 1: + with zf.open(manifest_paths[0]) as f: + manifest_data = yaml.safe_load(f) or {} - if manifest_data is None: - raise ValueError("Downloaded extension archive is missing 'extension.yml'") + if manifest_data is None: + raise ValueError("Downloaded extension archive is missing 'extension.yml'") zip_extension_id = manifest_data.get("extension", {}).get("id") if zip_extension_id != extension_id: @@ -4875,6 +4910,57 @@ def workflow_list(): console.print() +def _extract_workflow_yml(archive_path: Path, archive_fmt: str) -> bytes: + """Extract ``workflow.yml`` from a ZIP or ``.tar.gz`` archive. + + Searches the archive root and a single nested top-level subdirectory + (e.g., ``repo-name-1.0/workflow.yml``). + + Args: + archive_path: Path to the downloaded archive. + archive_fmt: ``"zip"`` or ``"tar.gz"``. + + Returns: + Raw bytes of the ``workflow.yml`` file. + + Raises: + ValueError: If no ``workflow.yml`` is found in the archive. + """ + import tarfile + + if archive_fmt == "tar.gz": + with tarfile.open(archive_path, "r:gz") as tf: + # Try root-level first. + try: + f = tf.extractfile(tf.getmember("workflow.yml")) + if f is not None: + return f.read() + except KeyError: + pass + # Look in a single top-level subdirectory. + candidates = [ + m for m in tf.getmembers() + if m.name.endswith("/workflow.yml") and m.name.count("/") == 1 + ] + if len(candidates) == 1: + f = tf.extractfile(candidates[0]) + if f is not None: + return f.read() + else: + with zipfile.ZipFile(archive_path, "r") as zf: + namelist = zf.namelist() + if "workflow.yml" in namelist: + return zf.read("workflow.yml") + candidates = [ + n for n in namelist + if n.endswith("/workflow.yml") and n.count("/") == 1 + ] + if len(candidates) == 1: + return zf.read(candidates[0]) + + raise ValueError("No workflow.yml found in the downloaded archive") + + @workflow_app.command("add") def workflow_add( source: str = typer.Argument(..., help="Workflow ID, URL, or local path"), @@ -4928,6 +5014,7 @@ def _validate_and_install_local(yaml_path: Path, source_label: str) -> None: from ipaddress import ip_address from urllib.parse import urlparse from urllib.request import urlopen # noqa: S310 + from .extensions import _detect_archive_format parsed_src = urlparse(source) src_host = parsed_src.hostname or "" @@ -4958,18 +5045,51 @@ def _validate_and_install_local(yaml_path: Path, source_label: str) -> None: if final_parsed.scheme != "https" and not (final_parsed.scheme == "http" and final_lb): console.print(f"[red]Error:[/red] URL redirected to non-HTTPS: {final_url}") raise typer.Exit(1) + + # Detect archive format from the final URL or Content-Type header. + archive_fmt = _detect_archive_format(final_url) + if not archive_fmt: + content_type = resp.headers.get("Content-Type", "") + archive_fmt = _detect_archive_format(final_url, content_type) + + raw_data = resp.read() + except typer.Exit: + raise + except Exception as exc: + console.print(f"[red]Error:[/red] Failed to download workflow: {exc}") + raise typer.Exit(1) + + tmp_path = None + try: + if archive_fmt in ("tar.gz", "zip"): + # Extract workflow.yml from the archive. + suffix = ".tar.gz" if archive_fmt == "tar.gz" else ".zip" + with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as arc_tmp: + arc_tmp.write(raw_data) + arc_tmp_path = Path(arc_tmp.name) + try: + wf_yaml = _extract_workflow_yml(arc_tmp_path, archive_fmt) + with tempfile.NamedTemporaryFile(suffix=".yml", delete=False) as tmp: + tmp.write(wf_yaml) + tmp_path = Path(tmp.name) + finally: + arc_tmp_path.unlink(missing_ok=True) + else: + # Treat as a plain YAML file (existing behaviour). with tempfile.NamedTemporaryFile(suffix=".yml", delete=False) as tmp: - tmp.write(resp.read()) + tmp.write(raw_data) tmp_path = Path(tmp.name) except typer.Exit: raise except Exception as exc: - console.print(f"[red]Error:[/red] Failed to download workflow: {exc}") + console.print(f"[red]Error:[/red] Failed to process downloaded workflow: {exc}") raise typer.Exit(1) + try: _validate_and_install_local(tmp_path, source) finally: - tmp_path.unlink(missing_ok=True) + if tmp_path is not None: + tmp_path.unlink(missing_ok=True) return # Try as a local file/directory @@ -4978,6 +5098,26 @@ def _validate_and_install_local(yaml_path: Path, source_label: str) -> None: if source_path.is_file() and source_path.suffix in (".yml", ".yaml"): _validate_and_install_local(source_path, str(source_path)) return + elif source_path.is_file() and ( + source.endswith(".tar.gz") or source.endswith(".tgz") or source.endswith(".zip") + ): + # Local archive file containing workflow.yml + from .extensions import _detect_archive_format + local_fmt = _detect_archive_format(source) + try: + wf_yaml = _extract_workflow_yml(source_path, local_fmt) + except (ValueError, Exception) as exc: + console.print(f"[red]Error:[/red] Failed to extract workflow from archive: {exc}") + raise typer.Exit(1) + import tempfile + with tempfile.NamedTemporaryFile(suffix=".yml", delete=False) as tmp: + tmp.write(wf_yaml) + tmp_local = Path(tmp.name) + try: + _validate_and_install_local(tmp_local, str(source_path)) + finally: + tmp_local.unlink(missing_ok=True) + return elif source_path.is_dir(): wf_file = source_path / "workflow.yml" if not wf_file.exists(): @@ -5041,6 +5181,7 @@ def _validate_and_install_local(yaml_path: Path, source_label: str) -> None: try: from urllib.request import urlopen # noqa: S310 — URL comes from catalog + from .extensions import _detect_archive_format workflow_dir.mkdir(parents=True, exist_ok=True) with urlopen(workflow_url, timeout=30) as response: # noqa: S310 @@ -5063,7 +5204,30 @@ def _validate_and_install_local(yaml_path: Path, source_label: str) -> None: f"[red]Error:[/red] Workflow '{source}' redirected to non-HTTPS URL: {final_url}" ) raise typer.Exit(1) - workflow_file.write_bytes(response.read()) + + # Detect archive format from the final URL or Content-Type header. + cat_archive_fmt = _detect_archive_format(final_url) + if not cat_archive_fmt: + cat_ct = response.headers.get("Content-Type", "") + cat_archive_fmt = _detect_archive_format(final_url, cat_ct) + + raw_response = response.read() + + if cat_archive_fmt in ("tar.gz", "zip"): + # Download URL points to an archive — extract workflow.yml from it. + suffix = ".tar.gz" if cat_archive_fmt == "tar.gz" else ".zip" + with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as arc_f: + arc_f.write(raw_response) + arc_tmp = Path(arc_f.name) + try: + wf_yaml_bytes = _extract_workflow_yml(arc_tmp, cat_archive_fmt) + finally: + arc_tmp.unlink(missing_ok=True) + workflow_file.write_bytes(wf_yaml_bytes) + else: + workflow_file.write_bytes(raw_response) + except typer.Exit: + raise except Exception as exc: if workflow_dir.exists(): import shutil diff --git a/src/specify_cli/extensions.py b/src/specify_cli/extensions.py index a419ebf1d2..f28c02e9b7 100644 --- a/src/specify_cli/extensions.py +++ b/src/specify_cli/extensions.py @@ -9,6 +9,8 @@ import json import hashlib import os +import sys +import tarfile import tempfile import zipfile import shutil @@ -106,6 +108,115 @@ def normalize_priority(value: Any, default: int = 10) -> int: return priority if priority >= 1 else default +def _detect_archive_format(url: str, content_type: str = "") -> str: + """Detect archive format from URL path extension or Content-Type header. + + Args: + url: URL or file path to inspect. + content_type: Optional ``Content-Type`` header value from the HTTP response. + + Returns: + ``"zip"`` for ZIP archives, ``"tar.gz"`` for gzipped tarballs, or ``""`` + when the format cannot be determined. + """ + # Strip query-string / fragment before examining the path extension. + url_path = url.split("?")[0].split("#")[0].lower() + if url_path.endswith(".zip"): + return "zip" + if url_path.endswith(".tar.gz") or url_path.endswith(".tgz"): + return "tar.gz" + + # Fall back to Content-Type header inspection. + ct = content_type.lower() + if "application/zip" in ct or "application/x-zip" in ct: + return "zip" + if any( + t in ct + for t in ( + "application/gzip", + "application/x-gzip", + "application/x-tar+gzip", + ) + ): + return "tar.gz" + + return "" + + +def _safe_extract_tarball( + archive_path: Path, + dest_dir: Path, + error_class: "type[Exception]" = Exception, +) -> None: + """Safely extract a ``.tar.gz`` or ``.tgz`` archive into *dest_dir*. + + All members are validated before extraction to prevent *tar slip* + (path traversal) attacks. Symlinks, hard links, and special files + (devices, FIFOs, etc.) are rejected. + + On Python 3.12 and later the ``"data"`` extraction filter is applied + for an additional layer of OS-level protection. On earlier versions + the explicit member list (containing only pre-validated regular files + and directories) is passed to ``extractall()`` — since all symlinks are + already rejected in the validation phase, no archive-introduced symlink + can be followed during extraction. + + Args: + archive_path: Path to the ``.tar.gz``/``.tgz`` archive. + dest_dir: Destination directory (must already exist). + error_class: Exception class to raise on unsafe entries. + + Raises: + error_class: If any member is unsafe or the archive cannot be read. + """ + dest_resolved = dest_dir.resolve() + + with tarfile.open(archive_path, "r:gz") as tf: + members = tf.getmembers() + safe_members = [] + + # Validate every member before extracting anything. + for member in members: + # Reject absolute paths and any path component that is "..". + if os.path.isabs(member.name) or any( + part == ".." for part in member.name.replace("\\", "/").split("/") + ): + raise error_class( + f"Unsafe path in tar archive: {member.name} (potential path traversal)" + ) + + # Confirm the resolved path stays inside dest_dir. + member_path = (dest_dir / member.name).resolve() + try: + member_path.relative_to(dest_resolved) + except ValueError: + raise error_class( + f"Unsafe path in tar archive: {member.name} (potential path traversal)" + ) + + # Reject symlinks and hard links. + if member.issym() or member.islnk(): + raise error_class( + f"Symlinks are not allowed in archive: {member.name}" + ) + + # Only allow regular files and directories. + if not (member.isreg() or member.isdir()): + raise error_class( + f"Non-regular file in archive: {member.name}" + ) + + safe_members.append(member) + + # Extract — use the "data" filter on Python 3.12+ for extra hardening. + # On older versions pass only the pre-validated members so that no + # unvetted entry (added concurrently or via a race) slips through. + if sys.version_info >= (3, 12): + tf.extractall(dest_dir, filter="data") # type: ignore[call-arg] + else: + tf.extractall(dest_dir, members=safe_members) # noqa: S202 — validated above + + @dataclass class CatalogEntry: """Represents a single catalog entry in the catalog stack.""" @@ -1202,10 +1313,10 @@ def install_from_zip( speckit_version: str, priority: int = 10, ) -> ExtensionManifest: - """Install extension from ZIP file. + """Install extension from a ZIP or ``.tar.gz``/``.tgz`` archive. Args: - zip_path: Path to extension ZIP file + zip_path: Path to the extension archive (ZIP or gzipped tarball). speckit_version: Current spec-kit version priority: Resolution priority (lower = higher precedence, default 10) @@ -1213,7 +1324,8 @@ def install_from_zip( Installed extension manifest Raises: - ValidationError: If manifest is invalid or priority is invalid + ValidationError: If manifest is invalid, the archive is unsafe, or + priority is invalid CompatibilityError: If extension is incompatible """ # Validate priority early @@ -1223,21 +1335,27 @@ def install_from_zip( with tempfile.TemporaryDirectory() as tmpdir: temp_path = Path(tmpdir) - # Extract ZIP safely (prevent Zip Slip attack) - with zipfile.ZipFile(zip_path, 'r') as zf: - # Validate all paths first before extracting anything - temp_path_resolved = temp_path.resolve() - for member in zf.namelist(): - member_path = (temp_path / member).resolve() - # Use is_relative_to for safe path containment check - try: - member_path.relative_to(temp_path_resolved) - except ValueError: - raise ValidationError( - f"Unsafe path in ZIP archive: {member} (potential path traversal)" - ) - # Only extract after all paths are validated - zf.extractall(temp_path) + archive_fmt = _detect_archive_format(str(zip_path)) + + if archive_fmt == "tar.gz": + # Extract tarball safely (prevent tar slip attack) + _safe_extract_tarball(zip_path, temp_path, ValidationError) + else: + # Extract ZIP safely (prevent Zip Slip attack) + with zipfile.ZipFile(zip_path, 'r') as zf: + # Validate all paths first before extracting anything + temp_path_resolved = temp_path.resolve() + for member in zf.namelist(): + member_path = (temp_path / member).resolve() + # Use is_relative_to for safe path containment check + try: + member_path.relative_to(temp_path_resolved) + except ValueError: + raise ValidationError( + f"Unsafe path in ZIP archive: {member} (potential path traversal)" + ) + # Only extract after all paths are validated + zf.extractall(temp_path) # Find extension directory (may be nested) extension_dir = temp_path @@ -1251,7 +1369,7 @@ def install_from_zip( manifest_path = extension_dir / "extension.yml" if not manifest_path.exists(): - raise ValidationError("No extension.yml found in ZIP file") + raise ValidationError("No extension.yml found in archive") # Install from extracted directory return self.install_from_directory(extension_dir, speckit_version, priority=priority) @@ -1965,14 +2083,18 @@ def get_extension_info(self, extension_id: str) -> Optional[Dict[str, Any]]: return None def download_extension(self, extension_id: str, target_dir: Optional[Path] = None) -> Path: - """Download extension ZIP from catalog. + """Download extension archive from catalog. + + Supports both ZIP (``.zip``) and gzipped tarball (``.tar.gz``/``.tgz``) + archives. The format is detected from the download URL's path extension; + when ambiguous the ``Content-Type`` header is used as a fallback. Args: extension_id: ID of the extension to download - target_dir: Directory to save ZIP file (defaults to temp directory) + target_dir: Directory to save the archive (defaults to cache directory) Returns: - Path to downloaded ZIP file + Path to downloaded archive file Raises: ExtensionError: If extension not found or download fails @@ -2011,21 +2133,35 @@ def download_extension(self, extension_id: str, target_dir: Optional[Path] = Non target_dir.mkdir(parents=True, exist_ok=True) version = ext_info.get("version", "unknown") - zip_filename = f"{extension_id}-{version}.zip" - zip_path = target_dir / zip_filename - # Download the ZIP file + # Detect archive format from URL; resolve via Content-Type when needed. + archive_fmt = _detect_archive_format(download_url) + + # Download the archive try: with self._open_url(download_url, timeout=60) as response: - zip_data = response.read() - - zip_path.write_bytes(zip_data) - return zip_path + if not archive_fmt: + content_type = response.headers.get("Content-Type", "") + archive_fmt = _detect_archive_format(download_url, content_type) + archive_data = response.read() except urllib.error.URLError as e: raise ExtensionError(f"Failed to download extension from {download_url}: {e}") except IOError as e: - raise ExtensionError(f"Failed to save extension ZIP: {e}") + raise ExtensionError(f"Failed to save extension archive: {e}") + + # Choose file extension based on detected format. + if archive_fmt == "tar.gz": + archive_filename = f"{extension_id}-{version}.tar.gz" + else: + archive_filename = f"{extension_id}-{version}.zip" + + archive_path = target_dir / archive_filename + try: + archive_path.write_bytes(archive_data) + except IOError as e: + raise ExtensionError(f"Failed to save extension archive: {e}") + return archive_path def clear_cache(self): """Clear the catalog cache (both legacy and URL-hash-based files).""" diff --git a/src/specify_cli/presets.py b/src/specify_cli/presets.py index 27054a77fc..0d597423e1 100644 --- a/src/specify_cli/presets.py +++ b/src/specify_cli/presets.py @@ -27,7 +27,7 @@ from packaging import version as pkg_version from packaging.specifiers import SpecifierSet, InvalidSpecifier -from .extensions import ExtensionRegistry, normalize_priority +from .extensions import ExtensionRegistry, normalize_priority, _detect_archive_format, _safe_extract_tarball def _substitute_core_template( @@ -1604,10 +1604,10 @@ def install_from_zip( speckit_version: str, priority: int = 10, ) -> PresetManifest: - """Install preset from ZIP file. + """Install preset from a ZIP or ``.tar.gz``/``.tgz`` archive. Args: - zip_path: Path to preset ZIP file + zip_path: Path to the preset archive (ZIP or gzipped tarball). speckit_version: Current spec-kit version priority: Resolution priority (lower = higher precedence, default 10) @@ -1615,7 +1615,8 @@ def install_from_zip( Installed preset manifest Raises: - PresetValidationError: If manifest is invalid or priority is invalid + PresetValidationError: If manifest is invalid, the archive is unsafe, + or priority is invalid PresetCompatibilityError: If pack is incompatible """ # Validate priority early @@ -1625,18 +1626,24 @@ def install_from_zip( with tempfile.TemporaryDirectory() as tmpdir: temp_path = Path(tmpdir) - with zipfile.ZipFile(zip_path, 'r') as zf: - temp_path_resolved = temp_path.resolve() - for member in zf.namelist(): - member_path = (temp_path / member).resolve() - try: - member_path.relative_to(temp_path_resolved) - except ValueError: - raise PresetValidationError( - f"Unsafe path in ZIP archive: {member} " - "(potential path traversal)" - ) - zf.extractall(temp_path) + archive_fmt = _detect_archive_format(str(zip_path)) + + if archive_fmt == "tar.gz": + # Extract tarball safely (prevent tar slip attack) + _safe_extract_tarball(zip_path, temp_path, PresetValidationError) + else: + with zipfile.ZipFile(zip_path, 'r') as zf: + temp_path_resolved = temp_path.resolve() + for member in zf.namelist(): + member_path = (temp_path / member).resolve() + try: + member_path.relative_to(temp_path_resolved) + except ValueError: + raise PresetValidationError( + f"Unsafe path in ZIP archive: {member} " + "(potential path traversal)" + ) + zf.extractall(temp_path) pack_dir = temp_path manifest_path = pack_dir / "preset.yml" @@ -1649,7 +1656,7 @@ def install_from_zip( if not manifest_path.exists(): raise PresetValidationError( - "No preset.yml found in ZIP file" + "No preset.yml found in archive" ) return self.install_from_directory(pack_dir, speckit_version, priority) @@ -2242,14 +2249,18 @@ def get_pack_info( def download_pack( self, pack_id: str, target_dir: Optional[Path] = None ) -> Path: - """Download preset ZIP from catalog. + """Download preset archive from catalog. + + Supports both ZIP (``.zip``) and gzipped tarball (``.tar.gz``/``.tgz``) + archives. The format is detected from the download URL's path extension; + when ambiguous the ``Content-Type`` header is used as a fallback. Args: pack_id: ID of the preset to download - target_dir: Directory to save ZIP file (defaults to cache directory) + target_dir: Directory to save the archive (defaults to cache directory) Returns: - Path to downloaded ZIP file + Path to downloaded archive file Raises: PresetError: If pack not found or download fails @@ -2301,22 +2312,36 @@ def download_pack( target_dir.mkdir(parents=True, exist_ok=True) version = pack_info.get("version", "unknown") - zip_filename = f"{pack_id}-{version}.zip" - zip_path = target_dir / zip_filename + + # Detect archive format from URL; resolve via Content-Type when needed. + archive_fmt = _detect_archive_format(download_url) try: with self._open_url(download_url, timeout=60) as response: - zip_data = response.read() - - zip_path.write_bytes(zip_data) - return zip_path + if not archive_fmt: + content_type = response.headers.get("Content-Type", "") + archive_fmt = _detect_archive_format(download_url, content_type) + archive_data = response.read() except urllib.error.URLError as e: raise PresetError( f"Failed to download preset from {download_url}: {e}" ) except IOError as e: - raise PresetError(f"Failed to save preset ZIP: {e}") + raise PresetError(f"Failed to save preset archive: {e}") + + # Choose file extension based on detected format. + if archive_fmt == "tar.gz": + archive_filename = f"{pack_id}-{version}.tar.gz" + else: + archive_filename = f"{pack_id}-{version}.zip" + + archive_path = target_dir / archive_filename + try: + archive_path.write_bytes(archive_data) + except IOError as e: + raise PresetError(f"Failed to save preset archive: {e}") + return archive_path def clear_cache(self): """Clear all catalog cache files, including per-URL hashed caches.""" diff --git a/tests/test_extensions.py b/tests/test_extensions.py index c5be0ab4f3..6310140070 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -178,6 +178,47 @@ def test_custom_default(self): assert normalize_priority("invalid", default=1) == 1 +# ===== _detect_archive_format Tests ===== + +class TestDetectArchiveFormat: + """Test the _detect_archive_format helper.""" + + def _fmt(self, url, ct=""): + from specify_cli.extensions import _detect_archive_format + return _detect_archive_format(url, ct) + + def test_zip_url_extension(self): + assert self._fmt("https://example.com/ext-1.0.0.zip") == "zip" + + def test_tar_gz_url_extension(self): + assert self._fmt("https://example.com/ext-1.0.0.tar.gz") == "tar.gz" + + def test_tgz_url_extension(self): + assert self._fmt("https://example.com/ext-1.0.0.tgz") == "tar.gz" + + def test_zip_uppercase_url_extension(self): + assert self._fmt("https://example.com/ext.ZIP") == "zip" + + def test_tar_gz_with_query_string(self): + assert self._fmt("https://example.com/ext.tar.gz?token=abc") == "tar.gz" + + def test_zip_content_type_fallback(self): + assert self._fmt("https://example.com/download", "application/zip") == "zip" + + def test_gzip_content_type_fallback(self): + assert self._fmt("https://example.com/download", "application/gzip") == "tar.gz" + + def test_x_gzip_content_type_fallback(self): + assert self._fmt("https://example.com/download", "application/x-gzip") == "tar.gz" + + def test_unknown_returns_empty_string(self): + assert self._fmt("https://example.com/workflow.yml") == "" + + def test_url_extension_takes_precedence_over_content_type(self): + # URL says .zip — content-type claiming gzip should not override. + assert self._fmt("https://example.com/ext.zip", "application/gzip") == "zip" + + # ===== ExtensionManifest Tests ===== class TestExtensionManifest: @@ -1013,6 +1054,97 @@ def test_config_backup_on_remove(self, extension_dir, project_dir): assert backup_file.read_text() == "test: config" +# ===== install_from_zip Tarball Tests ===== + +class TestInstallFromTarball: + """Tests for install_from_zip accepting .tar.gz/.tgz archives.""" + + def _make_tarball(self, dest: Path, extension_dir: Path, nested: bool = False) -> None: + """Create a minimal .tar.gz archive from *extension_dir*.""" + import tarfile + with tarfile.open(dest, "w:gz") as tf: + for file_path in extension_dir.rglob("*"): + if file_path.is_file(): + arcname = file_path.relative_to(extension_dir) + if nested: + arcname = Path("test-ext-v1.0.0") / arcname + tf.add(file_path, arcname=str(arcname)) + + def test_install_from_tar_gz(self, extension_dir, project_dir, temp_dir): + """install_from_zip should accept a .tar.gz archive.""" + archive = temp_dir / "test-ext-1.0.0.tar.gz" + self._make_tarball(archive, extension_dir) + + manager = ExtensionManager(project_dir) + manifest = manager.install_from_zip(archive, "0.1.0") + assert manifest.id == "test-ext" + assert manager.registry.is_installed("test-ext") + + def test_install_from_tgz(self, extension_dir, project_dir, temp_dir): + """install_from_zip should accept a .tgz archive.""" + archive = temp_dir / "test-ext-1.0.0.tgz" + self._make_tarball(archive, extension_dir) + + manager = ExtensionManager(project_dir) + manifest = manager.install_from_zip(archive, "0.1.0") + assert manifest.id == "test-ext" + assert manager.registry.is_installed("test-ext") + + def test_install_from_tar_gz_nested(self, extension_dir, project_dir, temp_dir): + """install_from_zip should handle a single nested directory inside the tarball.""" + archive = temp_dir / "test-ext-nested.tar.gz" + self._make_tarball(archive, extension_dir, nested=True) + + manager = ExtensionManager(project_dir) + manifest = manager.install_from_zip(archive, "0.1.0") + assert manifest.id == "test-ext" + assert manager.registry.is_installed("test-ext") + + def test_install_from_tar_gz_no_manifest(self, project_dir, temp_dir): + """install_from_zip raises ValidationError when tarball has no extension.yml.""" + import tarfile + import io + archive = temp_dir / "bad.tar.gz" + with tarfile.open(archive, "w:gz") as tf: + data = b"no manifest here" + info = tarfile.TarInfo(name="readme.txt") + info.size = len(data) + tf.addfile(info, io.BytesIO(data)) + + manager = ExtensionManager(project_dir) + with pytest.raises(ValidationError, match="No extension.yml found"): + manager.install_from_zip(archive, "0.1.0") + + def test_install_from_tar_gz_rejects_path_traversal(self, project_dir, temp_dir): + """install_from_zip must reject tarballs with path traversal entries.""" + import tarfile + import io + archive = temp_dir / "evil.tar.gz" + with tarfile.open(archive, "w:gz") as tf: + info = tarfile.TarInfo(name="../../evil.txt") + data = b"evil" + info.size = len(data) + tf.addfile(info, io.BytesIO(data)) + + manager = ExtensionManager(project_dir) + with pytest.raises(ValidationError, match="Unsafe path"): + manager.install_from_zip(archive, "0.1.0") + + def test_install_from_tar_gz_rejects_symlinks(self, project_dir, temp_dir): + """install_from_zip must reject tarballs containing symlinks.""" + import tarfile + archive = temp_dir / "symlink.tar.gz" + with tarfile.open(archive, "w:gz") as tf: + info = tarfile.TarInfo(name="link") + info.type = tarfile.SYMTYPE + info.linkname = "/etc/passwd" + tf.addfile(info) + + manager = ExtensionManager(project_dir) + with pytest.raises(ValidationError, match="Symlinks"): + manager.install_from_zip(archive, "0.1.0") + + # ===== CommandRegistrar Tests ===== class TestCommandRegistrar: diff --git a/tests/test_presets.py b/tests/test_presets.py index 4b167ed9be..f6f0a5a086 100644 --- a/tests/test_presets.py +++ b/tests/test_presets.py @@ -649,6 +649,76 @@ def test_install_from_zip_no_manifest(self, project_dir, temp_dir): with pytest.raises(PresetValidationError, match="No preset.yml found"): manager.install_from_zip(zip_path, "0.1.5") + def _make_tarball(self, dest, pack_dir, nested=False): + import tarfile + with tarfile.open(dest, "w:gz") as tf: + for file_path in pack_dir.rglob("*"): + if file_path.is_file(): + arcname = file_path.relative_to(pack_dir) + if nested: + arcname = Path("test-pack-v1.0.0") / arcname + tf.add(file_path, arcname=str(arcname)) + + def test_install_from_tar_gz(self, project_dir, pack_dir, temp_dir): + """Test installing a preset from a .tar.gz archive.""" + archive = temp_dir / "test-pack-1.0.tar.gz" + self._make_tarball(archive, pack_dir) + + manager = PresetManager(project_dir) + manifest = manager.install_from_zip(archive, "0.1.5") + assert manifest.id == "test-pack" + assert manager.registry.is_installed("test-pack") + + def test_install_from_tgz(self, project_dir, pack_dir, temp_dir): + """Test installing a preset from a .tgz archive.""" + archive = temp_dir / "test-pack-1.0.tgz" + self._make_tarball(archive, pack_dir) + + manager = PresetManager(project_dir) + manifest = manager.install_from_zip(archive, "0.1.5") + assert manifest.id == "test-pack" + assert manager.registry.is_installed("test-pack") + + def test_install_from_tar_gz_nested(self, project_dir, pack_dir, temp_dir): + """Test installing a preset from a .tar.gz archive with a single nested directory.""" + archive = temp_dir / "test-pack-nested.tar.gz" + self._make_tarball(archive, pack_dir, nested=True) + + manager = PresetManager(project_dir) + manifest = manager.install_from_zip(archive, "0.1.5") + assert manifest.id == "test-pack" + assert manager.registry.is_installed("test-pack") + + def test_install_from_tar_gz_no_manifest(self, project_dir, temp_dir): + """Test installing a preset from a .tar.gz without preset.yml raises error.""" + import tarfile + import io + archive = temp_dir / "bad.tar.gz" + with tarfile.open(archive, "w:gz") as tf: + data = b"no manifest here" + info = tarfile.TarInfo(name="readme.txt") + info.size = len(data) + tf.addfile(info, io.BytesIO(data)) + + manager = PresetManager(project_dir) + with pytest.raises(PresetValidationError, match="No preset.yml found"): + manager.install_from_zip(archive, "0.1.5") + + def test_install_from_tar_gz_rejects_path_traversal(self, project_dir, temp_dir): + """install_from_zip must reject tarballs with path traversal entries.""" + import tarfile + import io + archive = temp_dir / "evil.tar.gz" + with tarfile.open(archive, "w:gz") as tf: + info = tarfile.TarInfo(name="../../evil.txt") + data = b"evil" + info.size = len(data) + tf.addfile(info, io.BytesIO(data)) + + manager = PresetManager(project_dir) + with pytest.raises(PresetValidationError, match="Unsafe path"): + manager.install_from_zip(archive, "0.1.5") + def test_remove(self, project_dir, pack_dir): """Test removing a preset.""" manager = PresetManager(project_dir)