Skip to content
212 changes: 205 additions & 7 deletions src/cli/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ pub async fn run(cli: Cli) {
.padding(0, 1)
.build(),
);
table.add_row(row!["MODEL", "PROVIDER", "REVISION", "SIZE", "MODIFIED"]);

table.add_row(row!["MODEL", "PROVIDER", "REVISION", "SIZE", "AGE"]);
for model in models {
let size_str = format_size_decimal(model.size);

Expand All @@ -116,7 +115,7 @@ pub async fn run(cli: Cli) {
&model.revision
};

let created_str = format_time_ago(&model.modified_at);
let created_str = format_time_ago(&model.created_at);

table.add_row(row![
model.name,
Expand Down Expand Up @@ -186,6 +185,9 @@ pub async fn run(cli: Cli) {
Ok(Some(model)) => {
println!("Name: {}", model.name);
println!("Kind: Model");
println!("Metadata:");
println!(" Created: {}", format_time_ago(&model.created_at));
println!(" Updated: {}", format_time_ago(&model.updated_at));

println!("Spec:");
// Architecture section (only if info is available)
Expand All @@ -209,10 +211,6 @@ pub async fn run(cli: Cli) {
println!(" Provider: {}", model.provider);
println!(" Revision: {}", model.revision);
println!(" Size: {}", format_size_decimal(model.size));
println!(
" Modified: {}",
format_time_ago(&model.modified_at)
);
println!(" Cache Path: {}", model.cache_path);
}
Ok(None) => {
Expand All @@ -231,3 +229,203 @@ pub async fn run(cli: Cli) {
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::registry::model_registry::{ModelArchitecture, ModelInfo};
use tempfile::TempDir;

#[test]
fn test_ls_command_empty() {
let temp_dir = TempDir::new().unwrap();
let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf()));

let models = registry.load_models().unwrap_or_default();
assert_eq!(models.len(), 0);
}

#[test]
fn test_ls_command_with_models() {
let temp_dir = TempDir::new().unwrap();
let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf()));

let model = ModelInfo {
name: "test/model".to_string(),
provider: "huggingface".to_string(),
revision: "abc123def456".to_string(),
size: 1_000_000,
created_at: "2025-01-01T00:00:00Z".to_string(),
updated_at: "2025-01-01T00:00:00Z".to_string(),
cache_path: "/tmp/test".to_string(),
arch: None,
};

registry.register_model(model).unwrap();

let models = registry.load_models().unwrap();
assert_eq!(models.len(), 1);
assert_eq!(models[0].name, "test/model");
assert_eq!(models[0].provider, "huggingface");
}

#[test]
fn test_inspect_command_with_metadata() {
let temp_dir = TempDir::new().unwrap();
let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf()));

let model = ModelInfo {
name: "test/gpt-model".to_string(),
provider: "huggingface".to_string(),
revision: "abc123def456".to_string(),
size: 7_000_000_000,
created_at: "2025-01-01T00:00:00Z".to_string(),
updated_at: "2025-01-02T00:00:00Z".to_string(),
cache_path: "/tmp/test/gpt".to_string(),
arch: Some(ModelArchitecture {
model_type: Some("gpt2".to_string()),
classes: Some(vec!["GPT2LMHeadModel".to_string()]),
context_window: Some(2048),
parameters: Some("7.00B".to_string()),
}),
};

registry.register_model(model.clone()).unwrap();

let retrieved = registry.get_model("test/gpt-model").unwrap();
assert!(retrieved.is_some());

let model_info = retrieved.unwrap();
assert_eq!(model_info.name, "test/gpt-model");
assert_eq!(model_info.created_at, "2025-01-01T00:00:00Z");
assert_eq!(model_info.updated_at, "2025-01-02T00:00:00Z");

let arch = model_info.arch.unwrap();
assert_eq!(arch.model_type, Some("gpt2".to_string()));
assert_eq!(arch.classes, Some(vec!["GPT2LMHeadModel".to_string()]));
assert_eq!(arch.context_window, Some(2048));
assert_eq!(arch.parameters, Some("7.00B".to_string()));
}

#[test]
fn test_inspect_command_without_architecture() {
let temp_dir = TempDir::new().unwrap();
let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf()));

