diff --git a/Cargo.toml b/Cargo.toml index b9730e0..c21faba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] async-trait = "0.1" axum = { version = "0.7.9", features = ["tracing", "multipart", "macros"] } -axum-server = { version = "0.8.0", features = ["rustls", "rustls-pemfile", "tls-rustls"] } +axum-server = { version = "0.8.0", features = ["rustls", "tls-rustls"] } azure_blob_uploader = "0.1.4" azure_storage = "0.21.0" azure_storage_blobs = "0.21.0" @@ -19,14 +19,12 @@ futures = "0.3.31" futures-util = "0.3.31" headers = "0.4.0" hex = "0.4.3" -hmac = "0.13.0" http-body-util = "0.1.2" -jwt = "0.16.0" +jsonwebtoken = "10" rand = "0.8.5" reqwest = { version = "0.12.9", features = ["blocking"] } rustls = "0.23.20" serde = { version = "1.0.216", features = ["derive"] } -sha2 = "0.11.0" sysinfo = { version = "0.35.2", features = ["serde"] } tempfile = "3.14.0" tokio = { version = "1.42.0", features = ["rt", "rt-multi-thread", "macros"] } @@ -39,8 +37,6 @@ tracing-subscriber = "0.3.19" [dev-dependencies] reqwest = { version = "0.12.9", features = ["blocking", "multipart"] } -hmac = "0.13.0" -jwt = "0.16.0" -sha2 = "0.11.0" +jsonwebtoken = "10" tokio = { version = "1.42.0", features = ["rt", "rt-multi-thread", "macros", "time", "process"] } tempfile = "3.14.0" diff --git a/src/azure.rs b/src/azure.rs index 3e31c6a..21c8e7f 100644 --- a/src/azure.rs +++ b/src/azure.rs @@ -106,6 +106,7 @@ mod tests { } } +#[allow(dead_code)] fn calculate_checksum(filename: &String, data: &[u8]) { let hash = sha2_512::default().update(data).finalize(); let digest = hash.digest(); @@ -217,6 +218,7 @@ async fn write_file_to_blob_streaming( /// Write file to Azure blob storage (legacy version using Vec) /// TBD: Rework, do not keep whole file as Vec in memory!!! +#[allow(dead_code)] async fn write_file_to_blob( filename: String, data: Vec, @@ -378,7 +380,6 @@ async fn get_file_from_blob(filename: String) -> ReceivedFile { let blob_client = ClientBuilder::new(storage_account, storage_credential) .blob_client(storage_container, storage_blob); let blob_url_res = blob_client.url(); - let mut blob_url = "".to_string(); let mut received_file = ReceivedFile { original_filename: "".to_string(), cached_filename: "".to_string(), @@ -387,16 +388,13 @@ async fn get_file_from_blob(filename: String) -> ReceivedFile { }; received_file.original_filename = filename.clone(); - match blob_url_res { - Ok(url) => { - //println!("Blob URL: {}", url); - blob_url = url.to_string(); - } + let mut blob_url = match blob_url_res { + Ok(url) => url.to_string(), Err(e) => { eprintln!("Error getting blob URL: {:?}", e); return received_file; } - } + }; // append SAS token to blob URL blob_url.push_str(storage_sastoken); // we generate a hash of the filename to use as cache filename @@ -527,6 +525,7 @@ async fn get_file_from_blob(filename: String) -> ReceivedFile { // Implement set tags for Azure blob storage // tags are in format "key=value" +#[allow(dead_code)] async fn azure_set_filename_tags( filename: String, user_tags: Vec<(String, String)>, @@ -571,7 +570,8 @@ async fn azure_set_filename_tags( } } -async fn azure_list_files(directory: String) -> Vec { +#[allow(dead_code)] +async fn azure_list_files(_directory: String) -> Vec { let azure_cfg = Arc::new(get_azure_credentials("azure")); let storage_account = azure_cfg.account.as_str(); let storage_key = azure_cfg.key.clone(); diff --git a/src/local.rs b/src/local.rs index b7c7172..246ef49 100644 --- a/src/local.rs +++ b/src/local.rs @@ -100,6 +100,7 @@ fn get_metadata_file_path(filename: &str) -> PathBuf { } /// Calculate SHA-512 checksum of file data +#[allow(dead_code)] fn calculate_checksum(filename: &str, data: &[u8]) { let hash = sha2_512::default().update(data).finalize(); let digest = hash.digest(); @@ -157,6 +158,7 @@ async fn write_file_to_local_streaming( } /// Write file to local storage (legacy version using Vec) +#[allow(dead_code)] fn write_file_to_local( filename: String, data: Vec, @@ -298,6 +300,7 @@ fn list_files_in_local(directory: String) -> Vec { } /// Set tags for local storage (stored in metadata) +#[allow(dead_code)] fn set_tags_for_local_file( filename: String, user_tags: Vec<(String, String)>, diff --git a/src/main.rs b/src/main.rs index 2bd0d97..56d328a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -190,6 +190,7 @@ impl<'a> tokio::io::AsyncRead for FieldStream<'a> { } } +#[allow(dead_code)] #[async_trait] trait Driver: Send + Sync { async fn write_file( @@ -1149,6 +1150,7 @@ async fn driver_get_file(filepath: String) -> ReceivedFile { driver.get_file(filepath).await } +#[allow(dead_code)] async fn write_file_driver( filename: String, data: Vec, diff --git a/src/storcaching.rs b/src/storcaching.rs index a669931..5e1be56 100644 --- a/src/storcaching.rs +++ b/src/storcaching.rs @@ -1,4 +1,4 @@ -use crate::{debug_log, get_config_content}; +use crate::get_config_content; use serde::Deserialize; use std::cmp::Ordering; use std::collections::{BinaryHeap, HashSet}; @@ -325,7 +325,6 @@ pub async fn cache_loop(cache_dir: &str) { let cleanup_chunk_size = config.cleanup_chunk_size.max(1); let mut deleted_entries_counter: u64 = 0; let mut reclaimed_bytes_counter: u64 = 0; - let mut cached_file_count: usize = 0; let mut next_log = Instant::now(); loop { @@ -333,7 +332,7 @@ pub async fn cache_loop(cache_dir: &str) { enforce_cache_file_limit(cache_dir, cleanup_chunk_size).await; deleted_entries_counter += limit_outcome.deleted_entries; reclaimed_bytes_counter += limit_outcome.reclaimed_bytes; - cached_file_count = file_count; + let mut cached_file_count = file_count; let mut free_space = freediskspace_percent(cache_dir).await; if free_space < DISK_SPACE_LOW_PERCENT { diff --git a/src/storjwt.rs b/src/storjwt.rs index 7072c2a..8fd4235 100644 --- a/src/storjwt.rs +++ b/src/storjwt.rs @@ -1,36 +1,62 @@ -use crate::{debug_log, get_config_content}; -use hmac::{Hmac, Mac}; -use jwt::{Header, SignWithKey, Token, VerifyWithKey}; -use sha2::Sha256; +use crate::get_config_content; +use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; +use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; use toml::value::Table; -pub fn verify_jwt_token(token_str: &str) -> Result, jwt::Error> { - // config.toml, jwt_secret parameter + +#[derive(Debug, Serialize, Deserialize)] +struct Claims { + email: String, +} + +fn verify_with_key_str( + token_str: &str, + key_str: &str, +) -> Result, jsonwebtoken::errors::Error> { + let key = DecodingKey::from_secret(key_str.as_bytes()); + let mut validation = Validation::default(); + validation.required_spec_claims.clear(); + validation.validate_exp = false; + let token_data = decode::(token_str, &key, &validation)?; + let mut claims = BTreeMap::new(); + claims.insert("email".to_string(), token_data.claims.email); + Ok(claims) +} + +pub fn verify_jwt_token( + token_str: &str, +) -> Result, jsonwebtoken::errors::Error> { let toml_cfg = get_config_content(); let parsed_toml = toml_cfg.parse::().unwrap(); - let key_str = parsed_toml["jwt_secret"].as_str().unwrap(); - let key: Hmac = Hmac::new_from_slice(key_str.as_bytes())?; - let verify_result = token_str.verify_with_key(&key); - let token: Token, _> = match verify_result { - Ok(token) => token, - Err(e) => { - eprintln!("JWT verification error: {:?}", e); - return Err(e); - } - }; - //let header = token.header(); - let claims = token.claims(); - let email = claims.get("email"); - match email { - Some(email) => { - debug_log!("email: {}", email); + + // If only unified_secret is configured, it serves as jwt_secret as well. + // Try jwt_secret first, then fall through to unified_secret. + if let Some(key_str) = parsed_toml.get("jwt_secret").and_then(|v| v.as_str()) { + match verify_with_key_str(token_str, key_str) { + Ok(claims) => { + debug_log!("email: {}", claims["email"]); + return Ok(claims); + } + Err(e) => { + debug_log!("JWT verification with jwt_secret failed: {:?}", e); + } } - None => { - debug_log!("email not found"); - return Err(jwt::Error::InvalidSignature); + } + + if let Some(unified) = parsed_toml.get("unified_secret").and_then(|v| v.as_str()) { + match verify_with_key_str(token_str, unified) { + Ok(claims) => { + debug_log!("email (unified_secret): {}", claims["email"]); + return Ok(claims); + } + Err(e) => { + eprintln!("JWT verification with unified_secret also failed: {:?}", e); + return Err(e); + } } } - Ok(claims.clone()) + + Err(jsonwebtoken::errors::ErrorKind::InvalidSignature.into()) } pub fn generate_jwt_secret() { @@ -45,13 +71,18 @@ pub fn generate_jwt_secret() { debug_log!("jwt_secret=\"{}\"", secret); } -pub fn generate_jwt_token(email: &str) -> Result { +pub fn generate_jwt_token(email: &str) -> Result { let toml_cfg = get_config_content(); let parsed_toml = toml_cfg.parse::
().unwrap(); - let key_str = parsed_toml["jwt_secret"].as_str().unwrap(); - let key: Hmac = Hmac::new_from_slice(key_str.as_bytes())?; - let mut claims = BTreeMap::new(); - claims.insert("email".to_string(), email.to_string()); - let token_str = claims.sign_with_key(&key)?; - Ok(token_str) + // For token generation, prefer jwt_secret, fall back to unified_secret + let key_str = parsed_toml + .get("jwt_secret") + .or_else(|| parsed_toml.get("unified_secret")) + .and_then(|v| v.as_str()) + .expect("config must define jwt_secret or unified_secret"); + let key = EncodingKey::from_secret(key_str.as_bytes()); + let claims = Claims { + email: email.to_string(), + }; + encode(&Header::default(), &claims, &key) } diff --git a/tests/e2e.rs b/tests/e2e.rs index 6d92576..df5e0cc 100644 --- a/tests/e2e.rs +++ b/tests/e2e.rs @@ -4,11 +4,9 @@ // End-to-end tests for kernelci-storage using the local storage driver. // These tests start the actual server binary and exercise the HTTP API. -use hmac::{Hmac, Mac}; -use jwt::SignWithKey; +use jsonwebtoken::{encode, EncodingKey, Header}; use reqwest::blocking::multipart; -use sha2::Sha256; -use std::collections::BTreeMap; +use serde::Serialize; use std::net::TcpListener; use std::path::PathBuf; use std::process::{Child, Command}; @@ -18,6 +16,11 @@ const JWT_SECRET: &str = "test-secret-for-e2e-testing"; const TEST_EMAIL: &str = "test@kernelci.org"; const RESTRICTED_EMAIL: &str = "restricted@kernelci.org"; +#[derive(Serialize)] +struct Claims { + email: String, +} + /// Find an available TCP port fn get_free_port() -> u16 { let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind to ephemeral port"); @@ -26,10 +29,11 @@ fn get_free_port() -> u16 { /// Generate a JWT token for the given email using the test secret fn generate_token(email: &str) -> String { - let key: Hmac = Hmac::new_from_slice(JWT_SECRET.as_bytes()).unwrap(); - let mut claims = BTreeMap::new(); - claims.insert("email".to_string(), email.to_string()); - claims.sign_with_key(&key).unwrap() + let key = EncodingKey::from_secret(JWT_SECRET.as_bytes()); + let claims = Claims { + email: email.to_string(), + }; + encode(&Header::default(), &claims, &key).unwrap() } /// Test server handle - kills the server process on drop