Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 61 additions & 34 deletions src/cli/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,33 +185,56 @@ pub async fn run(cli: Cli) {
Ok(Some(model)) => {
println!("Name: {}", model.name);
println!("Kind: Model");
println!("Metadata:");
println!(" Created: {}", format_time_ago(&model.created_at));
println!(" Updated: {}", format_time_ago(&model.updated_at));

println!("Spec:");
// Architecture section (only if info is available)
if let Some(arch) = &model.arch {
println!(" Architecture:");
if let Some(model_type) = &arch.model_type {
println!(" Type: {}", model_type);
}
if let Some(classes) = &arch.classes {
println!(" Classes: {}", classes.join(", "));
}
if let Some(parameters) = &arch.parameters {
println!(" Parameters: {}", parameters);
}
if let Some(context_window) = arch.context_window {
println!(" Context Window: {}", context_window);
}
if let Some(spec) = &model.spec {
println!(
" Author: {}",
spec.author.as_deref().unwrap_or("N/A")
);
println!(
" Task: {}",
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Context Window is formatted with format_parameters, which produces abbreviated K/M/B output (e.g., 2048 -> "2.05K"). Context window is a token count and should generally be shown as an exact integer (or with a dedicated formatter) to avoid confusing output.

Copilot uses AI. Check for mistakes.
spec.task.as_deref().unwrap_or("N/A")
);
println!(
" License: {}",
spec.license
.as_ref()
.map(|s| s.to_uppercase())
.unwrap_or_else(|| "N/A".to_string())
);
println!(
" Model Type: {}",
spec.model_type.as_deref().unwrap_or("N/A")
);
println!(
" Parameters: {}",
spec.parameters
.map(crate::utils::format::format_parameters)
.unwrap_or_else(|| "N/A".to_string())
);
println!(
" Context Window: {}",
spec.context_window
.map(|w| crate::utils::format::format_parameters(w as u64))
.unwrap_or_else(|| "N/A".to_string())
);
} else {
println!(" Author: N/A");
println!(" Task: N/A");
println!(" License: N/A");
println!(" Model Type: N/A");
println!(" Parameters: N/A");
println!(" Context Window: N/A");
}
// Registry section
println!(" Registry:");
println!(" Provider: {}", model.provider);
println!(" Revision: {}", model.revision);
println!(" Size: {}", format_size_decimal(model.size));
println!(" Cache Path: {}", model.cache_path);
println!("Status:");
println!(" Created: {}", format_time_ago(&model.created_at));
println!(" Updated: {}", format_time_ago(&model.updated_at));
}
Ok(None) => {
eprintln!("Model not found: {}", args.model);
Expand All @@ -233,7 +256,7 @@ pub async fn run(cli: Cli) {
#[cfg(test)]
mod tests {
use super::*;
use crate::registry::model_registry::{ModelArchitecture, ModelInfo};
use crate::registry::model_registry::ModelInfo;
use tempfile::TempDir;

#[test]
Expand All @@ -258,7 +281,7 @@ mod tests {
created_at: "2025-01-01T00:00:00Z".to_string(),
updated_at: "2025-01-01T00:00:00Z".to_string(),
cache_path: "/tmp/test".to_string(),
arch: None,
spec: None,
};

registry.register_model(model).unwrap();
Expand All @@ -282,11 +305,13 @@ mod tests {
created_at: "2025-01-01T00:00:00Z".to_string(),
updated_at: "2025-01-02T00:00:00Z".to_string(),
cache_path: "/tmp/test/gpt".to_string(),
arch: Some(ModelArchitecture {
spec: Some(crate::registry::model_registry::ModelSpec {
model_type: Some("gpt2".to_string()),
classes: Some(vec!["GPT2LMHeadModel".to_string()]),
parameters: Some(7_000_000_000),
context_window: Some(2048),
parameters: Some("7.00B".to_string()),
author: Some("test-org".to_string()),
task: Some("text-generation".to_string()),
license: Some("mit".to_string()),
}),
};

Expand All @@ -300,11 +325,13 @@ mod tests {
assert_eq!(model_info.created_at, "2025-01-01T00:00:00Z");
assert_eq!(model_info.updated_at, "2025-01-02T00:00:00Z");

let arch = model_info.arch.unwrap();
assert_eq!(arch.model_type, Some("gpt2".to_string()));
assert_eq!(arch.classes, Some(vec!["GPT2LMHeadModel".to_string()]));
assert_eq!(arch.context_window, Some(2048));
assert_eq!(arch.parameters, Some("7.00B".to_string()));
let spec = model_info.spec.as_ref().unwrap();
assert_eq!(spec.author, Some("test-org".to_string()));
assert_eq!(spec.task, Some("text-generation".to_string()));
assert_eq!(spec.license, Some("mit".to_string()));
assert_eq!(spec.model_type, Some("gpt2".to_string()));
assert_eq!(spec.context_window, Some(2048));
assert_eq!(spec.parameters, Some(7_000_000_000));
}

#[test]
Expand All @@ -320,7 +347,7 @@ mod tests {
created_at: "2025-01-01T00:00:00Z".to_string(),
updated_at: "2025-01-01T00:00:00Z".to_string(),
cache_path: "/tmp/test/simple".to_string(),
arch: None,
spec: None,
};

registry.register_model(model).unwrap();
Expand All @@ -330,7 +357,7 @@ mod tests {

let model_info = retrieved.unwrap();
assert_eq!(model_info.name, "test/simple-model");
assert!(model_info.arch.is_none());
assert!(model_info.spec.is_none());
}

#[test]
Expand All @@ -346,7 +373,7 @@ mod tests {
created_at: "2025-01-01T00:00:00Z".to_string(),
updated_at: "2025-01-01T00:00:00Z".to_string(),
cache_path: "/tmp/test/remove".to_string(),
arch: None,
spec: None,
};

registry.register_model(model).unwrap();
Expand Down Expand Up @@ -400,7 +427,7 @@ mod tests {
created_at: "2025-01-01T00:00:00Z".to_string(),
updated_at: "2025-01-01T00:00:00Z".to_string(),
cache_path: "/tmp/test".to_string(),
arch: None,
spec: None,
};

registry.register_model(model).unwrap();
Expand All @@ -414,7 +441,7 @@ mod tests {
created_at: "2025-01-05T00:00:00Z".to_string(),
updated_at: "2025-01-05T00:00:00Z".to_string(),
cache_path: "/tmp/test".to_string(),
arch: None,
spec: None,
};

registry.register_model(updated_model).unwrap();
Expand Down
101 changes: 92 additions & 9 deletions src/downloader/huggingface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use indicatif::{ProgressBar, ProgressStyle};

use crate::downloader::downloader::{DownloadError, Downloader};
use crate::downloader::progress::{DownloadProgressManager, FileProgress};
use crate::registry::model_registry::{ModelArchitecture, ModelInfo, ModelRegistry};
use crate::registry::model_registry::{ModelInfo, ModelRegistry, ModelSpec};
use crate::utils::file::{self, format_model_name};

/// Adapter to bridge HuggingFace's Progress trait with our FileProgress
Expand Down Expand Up @@ -35,6 +35,60 @@ impl HuggingFaceDownloader {
pub fn new() -> Self {
Self
}

async fn fetch_metadata_from_api(
model_name: &str,
) -> (
Option<String>,
Option<String>,
Option<String>,
Option<String>,
Option<u64>,
Option<u64>,
) {
let url = format!("https://huggingface.co/api/models/{}", model_name);
let client = reqwest::Client::new();

match client.get(&url).send().await {
Comment on lines +49 to +52
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The HuggingFace API URL is built via string concatenation with model_name, but model_name may contain characters that require URL encoding. hf_hub likely handles this internally for downloads, so this extra request can fail even when the snapshot download succeeds. Consider constructing the URL with reqwest::Url (or similar) to ensure proper escaping of the repo id/path segment.

Suggested change
let url = format!("https://huggingface.co/api/models/{}", model_name);
let client = reqwest::Client::new();
match client.get(&url).send().await {
let mut url = match reqwest::Url::parse("https://huggingface.co/") {
Ok(url) => url,
Err(_) => return (None, None, None, None, None, None),
};
if let Ok(mut path_segments) = url.path_segments_mut() {
path_segments.push("api");
path_segments.push("models");
path_segments.push(model_name);
} else {
return (None, None, None, None, None, None);
}
let client = reqwest::Client::new();
match client.get(url).send().await {

Copilot uses AI. Check for mistakes.
Ok(response) => {
if let Ok(json) = response.json::<serde_json::Value>().await {
let author = json
Comment on lines +49 to +55
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fetch_metadata_from_api uses a default reqwest::Client with no timeout and it ignores non-success HTTP statuses. This can cause pull to hang indefinitely on network stalls and silently treat 4xx/5xx responses as “no metadata”. Consider using Client::builder().timeout(...) (or per-request timeout) and calling response.error_for_status() before attempting to parse JSON, returning/logging an error when the API call fails.

Copilot uses AI. Check for mistakes.
.get("author")
.and_then(|v| v.as_str())
.map(|s| s.to_string());

let task = json
.get("pipeline_tag")
.and_then(|v| v.as_str())
.map(|s| s.to_string());

let license = json
.get("cardData")
.and_then(|card| card.get("license"))
.and_then(|v| v.as_str())
.map(|s| s.to_string());

let model_type = json
.get("config")
.and_then(|config| config.get("model_type"))
.and_then(|v| v.as_str())
.map(|s| s.to_string());

let parameters = json
.get("safetensors")
.and_then(|st| st.get("total"))
.and_then(|v| v.as_u64());

let storage = json.get("usedStorage").and_then(|v| v.as_u64());

(author, task, license, model_type, parameters, storage)
} else {
(None, None, None, None, None, None)
}
}
Err(_) => (None, None, None, None, None, None),
}
}
}

impl Default for HuggingFaceDownloader {
Expand Down Expand Up @@ -195,34 +249,63 @@ impl Downloader for HuggingFaceDownloader {
}

let elapsed_time = start_time.elapsed();

// Get accumulated size from downloads
let downloaded_size = progress_manager.total_downloaded_bytes();
let model_cache_path = cache_dir.join(format_model_name(name));

// Register the model only if not totally cached
if !model_totally_cached {
// Extract architecture info from config.json
// Fetch metadata from HuggingFace API
let (
author_from_api,
task_from_api,
license_from_api,
model_type_from_api,
parameters_from_api,
storage_from_api,
) = Self::fetch_metadata_from_api(name).await;

// Extract context_window from config.json
let config_path = snapshot_path.join("config.json");
let arch = if config_path.exists() {
let context_window = if config_path.exists() {
std::fs::read_to_string(&config_path)
.ok()
.and_then(|content| serde_json::from_str::<serde_json::Value>(&content).ok())
.and_then(|config| ModelArchitecture::from_config(&config))
.and_then(|config| {
config
.get("text_config")
.and_then(|tc| tc.get("max_position_embeddings"))
.or_else(|| config.get("max_position_embeddings"))
.or_else(|| config.get("n_positions"))
.or_else(|| config.get("n_ctx"))
.and_then(|v| v.as_u64())
.map(|v| v as u32)
})
} else {
None
};

let spec = Some(ModelSpec {
author: author_from_api,
task: task_from_api,
license: license_from_api,
model_type: model_type_from_api,
parameters: parameters_from_api,
context_window,
});
Comment on lines +286 to +293
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spec is always set to Some(ModelSpec { ... }) even when all fields are None, which will serialize into "spec": {} (since the inner fields are skipped). If the intent is “no metadata”, consider setting spec to None when all extracted fields are missing to avoid storing an empty object and to keep the registry format cleaner.

Suggested change
let spec = Some(ModelSpec {
author: author_from_api,
task: task_from_api,
license: license_from_api,
model_type: model_type_from_api,
parameters: parameters_from_api,
context_window,
});
let model_spec = ModelSpec {
author: author_from_api,
task: task_from_api,
license: license_from_api,
model_type: model_type_from_api,
parameters: parameters_from_api,
context_window,
};
let spec = if model_spec.author.is_some()
|| model_spec.task.is_some()
|| model_spec.license.is_some()
|| model_spec.model_type.is_some()
|| model_spec.parameters.is_some()
|| model_spec.context_window.is_some()
{
Some(model_spec)
} else {
None
};

Copilot uses AI. Check for mistakes.

// Use storage from API, fallback to accumulated download size
let model_size =
storage_from_api.unwrap_or_else(|| progress_manager.total_downloaded_bytes());

let now = chrono::Local::now().to_rfc3339();
let model_info_record = ModelInfo {
name: name.to_string(),
provider: "huggingface".to_string(),
revision: sha,
size: downloaded_size,
size: model_size,
created_at: now.clone(),
updated_at: now,
cache_path: model_cache_path.to_string_lossy().to_string(),
arch,
spec,
};

let registry = ModelRegistry::new(None);
Expand Down
Loading
Loading