let model = ModelInfo {
name: "test/simple-model".to_string(),
provider: "huggingface".to_string(),
revision: "xyz789".to_string(),
size: 500_000,
created_at: "2025-01-01T00:00:00Z".to_string(),
updated_at: "2025-01-01T00:00:00Z".to_string(),
cache_path: "/tmp/test/simple".to_string(),
arch: None,
};

registry.register_model(model).unwrap();

let retrieved = registry.get_model("test/simple-model").unwrap();
assert!(retrieved.is_some());

let model_info = retrieved.unwrap();
assert_eq!(model_info.name, "test/simple-model");
assert!(model_info.arch.is_none());
}

#[test]
fn test_rm_command() {
let temp_dir = TempDir::new().unwrap();
let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf()));

let model = ModelInfo {
name: "test/remove-model".to_string(),
provider: "huggingface".to_string(),
revision: "abc123".to_string(),
size: 1000,
created_at: "2025-01-01T00:00:00Z".to_string(),
updated_at: "2025-01-01T00:00:00Z".to_string(),
cache_path: "/tmp/test/remove".to_string(),
arch: None,
};

registry.register_model(model).unwrap();
assert!(registry.get_model("test/remove-model").unwrap().is_some());

// Simulate RM command
let result = registry.get_model("test/remove-model");
assert!(result.is_ok());
assert!(result.unwrap().is_some());
}

#[test]
fn test_rm_command_nonexistent() {
let temp_dir = TempDir::new().unwrap();
let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf()));

let result = registry.get_model("nonexistent/model");
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}

#[test]
fn test_revision_truncation() {
let long_revision = "abc123def456ghi789jkl012";
let short = if long_revision.len() > 8 {
&long_revision[..8]
} else {
long_revision
};
assert_eq!(short, "abc123de");

let short_revision = "abc123";
let short = if short_revision.len() > 8 {
&short_revision[..8]
} else {
short_revision
};
assert_eq!(short, "abc123");
}

#[test]
fn test_metadata_timestamps_differ() {
let temp_dir = TempDir::new().unwrap();
let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf()));

let model = ModelInfo {
name: "test/updated-model".to_string(),
provider: "huggingface".to_string(),
revision: "v1".to_string(),
size: 1000,
created_at: "2025-01-01T00:00:00Z".to_string(),
updated_at: "2025-01-01T00:00:00Z".to_string(),
cache_path: "/tmp/test".to_string(),
arch: None,
};

registry.register_model(model).unwrap();

// Update the model
let updated_model = ModelInfo {
name: "test/updated-model".to_string(),
provider: "huggingface".to_string(),
revision: "v2".to_string(),
size: 2000,
created_at: "2025-01-05T00:00:00Z".to_string(),
updated_at: "2025-01-05T00:00:00Z".to_string(),
cache_path: "/tmp/test".to_string(),
arch: None,
};

registry.register_model(updated_model).unwrap();

let result = registry.get_model("test/updated-model").unwrap().unwrap();
// created_at should remain the same
assert_eq!(result.created_at, "2025-01-01T00:00:00Z");
// updated_at should be new
assert_eq!(result.updated_at, "2025-01-05T00:00:00Z");
// Other fields should be updated
assert_eq!(result.revision, "v2");
assert_eq!(result.size, 2000);
}
}
4 changes: 3 additions & 1 deletion src/downloader/huggingface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,14 @@ impl Downloader for HuggingFaceDownloader {
None
};

