diff --git a/README.md b/README.md index 86ba161..e2ed14c 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ PUMA Information: | `stop` | 🚧 | Stop a running model | `puma stop ` | | `rm` | ✅ | Remove a model | `puma rm InftyAI/tiny-random-gpt2` | | `info` | ✅ | Display system-wide information | `puma info` | -| `inspect` | 🚧 | Return detailed information about a model or service | `puma inspect InftyAI/tiny-random-gpt2` | +| `inspect` | ✅ | Return detailed information about a model or service | `puma inspect InftyAI/tiny-random-gpt2` | | `version` | ✅ | Show PUMA version | `puma version` | | `help` | ✅ | Show help information | `puma help` | diff --git a/src/cli/commands.rs b/src/cli/commands.rs index 5615aab..8535b8b 100644 --- a/src/cli/commands.rs +++ b/src/cli/commands.rs @@ -33,7 +33,7 @@ enum Commands { /// Display system-wide information INFO, /// Return detailed information about a model - INSPECT, + INSPECT(InspectArgs), /// Returns the version of PUMA. VERSION, } @@ -58,6 +58,12 @@ struct RmArgs { model: String, } +#[derive(Parser)] +struct InspectArgs { + /// Model name to inspect (e.g., InftyAI/tiny-random-gpt2) + model: String, +} + #[derive(Debug, Clone, Default, clap::ValueEnum)] pub enum Provider { #[default] @@ -70,7 +76,12 @@ pub async fn run(cli: Cli) { match cli.command { Commands::PS => { let mut table = Table::new(); - table.set_format(*format::consts::FORMAT_CLEAN); + table.set_format( + format::FormatBuilder::new() + .column_separator(' ') + .padding(0, 1) + .build(), + ); table.add_row(row!["NAME", "PROVIDER", "MODEL", "STATUS", "AGE"]); table.add_row(row![ "deepseek-r1", @@ -88,7 +99,12 @@ pub async fn run(cli: Cli) { let models = registry.load_models().unwrap_or_default(); let mut table = Table::new(); - table.set_format(*format::consts::FORMAT_CLEAN); + table.set_format( + format::FormatBuilder::new() + .column_separator(' ') + .padding(0, 1) + .build(), + ); table.add_row(row!["MODEL", "PROVIDER", "REVISION", "SIZE", "MODIFIED"]); for model in models { @@ -163,8 +179,51 @@ pub async fn run(cli: Cli) { info.display(); } - Commands::INSPECT => { - println!("Returning detailed information about model..."); + Commands::INSPECT(args) => { + let registry = ModelRegistry::new(None); + + match registry.get_model(&args.model) { + Ok(Some(model)) => { + println!("Name: {}", model.name); + println!("Kind: Model"); + + println!("Spec:"); + // Architecture section (only if info is available) + if let Some(arch) = &model.arch { + println!(" Architecture:"); + if let Some(model_type) = &arch.model_type { + println!(" Type: {}", model_type); + } + if let Some(classes) = &arch.classes { + println!(" Classes: {}", classes.join(", ")); + } + if let Some(parameters) = &arch.parameters { + println!(" Parameters: {}", parameters); + } + if let Some(context_window) = arch.context_window { + println!(" Context Window: {}", context_window); + } + } + // Registry section + println!(" Registry:"); + println!(" Provider: {}", model.provider); + println!(" Revision: {}", model.revision); + println!(" Size: {}", format_size_decimal(model.size)); + println!( + " Modified: {}", + format_time_ago(&model.modified_at) + ); + println!(" Cache Path: {}", model.cache_path); + } + Ok(None) => { + eprintln!("Model not found: {}", args.model); + std::process::exit(1); + } + Err(e) => { + eprintln!("Failed to load registry: {}", e); + std::process::exit(1); + } + } } Commands::VERSION => { diff --git a/src/downloader/huggingface.rs b/src/downloader/huggingface.rs index 1e09a32..2b44810 100644 --- a/src/downloader/huggingface.rs +++ b/src/downloader/huggingface.rs @@ -2,10 +2,11 @@ use colored::Colorize; use log::debug; use hf_hub::api::tokio::{ApiBuilder, Progress}; +use indicatif::{ProgressBar, ProgressStyle}; use crate::downloader::downloader::{DownloadError, Downloader}; use crate::downloader::progress::{DownloadProgressManager, FileProgress}; -use crate::registry::model_registry::{ModelInfo, ModelRegistry}; +use crate::registry::model_registry::{ModelArchitecture, ModelInfo, ModelRegistry}; use crate::utils::file::{self, format_model_name}; /// Adapter to bridge HuggingFace's Progress trait with our FileProgress @@ -62,7 +63,15 @@ impl Downloader for HuggingFaceDownloader { DownloadError::ApiError(format!("Failed to initialize Hugging Face API: {}", e)) })?; - println!("🐆 pulling manifest"); + // Create a simple spinner for manifest pulling + let manifest_spinner = ProgressBar::new_spinner(); + manifest_spinner.set_style( + ProgressStyle::default_spinner() + .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏") + .template("pulling manifest {spinner:.white}") + .unwrap(), + ); + manifest_spinner.enable_steady_tick(std::time::Duration::from_millis(80)); // Download the entire model repository using snapshot download let repo = api.model(name.to_string()); @@ -81,6 +90,10 @@ impl Downloader for HuggingFaceDownloader { } })?; + // Stop manifest spinner and print clean message + manifest_spinner.finish_and_clear(); + println!("pulling manifest"); + debug!("Model info for {}: {:?}", name, model_info); // Calculate the longest filename for proper alignment @@ -91,6 +104,8 @@ impl Downloader for HuggingFaceDownloader { .max() .unwrap_or(30); + // Add extra space for "pulling " prefix + let max_filename_len = max_filename_len + 8; // Create progress manager let progress_manager = DownloadProgressManager::new(max_filename_len); @@ -124,8 +139,9 @@ impl Downloader for HuggingFaceDownloader { debug!("File {} found in cache, showing as complete", filename); // Create progress bar for cached file (no speed display) + let display_name = format!("pulling {}", filename); let mut file_progress = - progress_manager_clone.create_cached_file_progress(&filename); + progress_manager_clone.create_cached_file_progress(&display_name); let file_size = cached_file_path.metadata().map(|m| m.len()).unwrap_or(0); file_progress.init(file_size); file_progress.update(file_size); @@ -136,7 +152,8 @@ impl Downloader for HuggingFaceDownloader { // File not in cache, download with progress debug!("Downloading: {}", filename); - let file_progress = progress_manager_clone.create_file_progress(&filename); + let display_name = format!("pulling {}", filename); + let file_progress = progress_manager_clone.create_file_progress(&display_name); let progress = HfProgressAdapter { progress: file_progress, }; @@ -156,37 +173,65 @@ impl Downloader for HuggingFaceDownloader { tasks.push(task); } + // Give tasks a moment to start and create their progress bars + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Show spinner at the bottom after all progress bars are created (only if not fully cached) + let spinner = if !model_totally_cached { + Some(progress_manager.create_spinner()) + } else { + None + }; + // Wait for all downloads to complete for task in tasks { task.await .map_err(|e| DownloadError::ApiError(format!("Task join error: {}", e)))??; } + // Finish spinner after downloads complete + if let Some(spinner) = &spinner { + spinner.finish_and_clear(); + } + let elapsed_time = start_time.elapsed(); // Get accumulated size from downloads let downloaded_size = progress_manager.total_downloaded_bytes(); let model_cache_path = cache_dir.join(format_model_name(name)); - // Register the model - let model_info_record = ModelInfo { - name: name.to_string(), - provider: "huggingface".to_string(), - revision: sha, - size: downloaded_size, - modified_at: chrono::Local::now().to_rfc3339(), - cache_path: model_cache_path.to_string_lossy().to_string(), - }; - + // Register the model only if not totally cached if !model_totally_cached { + // Extract architecture info from config.json + let config_path = snapshot_path.join("config.json"); + let arch = if config_path.exists() { + std::fs::read_to_string(&config_path) + .ok() + .and_then(|content| serde_json::from_str::(&content).ok()) + .and_then(|config| ModelArchitecture::from_config(&config)) + } else { + None + }; + + let model_info_record = ModelInfo { + name: name.to_string(), + provider: "huggingface".to_string(), + revision: sha, + size: downloaded_size, + modified_at: chrono::Local::now().to_rfc3339(), + cache_path: model_cache_path.to_string_lossy().to_string(), + arch, + }; + let registry = ModelRegistry::new(None); registry .register_model(model_info_record) .map_err(|e| DownloadError::ApiError(format!("Failed to register model: {}", e)))?; } + // Print success message println!( - "\n{} {} {} {} {:.2?}", + "{} {} {} {} {:.2?}", "✓".green().bold(), "Successfully downloaded model".bright_white(), name.cyan().bold(), diff --git a/src/downloader/progress.rs b/src/downloader/progress.rs index 533e364..7b3ba32 100644 --- a/src/downloader/progress.rs +++ b/src/downloader/progress.rs @@ -85,6 +85,19 @@ impl DownloadProgressManager { pub fn total_downloaded_bytes(&self) -> u64 { self.total_size.load(Ordering::Relaxed) } + + /// Create a spinner progress bar (for post-download operations) + pub fn create_spinner(&self) -> ProgressBar { + let pb = self.multi_progress.add(ProgressBar::new_spinner()); + pb.set_style( + ProgressStyle::default_spinner() + .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏") + .template("{spinner} ") + .unwrap(), + ); + pb.enable_steady_tick(std::time::Duration::from_millis(80)); + pb + } } /// Tracks progress for a single file download diff --git a/src/registry/model_registry.rs b/src/registry/model_registry.rs index f52893a..13ea87c 100644 --- a/src/registry/model_registry.rs +++ b/src/registry/model_registry.rs @@ -4,6 +4,93 @@ use std::fs; use std::path::PathBuf; use crate::utils::file; +use crate::utils::format::format_parameters; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct ModelArchitecture { + #[serde(skip_serializing_if = "Option::is_none")] + pub model_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub classes: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub context_window: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, +} + +impl ModelArchitecture { + /// Extract model architecture from config.json + pub fn from_config(config: &serde_json::Value) -> Option { + let model_type = config + .get("model_type") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + let classes = config + .get("architectures") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect::>() + }) + .filter(|v| !v.is_empty()); + + let context_window = config + .get("n_positions") + .or_else(|| config.get("max_position_embeddings")) + .or_else(|| config.get("n_ctx")) + .and_then(|v| v.as_u64()) + .map(|v| v as u32); + + let parameters = Self::estimate_parameters(config); + + if model_type.is_some() + || classes.is_some() + || context_window.is_some() + || parameters.is_some() + { + Some(ModelArchitecture { + model_type, + classes, + context_window, + parameters, + }) + } else { + None + } + } + + /// Estimate model parameters from config + fn estimate_parameters(config: &serde_json::Value) -> Option { + let n_layer = config + .get("n_layer") + .or_else(|| config.get("num_hidden_layers")) + .and_then(|v| v.as_u64())?; + + let n_embd = config + .get("n_embd") + .or_else(|| config.get("hidden_size")) + .and_then(|v| v.as_u64())?; + + let vocab_size = config.get("vocab_size").and_then(|v| v.as_u64())?; + + let n_positions = config + .get("n_positions") + .or_else(|| config.get("max_position_embeddings")) + .and_then(|v| v.as_u64()) + .unwrap_or(2048); + + // Rough parameter estimation for transformer models + // Each layer: ~12 * n_embd^2 (attention + FFN) + // Embeddings: vocab_size * n_embd + n_positions * n_embd + let layer_params = 12 * n_layer * n_embd * n_embd; + let embedding_params = vocab_size * n_embd + n_positions * n_embd; + let total_params = layer_params + embedding_params; + + Some(format_parameters(total_params)) + } +} #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ModelInfo { @@ -13,6 +100,8 @@ pub struct ModelInfo { pub size: u64, pub modified_at: String, pub cache_path: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub arch: Option, } pub struct ModelRegistry { @@ -124,6 +213,7 @@ mod tests { size: 1000, modified_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), + arch: None, }; registry.register_model(model.clone()).unwrap(); @@ -145,6 +235,7 @@ mod tests { size: 1000, modified_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), + arch: None, }; registry.register_model(model).unwrap(); @@ -166,6 +257,7 @@ mod tests { size: 1000, modified_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), + arch: None, }; registry.register_model(model).unwrap(); @@ -200,6 +292,7 @@ mod tests { size: 1000, modified_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), + arch: None, }; registry.register_model(model1).unwrap(); @@ -211,6 +304,7 @@ mod tests { size: 2000, modified_at: "2025-01-02T00:00:00Z".to_string(), cache_path: "/tmp/test2".to_string(), + arch: None, }; registry.register_model(model2).unwrap(); @@ -238,6 +332,7 @@ mod tests { size: 1000, modified_at: "2025-01-01T00:00:00Z".to_string(), cache_path: cache_dir.to_string_lossy().to_string(), + arch: None, }; registry.register_model(model).unwrap(); @@ -263,4 +358,138 @@ mod tests { let result = registry.remove_model("nonexistent"); assert!(result.is_ok()); } + + #[test] + fn test_inspect_model_with_full_spec() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let model = ModelInfo { + name: "test/gpt-model".to_string(), + provider: "huggingface".to_string(), + revision: "abc123def456".to_string(), + size: 7_000_000_000, + modified_at: "2025-01-01T00:00:00Z".to_string(), + cache_path: "/tmp/test/gpt".to_string(), + arch: Some(ModelArchitecture { + model_type: Some("gpt2".to_string()), + classes: Some(vec!["GPT2LMHeadModel".to_string()]), + context_window: Some(2048), + parameters: Some("7.00B".to_string()), + }), + }; + + registry.register_model(model).unwrap(); + + let retrieved = registry.get_model("test/gpt-model").unwrap(); + assert!(retrieved.is_some()); + + let model_info = retrieved.unwrap(); + assert_eq!(model_info.name, "test/gpt-model"); + assert_eq!(model_info.provider, "huggingface"); + assert_eq!(model_info.revision, "abc123def456"); + assert_eq!(model_info.size, 7_000_000_000); + + let arch = model_info.arch.unwrap(); + assert_eq!(arch.model_type, Some("gpt2".to_string())); + assert_eq!(arch.classes, Some(vec!["GPT2LMHeadModel".to_string()])); + assert_eq!(arch.context_window, Some(2048)); + assert_eq!(arch.parameters, Some("7.00B".to_string())); + } + + #[test] + fn test_inspect_model_without_spec() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let model = ModelInfo { + name: "test/legacy-model".to_string(), + provider: "huggingface".to_string(), + revision: "legacy123".to_string(), + size: 1_000_000, + modified_at: "2024-01-01T00:00:00Z".to_string(), + cache_path: "/tmp/test/legacy".to_string(), + arch: None, + }; + + registry.register_model(model).unwrap(); + + let retrieved = registry.get_model("test/legacy-model").unwrap(); + assert!(retrieved.is_some()); + + let model_info = retrieved.unwrap(); + assert_eq!(model_info.name, "test/legacy-model"); + assert!(model_info.arch.is_none()); + } + + #[test] + fn test_model_architecture_from_config_gpt2() { + use serde_json::json; + + let config = json!({ + "model_type": "gpt2", + "architectures": ["GPT2LMHeadModel"], + "n_layer": 5, + "n_embd": 32, + "vocab_size": 1000, + "n_positions": 512 + }); + + let arch = ModelArchitecture::from_config(&config); + assert!(arch.is_some()); + + let arch = arch.unwrap(); + assert_eq!(arch.model_type, Some("gpt2".to_string())); + assert_eq!(arch.classes, Some(vec!["GPT2LMHeadModel".to_string()])); + assert_eq!(arch.context_window, Some(512)); + assert_eq!(arch.parameters, Some("109.82K".to_string())); + } + + #[test] + fn test_model_architecture_from_config_bert_style() { + use serde_json::json; + + let config = json!({ + "model_type": "bert", + "num_hidden_layers": 12, + "hidden_size": 768, + "vocab_size": 30000, + "max_position_embeddings": 512 + }); + + let arch = ModelArchitecture::from_config(&config); + assert!(arch.is_some()); + + let arch = arch.unwrap(); + assert_eq!(arch.model_type, Some("bert".to_string())); + assert_eq!(arch.context_window, Some(512)); + assert!(arch.parameters.unwrap().contains("M")); + } + + #[test] + fn test_model_architecture_from_config_partial() { + use serde_json::json; + + let config = json!({ + "model_type": "llama", + "n_ctx": 4096 + }); + + let arch = ModelArchitecture::from_config(&config); + assert!(arch.is_some()); + + let arch = arch.unwrap(); + assert_eq!(arch.model_type, Some("llama".to_string())); + assert_eq!(arch.context_window, Some(4096)); + assert_eq!(arch.parameters, None); + } + + #[test] + fn test_model_architecture_from_config_empty() { + use serde_json::json; + + let config = json!({}); + let arch = ModelArchitecture::from_config(&config); + assert_eq!(arch, None); + } } diff --git a/src/system/system_info.rs b/src/system/system_info.rs index 9f8cc0d..00b49cd 100644 --- a/src/system/system_info.rs +++ b/src/system/system_info.rs @@ -272,7 +272,6 @@ impl SystemInfo { } } - println!(); println!("PUMA Information:"); println!(" PUMA Version: {}", self.version); println!(" Cache Directory: {}", self.cache_dir); diff --git a/src/utils/format.rs b/src/utils/format.rs index 0e193ab..865c9b9 100644 --- a/src/utils/format.rs +++ b/src/utils/format.rs @@ -34,6 +34,23 @@ pub fn format_size_decimal(bytes: u64) -> String { } } +/// Format parameter count to human-readable format (K, M, B) +pub fn format_parameters(count: u64) -> String { + const K: f64 = 1_000.0; + const M: f64 = 1_000_000.0; + const B: f64 = 1_000_000_000.0; + + if count as f64 >= B { + format!("{:.2}B", count as f64 / B) + } else if count as f64 >= M { + format!("{:.2}M", count as f64 / M) + } else if count as f64 >= K { + format!("{:.2}K", count as f64 / K) + } else { + count.to_string() + } +} + /// Format RFC3339 timestamp to human-readable relative time (e.g., "2 hours ago") pub fn format_time_ago(timestamp: &str) -> String { // Try to parse as RFC3339 @@ -273,4 +290,64 @@ mod tests { // Large model (65 GB) assert_eq!(format_size_decimal(65_000_000_000), "65.00 GB"); } + + #[test] + fn test_format_parameters_raw() { + assert_eq!(format_parameters(0), "0"); + assert_eq!(format_parameters(1), "1"); + assert_eq!(format_parameters(999), "999"); + } + + #[test] + fn test_format_parameters_thousands() { + assert_eq!(format_parameters(1_000), "1.00K"); + assert_eq!(format_parameters(1_500), "1.50K"); + assert_eq!(format_parameters(10_000), "10.00K"); + assert_eq!(format_parameters(999_999), "1000.00K"); + } + + #[test] + fn test_format_parameters_millions() { + assert_eq!(format_parameters(1_000_000), "1.00M"); + assert_eq!(format_parameters(1_500_000), "1.50M"); + assert_eq!(format_parameters(7_000_000), "7.00M"); + assert_eq!(format_parameters(350_000_000), "350.00M"); + } + + #[test] + fn test_format_parameters_billions() { + assert_eq!(format_parameters(1_000_000_000), "1.00B"); + assert_eq!(format_parameters(1_500_000_000), "1.50B"); + assert_eq!(format_parameters(7_000_000_000), "7.00B"); + assert_eq!(format_parameters(175_000_000_000), "175.00B"); + } + + #[test] + fn test_format_parameters_realistic_models() { + // Tiny model (109K parameters) + assert_eq!(format_parameters(109_824), "109.82K"); + + // Small model (125M parameters) + assert_eq!(format_parameters(125_000_000), "125.00M"); + + // Medium model (7B parameters) + assert_eq!(format_parameters(7_000_000_000), "7.00B"); + + // Large model (70B parameters) + assert_eq!(format_parameters(70_000_000_000), "70.00B"); + + // Very large model (405B parameters) + assert_eq!(format_parameters(405_000_000_000), "405.00B"); + } + + #[test] + fn test_format_parameters_edge_cases() { + // Boundary between K and M + assert_eq!(format_parameters(999_999), "1000.00K"); + assert_eq!(format_parameters(1_000_000), "1.00M"); + + // Boundary between M and B + assert_eq!(format_parameters(999_999_999), "1000.00M"); + assert_eq!(format_parameters(1_000_000_000), "1.00B"); + } }