diff --git a/src/modelinfo/architecture.py b/src/modelinfo/architecture.py index b398213..bef7237 100644 --- a/src/modelinfo/architecture.py +++ b/src/modelinfo/architecture.py @@ -11,6 +11,7 @@ def extract_architecture(tensors: Dict[str, Any], config: Dict[str, Any] = None) metadata = tensors.get("__metadata__", {}) gen_arch = metadata.get("general.architecture") + is_gguf = "general.architecture" in metadata or any(k.startswith("general.") for k in metadata.keys()) # 1. Attempt explicit GGUF metadata if gen_arch: @@ -68,14 +69,14 @@ def extract_architecture(tensors: Dict[str, Any], config: Dict[str, Any] = None) found_k_proj = True shape = meta.get("shape", []) if len(shape) >= 2: - kv_dim = shape[0] + kv_dim = shape[-1] if is_gguf else shape[0] if "qkv_proj.weight" in name or "c_attn.weight" in name: found_fused = True if not found_k_proj: shape = meta.get("shape", []) if len(shape) >= 2: - kv_dim = shape[0] // 3 + kv_dim = (shape[-1] if is_gguf else shape[0]) // 3 num_layers = len(layers_set) if found_fused and not found_k_proj and kv_dim > 0: diff --git a/src/modelinfo/cli.py b/src/modelinfo/cli.py index cb4be02..56c27d0 100644 --- a/src/modelinfo/cli.py +++ b/src/modelinfo/cli.py @@ -151,8 +151,15 @@ def analyze_model( is_remote = False if not os.path.exists(file_path): - if "/" in file_path or not file_path_lower.endswith((".safetensors", ".gguf", ".pt", ".bin", ".index.json")): - is_remote = True + # ponytail: prevent routing explicit local paths or typos to HF + is_local_path = ( + file_path.startswith((".", "/", "~")) + or os.path.isabs(file_path) + ) + if not is_local_path: + # Treat as remote only if it contains a slash and does not end with a model extension + if "/" in file_path and not file_path_lower.endswith((".safetensors", ".gguf", ".pt", ".bin", ".index.json")): + is_remote = True if is_remote: from modelinfo.parsers.huggingface import fetch_huggingface_repo @@ -212,8 +219,12 @@ def analyze_model( num_layers = footprint["num_layers"] arch_name = identify_architecture_name(tensors, num_layers, config) - if os.path.exists(file_path): - disk_size = os.path.getsize(file_path) + if not is_remote: + metadata = tensors.get("__metadata__", {}) + if metadata.get("is_sharded") and "disk_size" in metadata: + disk_size = metadata["disk_size"] + elif os.path.exists(file_path): + disk_size = os.path.getsize(file_path) tensor_count = len([k for k in tensors.keys() if k != "__metadata__"]) @@ -240,6 +251,10 @@ def analyze_model( def main(argv: Sequence[str] | None = None) -> int: args = parse_args(argv) + # Strip trailing slashes from paths/repos to prevent empty basenames and routing issues + if args.file: + args.file = [path.rstrip("/\\") for path in args.file if path] + gpu_name_display = None gpu_vram_gb = None gpu_count = 1 diff --git a/src/modelinfo/parsers/huggingface.py b/src/modelinfo/parsers/huggingface.py index 00f09e3..b36dd7f 100644 --- a/src/modelinfo/parsers/huggingface.py +++ b/src/modelinfo/parsers/huggingface.py @@ -212,19 +212,27 @@ def _fetch_remote_gguf_group(real_repo_id: str, gguf_files: List[Dict[str, Any]] return tensors -def _fetch_shards_concurrently(real_repo_id: str, unique_shards: List[str], timeout: float) -> Dict[str, Any]: +def _fetch_shards_concurrently(real_repo_id: str, unique_shards: List[str], timeout: float) -> Tuple[Dict[str, Any], int]: def fetch_shard(shard: str): - return shard, _fetch_safetensors_header(real_repo_id, shard, timeout=timeout) + try: + header = _fetch_safetensors_header(real_repo_id, shard, timeout=timeout) + return shard, header, None + except Exception as e: + return shard, {}, e tensors = {} + missing_shards = 0 with concurrent.futures.ThreadPoolExecutor(max_workers=max(1, min(8, len(unique_shards)))) as executor: future_to_shard = {executor.submit(fetch_shard, shard): shard for shard in unique_shards} for future in concurrent.futures.as_completed(future_to_shard): - shard, shard_header = future.result() - for k, v in shard_header.items(): - if k != "__metadata__": - tensors[k] = v - return tensors + shard, shard_header, error = future.result() + if error is not None: + missing_shards += 1 + else: + for k, v in shard_header.items(): + if k != "__metadata__": + tensors[k] = v + return tensors, missing_shards def _fetch_remote_safetensors_sharded( @@ -253,9 +261,9 @@ def _fetch_remote_safetensors_sharded( "total_size": total_size } else: - tensors = _fetch_shards_concurrently(real_repo_id, unique_shards, timeout) + tensors, missing_shards = _fetch_shards_concurrently(real_repo_id, unique_shards, timeout) tensors["__metadata__"] = { - "missing_shards": 0, + "missing_shards": missing_shards, "total_shards": len(unique_shards), "is_sharded": True } diff --git a/src/modelinfo/parsers/safetensors.py b/src/modelinfo/parsers/safetensors.py index 5d0289f..2e7d705 100644 --- a/src/modelinfo/parsers/safetensors.py +++ b/src/modelinfo/parsers/safetensors.py @@ -31,7 +31,13 @@ def parse_safetensors_header(path: str) -> dict[str, Any]: if path.endswith(".index.json"): is_index = True elif "-of-" in base_name and path.endswith(".safetensors"): - prefix = base_name.split("-")[0] + import re + match = re.match(r"^(.*?)-\d+-of-\d+\.safetensors$", base_name) + if match: + prefix = match.group(1) + else: + # Fallback to splitting in case of non-standard shard formatting + prefix = base_name.split("-")[0] potential_index = os.path.join(dir_path, f"{prefix}.safetensors.index.json") if os.path.exists(potential_index): index_path = potential_index @@ -49,9 +55,12 @@ def parse_safetensors_header(path: str) -> dict[str, Any]: tensors = {} missing_shards = 0 total_shards = len(unique_shards) + total_size = 0 for shard in unique_shards: shard_path = os.path.join(dir_path, shard) + if os.path.exists(shard_path): + total_size += os.path.getsize(shard_path) try: shard_header = _read_single_header(shard_path) for k, v in shard_header.items(): @@ -63,7 +72,8 @@ def parse_safetensors_header(path: str) -> dict[str, Any]: tensors["__metadata__"] = { "missing_shards": missing_shards, "total_shards": total_shards, - "is_sharded": True + "is_sharded": True, + "disk_size": total_size } return tensors diff --git a/tests/test_calculator.py b/tests/test_calculator.py index 1ef701c..94cf3ea 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -169,3 +169,46 @@ def test_vllm_capacity_simulation(): bytes_per_token = 40960 expected_capacity = math.floor(metrics["paged_kv_pool"] / bytes_per_token) assert metrics["max_serving_capacity"] == expected_capacity + + +def test_gguf_shape_guessing_fallback(): + """Verify that shape guessing logic correctly extracts kv_dim using GGUF column-major ordering (shape[-1]) when metadata has no explicit keys.""" + from modelinfo.architecture import extract_architecture + + tensors = { + "__metadata__": { + "general.architecture": "llama", + }, + "model.layers.0.self_attn.k_proj.weight": { + "shape": [4096, 1024], + "dtype": "F16" + }, + "model.layers.1.self_attn.k_proj.weight": { + "shape": [4096, 1024], + "dtype": "F16" + } + } + + num_layers, kv_dim, is_estimate = extract_architecture(tensors) + assert num_layers == 2 + assert kv_dim == 1024 + assert is_estimate is False + +def test_gguf_shape_guessing_fallback_fused(): + """Verify that fused shape guessing extracts (shape[-1] // 3) for GGUF tensors.""" + from modelinfo.architecture import extract_architecture + + tensors = { + "__metadata__": { + "general.architecture": "gpt2", + }, + "model.layers.0.self_attn.qkv_proj.weight": { + "shape": [4096, 3072], + "dtype": "F16" + } + } + + num_layers, kv_dim, is_estimate = extract_architecture(tensors) + assert num_layers == 1 + assert kv_dim == 1024 + assert is_estimate is True diff --git a/tests/test_cli.py b/tests/test_cli.py index 857a225..267b98d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -187,19 +187,7 @@ def fake_exists(path): return False def fake_fetch(repo_id, *, fetch_tensors, timeout): - tensors = { - "__metadata__": { - "general.architecture": "llama", - "llama.block_count": 32, - "llama.attention.head_count_kv": 8, - "llama.attention.key_length": 128, - "gguf_variants": [ - {"filename": "model-q4.gguf", "size": 1000000000}, - {"filename": "model-q8.gguf", "size": 2000000000} - ], - "repo_id": "org/model-gguf" - } - } + tensors, _ = _get_mock_gguf_group_data() return tensors, None, "GGUF_group", 0.0 monkeypatch.setattr(cli.os.path, "exists", fake_exists) @@ -304,4 +292,79 @@ def test_print_model_info_gguf_group_with_gpu(capsys): assert "model-q8.gguf" in out assert "Fits" in out +def test_analyze_model_local_path_routing(monkeypatch): + """Test that analyze_model treats paths starting with local prefix as local, raising an error instead of routing to Hugging Face.""" + from modelinfo.parsers import huggingface + + hf_fetched = [] + def fake_fetch(repo_id, *, fetch_tensors, timeout): + hf_fetched.append(repo_id) + return {}, None, "SafeTensors", 0.0 + + monkeypatch.setattr(huggingface, "fetch_huggingface_repo", fake_fetch) + + # Test cases that should NOT hit Hugging Face + local_paths = ["./missing.gguf", "../missing.safetensors", "/missing.bin", "~/missing.pt"] + for path in local_paths: + with pytest.raises((FileNotFoundError, ValueError, OSError)): + cli.analyze_model(path, context_override=128) + + assert len(hf_fetched) == 0, f"Hugging Face fetch was triggered for local paths: {hf_fetched}" + + # Test cases that SHOULD hit Hugging Face + remote_paths = ["meta-llama/Llama-2-7b-hf", "org/model"] + for path in remote_paths: + try: + cli.analyze_model(path, context_override=128) + except Exception: + # We don't care if calculation fails later because of empty dict from fake_fetch, + # we just care that it triggers fetch_huggingface_repo. + pass + + assert hf_fetched == remote_paths + + +def test_cli_strips_trailing_slashes_from_model_paths(monkeypatch): + captured_paths = [] + + def fake_analyze_model(file_path, *args, **kwargs): + captured_paths.append(file_path) + return { + "format_name": "GGUF", + "arch_name": "Llama", + "tensor_count": 10, + "footprint": { + "total_params": 100, + "base_memory_bytes": 200, + "kv_cache_bytes": 100, + "overhead_bytes": 50, + "total_memory_bytes": 350, + "num_layers": 1, + }, + "disk_size": 200, + "context_length": 128, + "is_default_context": True, + "tensors": {}, + "max_context": 512, + "is_lazy": False, + "gpu_count": 1, + "topology": "pcie4", + "strategy": "tp", + "is_vllm": False, + "gpu_vram_gb": 0.0, + "gpu_util": 0.9, + } + + monkeypatch.setattr(cli, "analyze_model", fake_analyze_model) + monkeypatch.setattr(cli, "print_compare_info", lambda models, max_vram, gpu_name: None) + monkeypatch.setattr(cli, "print_model_info", lambda *args, **kwargs: None) + + # Test single model path with trailing slash + cli.main(["meta-llama/Llama-2-7b-hf/"]) + assert captured_paths == ["meta-llama/Llama-2-7b-hf"] + + captured_paths.clear() + # Test multiple model paths with trailing slashes (side-by-side comparison) + cli.main(["meta-llama/Llama-2-7b-hf/", "mistralai/Mistral-7B-v0.1/"]) + assert captured_paths == ["meta-llama/Llama-2-7b-hf", "mistralai/Mistral-7B-v0.1"] diff --git a/tests/test_parsers.py b/tests/test_parsers.py index c1d0b6e..97fe244 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -201,4 +201,85 @@ def fake_make_request(url, headers=None, limit=None, timeout=10.0): huggingface.fetch_huggingface_repo("org/nonexistent-model") assert "Could not find repository on Hugging Face" in str(exc_info.value) +def test_safetensors_sharded_with_hyphens(tmp_path): + """Test safetensors parser sharded index path resolution when filename contains hyphens.""" + import struct + import json + + index_file = tmp_path / "mock-llama-3-8b.safetensors.index.json" + shard_file = tmp_path / "mock-llama-3-8b-00001-of-00002.safetensors" + + index_data = { + "weight_map": { + "model.embed_tokens.weight": "mock-llama-3-8b-00001-of-00002.safetensors" + } + } + index_file.write_text(json.dumps(index_data), encoding="utf-8") + + header_data = { + "model.embed_tokens.weight": { + "dtype": "BF16", + "shape": [32000, 4096], + "data_offsets": [0, 262144000] + } + } + header_json = json.dumps(header_data).encode("utf-8") + header_len = len(header_json) + + with open(shard_file, "wb") as f: + f.write(struct.pack("