diff --git a/src/cli/commands.rs b/src/cli/commands.rs index 3ac364d..2b849b6 100644 --- a/src/cli/commands.rs +++ b/src/cli/commands.rs @@ -185,26 +185,46 @@ pub async fn run(cli: Cli) { Ok(Some(model)) => { println!("Name: {}", model.name); println!("Kind: Model"); - println!("Metadata:"); - println!(" Created: {}", format_time_ago(&model.created_at)); - println!(" Updated: {}", format_time_ago(&model.updated_at)); - println!("Spec:"); - // Architecture section (only if info is available) - if let Some(arch) = &model.arch { - println!(" Architecture:"); - if let Some(model_type) = &arch.model_type { - println!(" Type: {}", model_type); - } - if let Some(classes) = &arch.classes { - println!(" Classes: {}", classes.join(", ")); - } - if let Some(parameters) = &arch.parameters { - println!(" Parameters: {}", parameters); - } - if let Some(context_window) = arch.context_window { - println!(" Context Window: {}", context_window); - } + if let Some(spec) = &model.spec { + println!( + " Author: {}", + spec.author.as_deref().unwrap_or("N/A") + ); + println!( + " Task: {}", + spec.task.as_deref().unwrap_or("N/A") + ); + println!( + " License: {}", + spec.license + .as_ref() + .map(|s| s.to_uppercase()) + .unwrap_or_else(|| "N/A".to_string()) + ); + println!( + " Model Type: {}", + spec.model_type.as_deref().unwrap_or("N/A") + ); + println!( + " Parameters: {}", + spec.parameters + .map(crate::utils::format::format_parameters) + .unwrap_or_else(|| "N/A".to_string()) + ); + println!( + " Context Window: {}", + spec.context_window + .map(|w| crate::utils::format::format_parameters(w as u64)) + .unwrap_or_else(|| "N/A".to_string()) + ); + } else { + println!(" Author: N/A"); + println!(" Task: N/A"); + println!(" License: N/A"); + println!(" Model Type: N/A"); + println!(" Parameters: N/A"); + println!(" Context Window: N/A"); } // Registry section println!(" Registry:"); @@ -212,6 +232,9 @@ pub async fn run(cli: Cli) { println!(" Revision: {}", model.revision); println!(" Size: {}", format_size_decimal(model.size)); println!(" Cache Path: {}", model.cache_path); + println!("Status:"); + println!(" Created: {}", format_time_ago(&model.created_at)); + println!(" Updated: {}", format_time_ago(&model.updated_at)); } Ok(None) => { eprintln!("Model not found: {}", args.model); @@ -233,7 +256,7 @@ pub async fn run(cli: Cli) { #[cfg(test)] mod tests { use super::*; - use crate::registry::model_registry::{ModelArchitecture, ModelInfo}; + use crate::registry::model_registry::ModelInfo; use tempfile::TempDir; #[test] @@ -258,7 +281,7 @@ mod tests { created_at: "2025-01-01T00:00:00Z".to_string(), updated_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), - arch: None, + spec: None, }; registry.register_model(model).unwrap(); @@ -282,11 +305,13 @@ mod tests { created_at: "2025-01-01T00:00:00Z".to_string(), updated_at: "2025-01-02T00:00:00Z".to_string(), cache_path: "/tmp/test/gpt".to_string(), - arch: Some(ModelArchitecture { + spec: Some(crate::registry::model_registry::ModelSpec { model_type: Some("gpt2".to_string()), - classes: Some(vec!["GPT2LMHeadModel".to_string()]), + parameters: Some(7_000_000_000), context_window: Some(2048), - parameters: Some("7.00B".to_string()), + author: Some("test-org".to_string()), + task: Some("text-generation".to_string()), + license: Some("mit".to_string()), }), }; @@ -300,11 +325,13 @@ mod tests { assert_eq!(model_info.created_at, "2025-01-01T00:00:00Z"); assert_eq!(model_info.updated_at, "2025-01-02T00:00:00Z"); - let arch = model_info.arch.unwrap(); - assert_eq!(arch.model_type, Some("gpt2".to_string())); - assert_eq!(arch.classes, Some(vec!["GPT2LMHeadModel".to_string()])); - assert_eq!(arch.context_window, Some(2048)); - assert_eq!(arch.parameters, Some("7.00B".to_string())); + let spec = model_info.spec.as_ref().unwrap(); + assert_eq!(spec.author, Some("test-org".to_string())); + assert_eq!(spec.task, Some("text-generation".to_string())); + assert_eq!(spec.license, Some("mit".to_string())); + assert_eq!(spec.model_type, Some("gpt2".to_string())); + assert_eq!(spec.context_window, Some(2048)); + assert_eq!(spec.parameters, Some(7_000_000_000)); } #[test] @@ -320,7 +347,7 @@ mod tests { created_at: "2025-01-01T00:00:00Z".to_string(), updated_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test/simple".to_string(), - arch: None, + spec: None, }; registry.register_model(model).unwrap(); @@ -330,7 +357,7 @@ mod tests { let model_info = retrieved.unwrap(); assert_eq!(model_info.name, "test/simple-model"); - assert!(model_info.arch.is_none()); + assert!(model_info.spec.is_none()); } #[test] @@ -346,7 +373,7 @@ mod tests { created_at: "2025-01-01T00:00:00Z".to_string(), updated_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test/remove".to_string(), - arch: None, + spec: None, }; registry.register_model(model).unwrap(); @@ -400,7 +427,7 @@ mod tests { created_at: "2025-01-01T00:00:00Z".to_string(), updated_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), - arch: None, + spec: None, }; registry.register_model(model).unwrap(); @@ -414,7 +441,7 @@ mod tests { created_at: "2025-01-05T00:00:00Z".to_string(), updated_at: "2025-01-05T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), - arch: None, + spec: None, }; registry.register_model(updated_model).unwrap(); diff --git a/src/downloader/huggingface.rs b/src/downloader/huggingface.rs index dfc0a40..623d883 100644 --- a/src/downloader/huggingface.rs +++ b/src/downloader/huggingface.rs @@ -6,7 +6,7 @@ use indicatif::{ProgressBar, ProgressStyle}; use crate::downloader::downloader::{DownloadError, Downloader}; use crate::downloader::progress::{DownloadProgressManager, FileProgress}; -use crate::registry::model_registry::{ModelArchitecture, ModelInfo, ModelRegistry}; +use crate::registry::model_registry::{ModelInfo, ModelRegistry, ModelSpec}; use crate::utils::file::{self, format_model_name}; /// Adapter to bridge HuggingFace's Progress trait with our FileProgress @@ -35,6 +35,60 @@ impl HuggingFaceDownloader { pub fn new() -> Self { Self } + + async fn fetch_metadata_from_api( + model_name: &str, + ) -> ( + Option, + Option, + Option, + Option, + Option, + Option, + ) { + let url = format!("https://huggingface.co/api/models/{}", model_name); + let client = reqwest::Client::new(); + + match client.get(&url).send().await { + Ok(response) => { + if let Ok(json) = response.json::().await { + let author = json + .get("author") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + let task = json + .get("pipeline_tag") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + let license = json + .get("cardData") + .and_then(|card| card.get("license")) + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + let model_type = json + .get("config") + .and_then(|config| config.get("model_type")) + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + let parameters = json + .get("safetensors") + .and_then(|st| st.get("total")) + .and_then(|v| v.as_u64()); + + let storage = json.get("usedStorage").and_then(|v| v.as_u64()); + + (author, task, license, model_type, parameters, storage) + } else { + (None, None, None, None, None, None) + } + } + Err(_) => (None, None, None, None, None, None), + } + } } impl Default for HuggingFaceDownloader { @@ -195,34 +249,63 @@ impl Downloader for HuggingFaceDownloader { } let elapsed_time = start_time.elapsed(); - - // Get accumulated size from downloads - let downloaded_size = progress_manager.total_downloaded_bytes(); let model_cache_path = cache_dir.join(format_model_name(name)); // Register the model only if not totally cached if !model_totally_cached { - // Extract architecture info from config.json + // Fetch metadata from HuggingFace API + let ( + author_from_api, + task_from_api, + license_from_api, + model_type_from_api, + parameters_from_api, + storage_from_api, + ) = Self::fetch_metadata_from_api(name).await; + + // Extract context_window from config.json let config_path = snapshot_path.join("config.json"); - let arch = if config_path.exists() { + let context_window = if config_path.exists() { std::fs::read_to_string(&config_path) .ok() .and_then(|content| serde_json::from_str::(&content).ok()) - .and_then(|config| ModelArchitecture::from_config(&config)) + .and_then(|config| { + config + .get("text_config") + .and_then(|tc| tc.get("max_position_embeddings")) + .or_else(|| config.get("max_position_embeddings")) + .or_else(|| config.get("n_positions")) + .or_else(|| config.get("n_ctx")) + .and_then(|v| v.as_u64()) + .map(|v| v as u32) + }) } else { None }; + let spec = Some(ModelSpec { + author: author_from_api, + task: task_from_api, + license: license_from_api, + model_type: model_type_from_api, + parameters: parameters_from_api, + context_window, + }); + + // Use storage from API, fallback to accumulated download size + let model_size = + storage_from_api.unwrap_or_else(|| progress_manager.total_downloaded_bytes()); + let now = chrono::Local::now().to_rfc3339(); let model_info_record = ModelInfo { name: name.to_string(), provider: "huggingface".to_string(), revision: sha, - size: downloaded_size, + size: model_size, created_at: now.clone(), updated_at: now, cache_path: model_cache_path.to_string_lossy().to_string(), - arch, + spec, }; let registry = ModelRegistry::new(None); diff --git a/src/registry/model_registry.rs b/src/registry/model_registry.rs index 3505570..6916c8e 100644 --- a/src/registry/model_registry.rs +++ b/src/registry/model_registry.rs @@ -4,92 +4,21 @@ use std::fs; use std::path::PathBuf; use crate::utils::file; -use crate::utils::format::format_parameters; -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct ModelArchitecture { +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ModelSpec { + #[serde(skip_serializing_if = "Option::is_none")] + pub author: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub task: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub license: Option, #[serde(skip_serializing_if = "Option::is_none")] pub model_type: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub classes: Option>, + pub parameters: Option, #[serde(skip_serializing_if = "Option::is_none")] pub context_window: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub parameters: Option, -} - -impl ModelArchitecture { - /// Extract model architecture from config.json - pub fn from_config(config: &serde_json::Value) -> Option { - let model_type = config - .get("model_type") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - - let classes = config - .get("architectures") - .and_then(|v| v.as_array()) - .map(|arr| { - arr.iter() - .filter_map(|v| v.as_str().map(|s| s.to_string())) - .collect::>() - }) - .filter(|v| !v.is_empty()); - - let context_window = config - .get("n_positions") - .or_else(|| config.get("max_position_embeddings")) - .or_else(|| config.get("n_ctx")) - .and_then(|v| v.as_u64()) - .map(|v| v as u32); - - let parameters = Self::estimate_parameters(config); - - if model_type.is_some() - || classes.is_some() - || context_window.is_some() - || parameters.is_some() - { - Some(ModelArchitecture { - model_type, - classes, - context_window, - parameters, - }) - } else { - None - } - } - - /// Estimate model parameters from config - fn estimate_parameters(config: &serde_json::Value) -> Option { - let n_layer = config - .get("n_layer") - .or_else(|| config.get("num_hidden_layers")) - .and_then(|v| v.as_u64())?; - - let n_embd = config - .get("n_embd") - .or_else(|| config.get("hidden_size")) - .and_then(|v| v.as_u64())?; - - let vocab_size = config.get("vocab_size").and_then(|v| v.as_u64())?; - - let n_positions = config - .get("n_positions") - .or_else(|| config.get("max_position_embeddings")) - .and_then(|v| v.as_u64()) - .unwrap_or(2048); - - // Rough parameter estimation for transformer models - // Each layer: ~12 * n_embd^2 (attention + FFN) - // Embeddings: vocab_size * n_embd + n_positions * n_embd - let layer_params = 12 * n_layer * n_embd * n_embd; - let embedding_params = vocab_size * n_embd + n_positions * n_embd; - let total_params = layer_params + embedding_params; - - Some(format_parameters(total_params)) - } } #[derive(Debug, Serialize, Deserialize, Clone)] @@ -102,7 +31,7 @@ pub struct ModelInfo { pub updated_at: String, pub cache_path: String, #[serde(skip_serializing_if = "Option::is_none")] - pub arch: Option, + pub spec: Option, } pub struct ModelRegistry { @@ -227,7 +156,7 @@ mod tests { created_at: "2025-01-01T00:00:00Z".to_string(), updated_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), - arch: None, + spec: None, }; registry.register_model(model.clone()).unwrap(); @@ -250,7 +179,7 @@ mod tests { created_at: "2025-01-01T00:00:00Z".to_string(), updated_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), - arch: None, + spec: None, }; registry.register_model(model).unwrap(); @@ -273,7 +202,7 @@ mod tests { created_at: "2025-01-01T00:00:00Z".to_string(), updated_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), - arch: None, + spec: None, }; registry.register_model(model).unwrap(); @@ -309,7 +238,7 @@ mod tests { created_at: "2025-01-01T00:00:00Z".to_string(), updated_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), - arch: None, + spec: None, }; registry.register_model(model1).unwrap(); @@ -322,7 +251,7 @@ mod tests { created_at: "2025-01-02T00:00:00Z".to_string(), updated_at: "2025-01-02T00:00:00Z".to_string(), cache_path: "/tmp/test2".to_string(), - arch: None, + spec: None, }; registry.register_model(model2).unwrap(); @@ -355,7 +284,7 @@ mod tests { created_at: "2025-01-01T00:00:00Z".to_string(), updated_at: "2025-01-01T00:00:00Z".to_string(), cache_path: cache_dir.to_string_lossy().to_string(), - arch: None, + spec: None, }; registry.register_model(model).unwrap(); @@ -395,11 +324,13 @@ mod tests { created_at: "2025-01-01T00:00:00Z".to_string(), updated_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test/gpt".to_string(), - arch: Some(ModelArchitecture { + spec: Some(ModelSpec { model_type: Some("gpt2".to_string()), - classes: Some(vec!["GPT2LMHeadModel".to_string()]), + parameters: Some(7_000_000_000), context_window: Some(2048), - parameters: Some("7.00B".to_string()), + author: None, + task: None, + license: None, }), }; @@ -414,11 +345,10 @@ mod tests { assert_eq!(model_info.revision, "abc123def456"); assert_eq!(model_info.size, 7_000_000_000); - let arch = model_info.arch.unwrap(); - assert_eq!(arch.model_type, Some("gpt2".to_string())); - assert_eq!(arch.classes, Some(vec!["GPT2LMHeadModel".to_string()])); - assert_eq!(arch.context_window, Some(2048)); - assert_eq!(arch.parameters, Some("7.00B".to_string())); + let spec = model_info.spec.as_ref().unwrap(); + assert_eq!(spec.model_type, Some("gpt2".to_string())); + assert_eq!(spec.context_window, Some(2048)); + assert_eq!(spec.parameters, Some(7_000_000_000)); } #[test] @@ -434,7 +364,7 @@ mod tests { created_at: "2024-01-01T00:00:00Z".to_string(), updated_at: "2024-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test/legacy".to_string(), - arch: None, + spec: None, }; registry.register_model(model).unwrap(); @@ -444,77 +374,6 @@ mod tests { let model_info = retrieved.unwrap(); assert_eq!(model_info.name, "test/legacy-model"); - assert!(model_info.arch.is_none()); - } - - #[test] - fn test_model_architecture_from_config_gpt2() { - use serde_json::json; - - let config = json!({ - "model_type": "gpt2", - "architectures": ["GPT2LMHeadModel"], - "n_layer": 5, - "n_embd": 32, - "vocab_size": 1000, - "n_positions": 512 - }); - - let arch = ModelArchitecture::from_config(&config); - assert!(arch.is_some()); - - let arch = arch.unwrap(); - assert_eq!(arch.model_type, Some("gpt2".to_string())); - assert_eq!(arch.classes, Some(vec!["GPT2LMHeadModel".to_string()])); - assert_eq!(arch.context_window, Some(512)); - assert_eq!(arch.parameters, Some("109.82K".to_string())); - } - - #[test] - fn test_model_architecture_from_config_bert_style() { - use serde_json::json; - - let config = json!({ - "model_type": "bert", - "num_hidden_layers": 12, - "hidden_size": 768, - "vocab_size": 30000, - "max_position_embeddings": 512 - }); - - let arch = ModelArchitecture::from_config(&config); - assert!(arch.is_some()); - - let arch = arch.unwrap(); - assert_eq!(arch.model_type, Some("bert".to_string())); - assert_eq!(arch.context_window, Some(512)); - assert!(arch.parameters.unwrap().contains("M")); - } - - #[test] - fn test_model_architecture_from_config_partial() { - use serde_json::json; - - let config = json!({ - "model_type": "llama", - "n_ctx": 4096 - }); - - let arch = ModelArchitecture::from_config(&config); - assert!(arch.is_some()); - - let arch = arch.unwrap(); - assert_eq!(arch.model_type, Some("llama".to_string())); - assert_eq!(arch.context_window, Some(4096)); - assert_eq!(arch.parameters, None); - } - - #[test] - fn test_model_architecture_from_config_empty() { - use serde_json::json; - - let config = json!({}); - let arch = ModelArchitecture::from_config(&config); - assert_eq!(arch, None); + assert!(model_info.spec.is_none()); } }