let now = chrono::Local::now().to_rfc3339();
let model_info_record = ModelInfo {
name: name.to_string(),
provider: "huggingface".to_string(),
revision: sha,
size: downloaded_size,
modified_at: chrono::Local::now().to_rfc3339(),
created_at: now.clone(),
updated_at: now,
cache_path: model_cache_path.to_string_lossy().to_string(),
arch,
};
Expand Down
45 changes: 35 additions & 10 deletions src/registry/model_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ pub struct ModelInfo {
pub provider: String,
pub revision: String,
pub size: u64,
pub modified_at: String,
pub created_at: String,
pub updated_at: String,
pub cache_path: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub arch: Option<ModelArchitecture>,
Expand Down Expand Up @@ -148,10 +149,22 @@ impl ModelRegistry {
pub fn register_model(&self, model: ModelInfo) -> Result<(), std::io::Error> {
let mut models = self.load_models()?;

// Check if model already exists to preserve created_at
let existing_created_at = models
.iter()
.find(|m| m.name == model.name)
.map(|m| m.created_at.clone());

// Remove existing model with same name if exists
models.retain(|m| m.name != model.name);

models.push(model);
// Use existing created_at if this is an update, otherwise use the provided one
let mut final_model = model;
if let Some(created_at) = existing_created_at {
final_model.created_at = created_at;
}

models.push(final_model);
self.save_models(&models)?;

Ok(())
Expand Down Expand Up @@ -211,7 +224,8 @@ mod tests {
provider: "huggingface".to_string(),
revision: "abc123".to_string(),
size: 1000,
modified_at: "2025-01-01T00:00:00Z".to_string(),
created_at: "2025-01-01T00:00:00Z".to_string(),
updated_at: "2025-01-01T00:00:00Z".to_string(),
cache_path: "/tmp/test".to_string(),
arch: None,
};
Expand All @@ -233,7 +247,8 @@ mod tests {
provider: "huggingface".to_string(),
revision: "abc123".to_string(),
size: 1000,
modified_at: "2025-01-01T00:00:00Z".to_string(),
created_at: "2025-01-01T00:00:00Z".to_string(),
updated_at: "2025-01-01T00:00:00Z".to_string(),
cache_path: "/tmp/test".to_string(),
arch: None,
};
Expand All @@ -255,7 +270,8 @@ mod tests {
provider: "huggingface".to_string(),
revision: "abc123".to_string(),
size: 1000,
modified_at: "2025-01-01T00:00:00Z".to_string(),
created_at: "2025-01-01T00:00:00Z".to_string(),
updated_at: "2025-01-01T00:00:00Z".to_string(),
cache_path: "/tmp/test".to_string(),
arch: None,
};
Expand Down Expand Up @@ -290,7 +306,8 @@ mod tests {
provider: "huggingface".to_string(),
revision: "abc123".to_string(),
size: 1000,
modified_at: "2025-01-01T00:00:00Z".to_string(),
created_at: "2025-01-01T00:00:00Z".to_string(),
updated_at: "2025-01-01T00:00:00Z".to_string(),
cache_path: "/tmp/test".to_string(),
arch: None,
};
Expand All @@ -302,7 +319,8 @@ mod tests {
provider: "huggingface".to_string(),
revision: "def456".to_string(),
size: 2000,
modified_at: "2025-01-02T00:00:00Z".to_string(),
created_at: "2025-01-02T00:00:00Z".to_string(),
updated_at: "2025-01-02T00:00:00Z".to_string(),
cache_path: "/tmp/test2".to_string(),
arch: None,
};
Expand All @@ -313,6 +331,10 @@ mod tests {
assert_eq!(models.len(), 1);
assert_eq!(models[0].revision, "def456");
assert_eq!(models[0].size, 2000);
// created_at should be preserved from model1
assert_eq!(models[0].created_at, "2025-01-01T00:00:00Z");
// updated_at should be from model2
assert_eq!(models[0].updated_at, "2025-01-02T00:00:00Z");
}

#[test]
Expand All @@ -330,7 +352,8 @@ mod tests {
provider: "huggingface".to_string(),
revision: "abc123".to_string(),
size: 1000,
modified_at: "2025-01-01T00:00:00Z".to_string(),
created_at: "2025-01-01T00:00:00Z".to_string(),
updated_at: "2025-01-01T00:00:00Z".to_string(),
cache_path: cache_dir.to_string_lossy().to_string(),
arch: None,
};
Expand Down Expand Up @@ -369,7 +392,8 @@ mod tests {
provider: "huggingface".to_string(),
revision: "abc123def456".to_string(),
size: 7_000_000_000,
modified_at: "2025-01-01T00:00:00Z".to_string(),
created_at: "2025-01-01T00:00:00Z".to_string(),
updated_at: "2025-01-01T00:00:00Z".to_string(),
cache_path: "/tmp/test/gpt".to_string(),
arch: Some(ModelArchitecture {
model_type: Some("gpt2".to_string()),
Expand Down Expand Up @@ -407,7 +431,8 @@ mod tests {
provider: "huggingface".to_string(),
revision: "legacy123".to_string(),
size: 1_000_000,
modified_at: "2024-01-01T00:00:00Z".to_string(),
created_at: "2024-01-01T00:00:00Z".to_string(),
updated_at: "2024-01-01T00:00:00Z".to_string(),
cache_path: "/tmp/test/legacy".to_string(),
arch: None,
};
Expand Down
Loading