From 96a4b01d7a558095b6473820bf12a977fa7ad8c9 Mon Sep 17 00:00:00 2001 From: 0x416e746f6e Date: Tue, 7 Apr 2026 11:28:09 +0200 Subject: [PATCH 1/7] chore: make detect() non-async refs #18 --- crates/attestation/src/azure/mod.rs | 32 ++++++++++++----------------- crates/attestation/src/lib.rs | 31 ++++++++++++++-------------- 2 files changed, 28 insertions(+), 35 deletions(-) diff --git a/crates/attestation/src/azure/mod.rs b/crates/attestation/src/azure/mod.rs index 583ac30..3dc06c1 100644 --- a/crates/attestation/src/azure/mod.rs +++ b/crates/attestation/src/azure/mod.rs @@ -396,38 +396,32 @@ impl RsaPubKey { } /// Detect whether we are on Azure and can make an Azure vTPM attestation -pub async fn detect_azure_cvm() -> Result { - let client = reqwest::Client::builder().no_proxy().timeout(Duration::from_secs(2)).build()?; - - let response = match client.get(AZURE_METADATA_API).header("Metadata", "true").send().await { - Ok(response) => response, +pub fn detect_azure_cvm() -> Result { + let agent = ureq::AgentBuilder::new().timeout(Duration::from_millis(200)).build(); + let resp = match agent.get(AZURE_METADATA_API).set("Metadata", "true").call() { + Ok(resp) => resp, Err(err) => { tracing::debug!("Azure CVM detection failed: Azure metadata API request failed: {err}"); return Ok(false); } }; - if !response.status().is_success() { + if !resp.status() != 200 { tracing::debug!( "Azure CVM detection failed: metadata API returned non-success status: {}", - response.status() + resp.status() ); return Ok(false); } // Ensure the response has a JSON content type - let content_type = response - .headers() - .get(CONTENT_TYPE) - .map(|value| value.to_str().map(str::to_owned)) - .transpose() - .map_err(|_| MaaError::AzureMetadataApiNonJsonResponse { content_type: None })?; - - if !content_type - .as_deref() - .is_some_and(|value| value.to_lowercase().starts_with("application/json")) - { - return Err(MaaError::AzureMetadataApiNonJsonResponse { content_type }); + let content_type = resp + .header(CONTENT_TYPE.as_str()) + .map(|value| value.to_owned()) + .ok_or_else(|| MaaError::AzureMetadataApiNonJsonResponse { content_type: None })?; + + if !content_type.to_lowercase().starts_with("application/json") { + return Err(MaaError::AzureMetadataApiNonJsonResponse { content_type: Some(content_type) }); } match az_tdx_vtpm::is_tdx_cvm() { diff --git a/crates/attestation/src/lib.rs b/crates/attestation/src/lib.rs index cf49ded..835590b 100644 --- a/crates/attestation/src/lib.rs +++ b/crates/attestation/src/lib.rs @@ -103,18 +103,18 @@ impl AttestationType { } /// Detect what platform we are on by attempting an attestation - pub async fn detect() -> Result { + pub fn detect() -> Result { // First attempt azure, if the feature is present #[cfg(feature = "azure")] { - if azure::detect_azure_cvm().await? { + if azure::detect_azure_cvm()? { return Ok(AttestationType::AzureTdx); } } // Otherwise try DCAP quote - this internally checks that the quote provider // is `tdx_guest` if configfs_tsm::create_tdx_quote([0; 64]).is_ok() { - if running_on_gcp().await? { + if running_on_gcp()? { return Ok(AttestationType::GcpTdx); } else { return Ok(AttestationType::DcapTdx); @@ -170,8 +170,8 @@ impl AttestationGenerator { /// Detect what confidential compute platform is present and create the /// appropriate attestation generator - pub async fn detect() -> Result { - Self::new_with_detection(None, None).await + pub fn detect() -> Result { + Self::new_with_detection(None, None) } /// Do not generate attestations @@ -181,7 +181,7 @@ impl AttestationGenerator { /// Create an [AttestationGenerator] detecting the attestation type if /// it is not given - pub async fn new_with_detection( + pub fn new_with_detection( attestation_type_string: Option, attestation_provider_url: Option, ) -> Result { @@ -196,7 +196,7 @@ impl AttestationGenerator { let attestation_type_string = attestation_type_string.unwrap_or_else(|| "auto".to_string()); let attestation_type = if attestation_type_string == "auto" { tracing::info!("Doing attestation type detection..."); - AttestationType::detect().await? + AttestationType::detect()? } else { serde_json::from_value(serde_json::Value::String(attestation_type_string))? }; @@ -497,10 +497,11 @@ fn log_attestation(attestation: &AttestationExchangeMessage) { /// Test whether it looks like we are running on GCP by hitting the metadata /// API -async fn running_on_gcp() -> Result { - let client = reqwest::Client::builder().timeout(Duration::from_millis(200)).build()?; +fn running_on_gcp() -> Result { + let client = + reqwest::blocking::Client::builder().timeout(Duration::from_millis(200)).build()?; - let resp = client.get(GCP_METADATA_API).send().await; + let resp = client.get(GCP_METADATA_API).send(); if let Ok(r) = resp { return Ok(r.status().is_success() && @@ -632,16 +633,14 @@ mod tests { addr } - #[tokio::test] - async fn attestation_detection_does_not_panic() { + fn attestation_detection_does_not_panic() { // We dont enforce what platform the test is run on, only that the function // does not panic - let _ = AttestationGenerator::new_with_detection(None, None).await; + let _ = AttestationGenerator::new_with_detection(None, None); } - #[tokio::test] - async fn running_on_gcp_check_does_not_panic() { - let _ = running_on_gcp().await; + fn running_on_gcp_check_does_not_panic() { + let _ = running_on_gcp(); } #[tokio::test] From ca43025d8d126806c1e4c0cfb69d784c24e55297 Mon Sep 17 00:00:00 2001 From: 0x416e746f6e Date: Tue, 7 Apr 2026 12:07:26 +0200 Subject: [PATCH 2/7] chore: rename primary_name to subject --- crates/attestation/src/lib.rs | 2 ++ crates/attested-tls/src/lib.rs | 42 ++++++++++++++++------------------ 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/crates/attestation/src/lib.rs b/crates/attestation/src/lib.rs index 835590b..e7b8c7b 100644 --- a/crates/attestation/src/lib.rs +++ b/crates/attestation/src/lib.rs @@ -633,12 +633,14 @@ mod tests { addr } + #[test] fn attestation_detection_does_not_panic() { // We dont enforce what platform the test is run on, only that the function // does not panic let _ = AttestationGenerator::new_with_detection(None, None); } + #[test] fn running_on_gcp_check_does_not_panic() { let _ = running_on_gcp(); } diff --git a/crates/attested-tls/src/lib.rs b/crates/attested-tls/src/lib.rs index b384665..7d92aa2 100644 --- a/crates/attested-tls/src/lib.rs +++ b/crates/attested-tls/src/lib.rs @@ -88,7 +88,7 @@ struct ResolverState { /// Attestation generator used when renewing certificate attestation_generator: AttestationGenerator, /// Primary DNS name used as certificate subject / common name. - primary_name: String, + subject: String, /// DNS subject alternative names, including the primary name. subject_alt_names: Vec, } @@ -103,7 +103,7 @@ impl fmt::Debug for ResolverState { .field("key_pair_der_len", &self.key_pair_der.len()) .field("certificate_chain_len", &certificate_chain_len) .field("attestation_generator", &self.attestation_generator) - .field("primary_name", &self.primary_name) + .field("subject", &self.subject) .field("subject_alt_names", &self.subject_alt_names) .finish() } @@ -116,14 +116,14 @@ impl AttestedCertificateResolver { pub async fn new( attestation_generator: AttestationGenerator, ca: Option, - primary_name: String, + subject: String, subject_alt_names: Vec, certificate_validity_duration: Duration, ) -> Result { Self::new_with_provider( attestation_generator, ca, - primary_name, + subject, subject_alt_names, default_crypto_provider()?, certificate_validity_duration, @@ -135,7 +135,7 @@ impl AttestedCertificateResolver { pub async fn new_with_provider( attestation_generator: AttestationGenerator, ca: Option, - primary_name: String, + subject: String, subject_alt_names: Vec, provider: Arc, certificate_validity_duration: Duration, @@ -145,8 +145,7 @@ impl AttestedCertificateResolver { minimum: MIN_CERTIFICATE_VALIDITY_DURATION, }); } - let subject_alt_names = - normalized_subject_alt_names(primary_name.as_str(), subject_alt_names); + let subject_alt_names = normalized_subject_alt_names(subject.as_str(), subject_alt_names); // Generate keypair let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256)?; @@ -157,7 +156,7 @@ impl AttestedCertificateResolver { let certificate = Self::issue_ra_cert_chain( &key_pair, ca.as_ref(), - primary_name.as_str(), + subject.as_str(), &subject_alt_names, &attestation_generator, certificate_validity_duration, @@ -170,7 +169,7 @@ impl AttestedCertificateResolver { ca: ca.map(Arc::new), key_pair_der, attestation_generator, - primary_name, + subject, subject_alt_names, }); @@ -185,12 +184,12 @@ impl AttestedCertificateResolver { async fn issue_ra_cert_chain( key: &KeyPair, ca: Option<&CaCert>, - primary_name: &str, + subject: &str, subject_alt_names: &[String], attestation_generator: &AttestationGenerator, certificate_validity_duration: Duration, ) -> Result>, AttestedTlsError> { - tracing::debug!("Generating new remote-attested certificate for {primary_name}"); + tracing::debug!("Generating new remote-attested certificate for {subject}"); let pubkey = key.public_key_der(); let now = SystemTime::now(); let not_after = now + certificate_validity_duration; @@ -199,14 +198,14 @@ impl AttestedCertificateResolver { pubkey, now, not_after, - primary_name, + subject, attestation_generator, ) .await?; let cert_request = CertRequest::builder() .key(key) - .subject(primary_name) + .subject(subject) .alt_names(subject_alt_names) .not_before(now) .not_after(not_after) @@ -244,11 +243,10 @@ impl AttestedCertificateResolver { pubkey: Vec, not_before: SystemTime, not_after: SystemTime, - primary_name: &str, + subject: &str, attestation_generator: &AttestationGenerator, ) -> Result { - let report_data = - create_report_data(pubkey, not_before, not_after, primary_name.as_bytes())?; + let report_data = create_report_data(pubkey, not_before, not_after, subject.as_bytes())?; let attestation = attestation_generator.generate_attestation(report_data).await?; Ok(VersionedAttestation::V0 { attestation: Attestation { @@ -294,7 +292,7 @@ impl AttestedCertificateResolver { next_delay = match Self::issue_ra_cert_chain( &key_pair, current.ca.as_deref(), - current.primary_name.as_str(), + current.subject.as_str(), ¤t.subject_alt_names, ¤t.attestation_generator, certificate_validity_duration, @@ -352,9 +350,9 @@ fn default_crypto_provider() -> Result, AttestedTlsError> { } /// Ensures that SAN contains the primary hostname -fn normalized_subject_alt_names(primary_name: &str, subject_alt_names: Vec) -> Vec { +fn normalized_subject_alt_names(subject: &str, subject_alt_names: Vec) -> Vec { let mut normalized = Vec::with_capacity(subject_alt_names.len() + 1); - normalized.push(primary_name.to_string()); + normalized.push(subject.to_string()); for name in subject_alt_names { if !normalized.iter().any(|existing| existing == &name) { @@ -1105,13 +1103,13 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn alternate_san_completes_a_handshake() { let provider: Arc = aws_lc_rs::default_provider().into(); - let primary_name = "foo"; + let subject = "foo"; let alternate_name = "bar"; let resolver = AttestedCertificateResolver::new_with_provider( AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), None, - primary_name.to_string(), - vec![alternate_name.to_string(), primary_name.to_string()], + subject.to_string(), + vec![alternate_name.to_string(), subject.to_string()], provider.clone(), Duration::from_secs(4), ) From 0f8ff058680ad87648ba395ca82d7c53b73f8e79 Mon Sep 17 00:00:00 2001 From: 0x416e746f6e Date: Tue, 7 Apr 2026 12:15:23 +0200 Subject: [PATCH 3/7] feat: allow using custom key-pairs --- crates/attested-tls/src/lib.rs | 48 ++++++++++++++++++++----- crates/attested-tls/tests/nested_tls.rs | 2 ++ 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/crates/attested-tls/src/lib.rs b/crates/attested-tls/src/lib.rs index 7d92aa2..5a497d5 100644 --- a/crates/attested-tls/src/lib.rs +++ b/crates/attested-tls/src/lib.rs @@ -16,7 +16,7 @@ pub use ra_tls::cert::CaCert; use ra_tls::{ attestation::{Attestation, AttestationQuote, VersionedAttestation}, cert::CertRequest, - rcgen::{KeyPair, PKCS_ECDSA_P256_SHA256}, + rcgen::KeyPair, }; use rustls::{ DigitallySignedStruct, @@ -115,6 +115,7 @@ impl AttestedCertificateResolver { /// certificates will be self signed pub async fn new( attestation_generator: AttestationGenerator, + key_pair: &KeyPair, ca: Option, subject: String, subject_alt_names: Vec, @@ -122,6 +123,7 @@ impl AttestedCertificateResolver { ) -> Result { Self::new_with_provider( attestation_generator, + key_pair, ca, subject, subject_alt_names, @@ -134,6 +136,7 @@ impl AttestedCertificateResolver { /// Also provide a crypto provider pub async fn new_with_provider( attestation_generator: AttestationGenerator, + key_pair: &KeyPair, ca: Option, subject: String, subject_alt_names: Vec, @@ -147,14 +150,12 @@ impl AttestedCertificateResolver { } let subject_alt_names = normalized_subject_alt_names(subject.as_str(), subject_alt_names); - // Generate keypair - let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256)?; let key_pair_der = key_pair.serialize_der(); - let key = Self::load_signing_key(&key_pair, provider)?; + let key = Self::load_signing_key(key_pair, provider)?; // Generate initial attested certificate let certificate = Self::issue_ra_cert_chain( - &key_pair, + key_pair, ca.as_ref(), subject.as_str(), &subject_alt_names, @@ -182,7 +183,7 @@ impl AttestedCertificateResolver { /// Create an attested certificate chain - either self-signed or with /// the provided CA async fn issue_ra_cert_chain( - key: &KeyPair, + key_pair: &KeyPair, ca: Option<&CaCert>, subject: &str, subject_alt_names: &[String], @@ -190,7 +191,7 @@ impl AttestedCertificateResolver { certificate_validity_duration: Duration, ) -> Result>, AttestedTlsError> { tracing::debug!("Generating new remote-attested certificate for {subject}"); - let pubkey = key.public_key_der(); + let pubkey = key_pair.public_key_der(); let now = SystemTime::now(); let not_after = now + certificate_validity_duration; @@ -204,7 +205,7 @@ impl AttestedCertificateResolver { .await?; let cert_request = CertRequest::builder() - .key(key) + .key(key_pair) .subject(subject) .alt_names(subject_alt_names) .not_before(now) @@ -812,7 +813,13 @@ pub enum AttestedTlsError { mod tests { use std::{io::Cursor, sync::Arc}; - use ra_tls::rcgen::{BasicConstraints, CertificateParams, IsCa}; + use ra_tls::rcgen::{ + BasicConstraints, + CertificateParams, + IsCa, + KeyPair, + PKCS_ECDSA_P256_SHA256, + }; use rustls::{ CertificateError, ClientConfig, @@ -860,8 +867,10 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn certificate_resolver_creates_initial_certificate() { let provider: Arc = aws_lc_rs::default_provider().into(); + let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); let resolver = AttestedCertificateResolver::new_with_provider( AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + &key_pair, None, "foo".to_string(), vec![], @@ -879,8 +888,10 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn certificate_resolver_rejects_too_short_validity_duration() { let provider: Arc = aws_lc_rs::default_provider().into(); + let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); let error = AttestedCertificateResolver::new_with_provider( AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + &key_pair, None, "foo".to_string(), vec![], @@ -897,8 +908,10 @@ mod tests { async fn server_and_client_configs_complete_a_handshake() { let provider: Arc = aws_lc_rs::default_provider().into(); let server_name = "foo"; + let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); let resolver = AttestedCertificateResolver::new_with_provider( AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + &key_pair, None, server_name.to_string(), vec![], @@ -947,12 +960,14 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn ca_signed_server_and_client_configs_complete_a_handshake() { let provider: Arc = aws_lc_rs::default_provider().into(); + let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); let server_name = "foo"; let ca = test_ca(); let ca_cert = CertificateDer::from_pem_slice(ca.pem_cert.as_bytes()).unwrap(); let resolver = AttestedCertificateResolver::new_with_provider( AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + &key_pair, Some(ca), server_name.to_string(), vec![], @@ -1009,8 +1024,10 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn certificate_is_renewed_before_expiry() { let provider: Arc = aws_lc_rs::default_provider().into(); + let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); let resolver = AttestedCertificateResolver::new_with_provider( AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + &key_pair, None, "foo".to_string(), vec![], @@ -1033,10 +1050,12 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn server_and_client_configs_complete_a_mutual_auth_handshake() { let provider: Arc = aws_lc_rs::default_provider().into(); + let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); let server_name = "foo"; let server_resolver = AttestedCertificateResolver::new_with_provider( AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + &key_pair, None, server_name.to_string(), vec![], @@ -1048,6 +1067,7 @@ mod tests { let client_resolver = AttestedCertificateResolver::new_with_provider( AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + &key_pair, None, "client".to_string(), vec![], @@ -1103,10 +1123,12 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn alternate_san_completes_a_handshake() { let provider: Arc = aws_lc_rs::default_provider().into(); + let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); let subject = "foo"; let alternate_name = "bar"; let resolver = AttestedCertificateResolver::new_with_provider( AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + &key_pair, None, subject.to_string(), vec![alternate_name.to_string(), subject.to_string()], @@ -1264,8 +1286,10 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn self_signed_attested_certificate_with_wrong_name_is_rejected() { let provider: Arc = aws_lc_rs::default_provider().into(); + let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); let resolver = AttestedCertificateResolver::new_with_provider( AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + &key_pair, None, "foo".to_string(), vec![], @@ -1298,8 +1322,10 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn certificate_binding_changes_when_identity_changes() { let provider: Arc = aws_lc_rs::default_provider().into(); + let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); let resolver = AttestedCertificateResolver::new_with_provider( AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + &key_pair, None, "foo".to_string(), vec![], @@ -1339,8 +1365,10 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn attestation_rejection_returns_application_verification_failure() { let provider: Arc = aws_lc_rs::default_provider().into(); + let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); let resolver = AttestedCertificateResolver::new_with_provider( AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + &key_pair, None, "foo".to_string(), vec![], @@ -1373,8 +1401,10 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn verifier_reuses_trusted_certificate_cache() { let provider: Arc = aws_lc_rs::default_provider().into(); + let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); let resolver = AttestedCertificateResolver::new_with_provider( AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + &key_pair, None, "foo".to_string(), vec![], diff --git a/crates/attested-tls/tests/nested_tls.rs b/crates/attested-tls/tests/nested_tls.rs index 5d916e5..10531da 100644 --- a/crates/attested-tls/tests/nested_tls.rs +++ b/crates/attested-tls/tests/nested_tls.rs @@ -84,8 +84,10 @@ fn plain_tls_config_pair(provider: Arc) -> (ServerConfig, Client /// Create attested server TLS config with mock DCAP attestation and /// self-signed certs async fn attested_server_config(server_name: &str, provider: Arc) -> ServerConfig { + let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); let resolver = AttestedCertificateResolver::new_with_provider( AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + &key_pair, None, server_name.to_string(), vec![], From 63c31e47ef4f602524818f790d3c2cc81aa1bc4d Mon Sep 17 00:00:00 2001 From: 0x416e746f6e Date: Tue, 7 Apr 2026 13:49:29 +0200 Subject: [PATCH 4/7] fix: re-export `ra_tls::rcgen` --- crates/attested-tls/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/attested-tls/src/lib.rs b/crates/attested-tls/src/lib.rs index 5a497d5..fc387d0 100644 --- a/crates/attested-tls/src/lib.rs +++ b/crates/attested-tls/src/lib.rs @@ -12,12 +12,12 @@ pub use attestation::{ AttestationType, AttestationVerifier, }; -pub use ra_tls::cert::CaCert; use ra_tls::{ attestation::{Attestation, AttestationQuote, VersionedAttestation}, cert::CertRequest, rcgen::KeyPair, }; +pub use ra_tls::{cert::CaCert, rcgen}; use rustls::{ DigitallySignedStruct, DistinguishedName, From c6644a4872aa3c11e78f19400b3b492407b1f8a5 Mon Sep 17 00:00:00 2001 From: 0x416e746f6e Date: Tue, 7 Apr 2026 15:36:33 +0200 Subject: [PATCH 5/7] fix: make `AttestedCertificateResolver::new()` non-async --- Cargo.lock | 18 ++++++++++-- crates/attestation/Cargo.toml | 1 + crates/attestation/src/lib.rs | 39 ++++++++++++------------- crates/attested-tls/src/lib.rs | 33 +++++---------------- crates/attested-tls/tests/nested_tls.rs | 5 ++-- 5 files changed, 46 insertions(+), 50 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index da26021..304dd3a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -321,6 +321,7 @@ dependencies = [ "tokio-rustls", "tracing", "tss-esapi", + "ureq", "x509-parser 0.18.1", ] @@ -2006,7 +2007,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", - "webpki-roots", + "webpki-roots 1.0.6", ] [[package]] @@ -3364,7 +3365,7 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "web-sys", - "webpki-roots", + "webpki-roots 1.0.6", ] [[package]] @@ -4388,11 +4389,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d" dependencies = [ "base64 0.22.1", + "flate2", "log", "once_cell", + "rustls", + "rustls-pki-types", "serde", "serde_json", "url", + "webpki-roots 0.26.11", ] [[package]] @@ -4621,6 +4626,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-roots" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +dependencies = [ + "webpki-roots 1.0.6", +] + [[package]] name = "webpki-roots" version = "1.0.6" diff --git a/crates/attestation/Cargo.toml b/crates/attestation/Cargo.toml index 81fcb85..16fb6aa 100644 --- a/crates/attestation/Cargo.toml +++ b/crates/attestation/Cargo.toml @@ -26,6 +26,7 @@ base64 = "0.22.1" reqwest = { version = "0.12.23", default-features = false, features = [ "rustls-tls-webpki-roots-no-provider", ] } +ureq = "2.12.1" tracing = "0.1.41" parity-scale-codec = "3.7.5" num-bigint = "0.4.6" diff --git a/crates/attestation/src/lib.rs b/crates/attestation/src/lib.rs index e7b8c7b..3b5f1c8 100644 --- a/crates/attestation/src/lib.rs +++ b/crates/attestation/src/lib.rs @@ -7,6 +7,7 @@ pub mod measurements; use std::{ fmt::{self, Display, Formatter}, + io::Read, net::IpAddr, time::{Duration, SystemTime, UNIX_EPOCH}, }; @@ -206,12 +207,12 @@ impl AttestationGenerator { } /// Generate an attestation exchange message with given input data - pub async fn generate_attestation( + pub fn generate_attestation( &self, input_data: [u8; 64], ) -> Result { if let Some(url) = &self.attestation_provider_url { - Self::use_attestation_provider(url, self.attestation_type, input_data).await + Self::use_attestation_provider(url, self.attestation_type, input_data) } else { Ok(AttestationExchangeMessage { attestation_type: self.attestation_type, @@ -247,27 +248,29 @@ impl AttestationGenerator { /// Generate an attestation by using an external service for the /// attestation generation - async fn use_attestation_provider( + fn use_attestation_provider( url: &str, attestation_type: AttestationType, input_data: [u8; 64], ) -> Result { let url = format!("{}/attest/{}", url, hex::encode(input_data)); - let response = reqwest::get(url) - .await + let mut response = ureq::get(&url) + .timeout(Duration::from_millis(1000)) + .call() .map_err(|err| AttestationError::AttestationProvider(err.to_string()))? - .bytes() - .await - .map_err(|err| AttestationError::AttestationProvider(err.to_string()))? - .to_vec(); + .into_reader(); + let mut body = Vec::new(); + response + .read_to_end(&mut body) + .map_err(|err| AttestationError::AttestationProvider(err.to_string()))?; // If the response is not already wrapped in an attestation exchange // message, wrap it in one - if let Ok(message) = AttestationExchangeMessage::decode(&mut &response[..]) { + if let Ok(message) = AttestationExchangeMessage::decode(&mut &body[..]) { Ok(message) } else { - Ok(AttestationExchangeMessage { attestation_type, attestation: response }) + Ok(AttestationExchangeMessage { attestation_type, attestation: body }) } } } @@ -498,14 +501,12 @@ fn log_attestation(attestation: &AttestationExchangeMessage) { /// Test whether it looks like we are running on GCP by hitting the metadata /// API fn running_on_gcp() -> Result { - let client = - reqwest::blocking::Client::builder().timeout(Duration::from_millis(200)).build()?; - - let resp = client.get(GCP_METADATA_API).send(); + let agent = ureq::AgentBuilder::new().timeout(Duration::from_millis(200)).build(); + let resp = agent.get(GCP_METADATA_API).call(); if let Ok(r) = resp { - return Ok(r.status().is_success() && - r.headers().get("Metadata-Flavor").map(|v| v == "Google").unwrap_or(false)); + return Ok(r.status() == 200 && + r.header("Metadata-Flavor").map(|v| v == "Google").unwrap_or(false)); } Ok(false) @@ -645,7 +646,7 @@ mod tests { let _ = running_on_gcp(); } - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn attestation_provider_response_is_wrapped_if_needed() { let input_data = [0u8; 64]; @@ -662,7 +663,6 @@ mod tests { AttestationType::GcpTdx, input_data, ) - .await .unwrap(); assert_eq!(decoded.attestation_type, AttestationType::None); assert_eq!(decoded.attestation, vec![1, 2, 3]); @@ -674,7 +674,6 @@ mod tests { AttestationType::DcapTdx, input_data, ) - .await .unwrap(); assert_eq!(wrapped.attestation_type, AttestationType::DcapTdx); assert_eq!(wrapped.attestation, vec![9, 8]); diff --git a/crates/attested-tls/src/lib.rs b/crates/attested-tls/src/lib.rs index fc387d0..baf40ff 100644 --- a/crates/attested-tls/src/lib.rs +++ b/crates/attested-tls/src/lib.rs @@ -113,7 +113,7 @@ impl AttestedCertificateResolver { /// Create a certificate resolver with a given attestation generator /// A private certificate authority can also be given - otherwise /// certificates will be self signed - pub async fn new( + pub fn new( attestation_generator: AttestationGenerator, key_pair: &KeyPair, ca: Option, @@ -130,11 +130,10 @@ impl AttestedCertificateResolver { default_crypto_provider()?, certificate_validity_duration, ) - .await } /// Also provide a crypto provider - pub async fn new_with_provider( + pub fn new_with_provider( attestation_generator: AttestationGenerator, key_pair: &KeyPair, ca: Option, @@ -161,8 +160,7 @@ impl AttestedCertificateResolver { &subject_alt_names, &attestation_generator, certificate_validity_duration, - ) - .await?; + )?; let state = Arc::new(ResolverState { key, @@ -182,7 +180,7 @@ impl AttestedCertificateResolver { /// Create an attested certificate chain - either self-signed or with /// the provided CA - async fn issue_ra_cert_chain( + fn issue_ra_cert_chain( key_pair: &KeyPair, ca: Option<&CaCert>, subject: &str, @@ -201,8 +199,7 @@ impl AttestedCertificateResolver { not_after, subject, attestation_generator, - ) - .await?; + )?; let cert_request = CertRequest::builder() .key(key_pair) @@ -240,7 +237,7 @@ impl AttestedCertificateResolver { /// Create an attestation, and format it to be used in certificate /// extension - async fn create_attestation_payload( + fn create_attestation_payload( pubkey: Vec, not_before: SystemTime, not_after: SystemTime, @@ -248,7 +245,7 @@ impl AttestedCertificateResolver { attestation_generator: &AttestationGenerator, ) -> Result { let report_data = create_report_data(pubkey, not_before, not_after, subject.as_bytes())?; - let attestation = attestation_generator.generate_attestation(report_data).await?; + let attestation = attestation_generator.generate_attestation(report_data)?; Ok(VersionedAttestation::V0 { attestation: Attestation { quote: ra_tls::attestation::AttestationQuote::DstackTdx( @@ -297,9 +294,7 @@ impl AttestedCertificateResolver { ¤t.subject_alt_names, ¤t.attestation_generator, certificate_validity_duration, - ) - .await - { + ) { Ok(certificate) => { *current.certificate.write().expect("Certificate lock poisoned") = certificate; @@ -877,7 +872,6 @@ mod tests { provider, Duration::from_secs(4), ) - .await .unwrap(); let certificate = resolver.state.certificate.read().unwrap(); @@ -898,7 +892,6 @@ mod tests { provider, CERTIFICATE_RENEWAL_RETRY_DELAY * 3, ) - .await .unwrap_err(); assert!(matches!(error, AttestedTlsError::InvalidCertificateValidityDuration { .. })); @@ -918,7 +911,6 @@ mod tests { provider.clone(), Duration::from_secs(4), ) - .await .unwrap(); let verifier = AttestedCertificateVerifier::new_with_provider( @@ -974,7 +966,6 @@ mod tests { provider.clone(), Duration::from_secs(4), ) - .await .unwrap(); let certificate_chain = resolver.state.certificate.read().unwrap().clone(); @@ -1034,7 +1025,6 @@ mod tests { provider, Duration::from_secs(4), ) - .await .unwrap(); let initial_certificate = resolver.state.certificate.read().unwrap().first().unwrap().clone(); @@ -1062,7 +1052,6 @@ mod tests { provider.clone(), Duration::from_secs(4), ) - .await .unwrap(); let client_resolver = AttestedCertificateResolver::new_with_provider( @@ -1074,7 +1063,6 @@ mod tests { provider.clone(), Duration::from_secs(4), ) - .await .unwrap(); let server_verifier = AttestedCertificateVerifier::new_with_provider( @@ -1135,7 +1123,6 @@ mod tests { provider.clone(), Duration::from_secs(4), ) - .await .unwrap(); let verifier = AttestedCertificateVerifier::new_with_provider( None, @@ -1296,7 +1283,6 @@ mod tests { provider.clone(), Duration::from_secs(4), ) - .await .unwrap(); let verifier = AttestedCertificateVerifier::new_with_provider( None, @@ -1332,7 +1318,6 @@ mod tests { provider.clone(), Duration::from_secs(4), ) - .await .unwrap(); let original_cert = resolver.state.certificate.read().unwrap().first().unwrap().clone(); let (original_report_data, original_not_after) = @@ -1375,7 +1360,6 @@ mod tests { provider.clone(), Duration::from_secs(4), ) - .await .unwrap(); let verifier = AttestedCertificateVerifier::new_with_provider( None, @@ -1411,7 +1395,6 @@ mod tests { provider.clone(), Duration::from_secs(4), ) - .await .unwrap(); let mut verifier = AttestedCertificateVerifier::new_with_provider( None, diff --git a/crates/attested-tls/tests/nested_tls.rs b/crates/attested-tls/tests/nested_tls.rs index 10531da..10e2165 100644 --- a/crates/attested-tls/tests/nested_tls.rs +++ b/crates/attested-tls/tests/nested_tls.rs @@ -18,7 +18,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; async fn nested_tls_uses_attested_tls_for_inner_session() { let provider: Arc = aws_lc_rs::default_provider().into(); let (outer_server, outer_client) = plain_tls_config_pair(provider.clone()); - let inner_server = attested_server_config("localhost", provider.clone()).await; + let inner_server = attested_server_config("localhost", provider.clone()); let inner_client = attested_client_config(provider.clone()); let acceptor = NestingTlsAcceptor::new(Arc::new(outer_server), Arc::new(inner_server)); @@ -83,7 +83,7 @@ fn plain_tls_config_pair(provider: Arc) -> (ServerConfig, Client /// Create attested server TLS config with mock DCAP attestation and /// self-signed certs -async fn attested_server_config(server_name: &str, provider: Arc) -> ServerConfig { +fn attested_server_config(server_name: &str, provider: Arc) -> ServerConfig { let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); let resolver = AttestedCertificateResolver::new_with_provider( AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), @@ -94,7 +94,6 @@ async fn attested_server_config(server_name: &str, provider: Arc provider.clone(), std::time::Duration::from_secs(91), ) - .await .unwrap(); ServerConfig::builder_with_provider(provider) From 014c45fa8f536d98ad5b5549d08aa77b508c323d Mon Sep 17 00:00:00 2001 From: 0x416e746f6e Date: Tue, 14 Apr 2026 10:49:55 +0200 Subject: [PATCH 6/7] fix: be explicit w.r.t attestation types (if in the future we add a new attestation type, the catch-all case will backfire) --- crates/attestation/src/lib.rs | 8 +++++--- crates/attestation/src/measurements.rs | 14 +++++++++----- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/crates/attestation/src/lib.rs b/crates/attestation/src/lib.rs index 3b5f1c8..1bb423e 100644 --- a/crates/attestation/src/lib.rs +++ b/crates/attestation/src/lib.rs @@ -55,7 +55,7 @@ impl AttestationExchangeMessage { Err(AttestationError::AttestationTypeNotSupported) } } - _ => { + AttestationType::DcapTdx | AttestationType::GcpTdx | AttestationType::QemuTdx => { #[cfg(any(test, feature = "mock"))] { let quote = tdx_quote::Quote::from_bytes(&self.attestation) @@ -242,7 +242,9 @@ impl AttestationGenerator { Err(AttestationError::AttestationTypeNotSupported) } } - _ => dcap::create_dcap_attestation(input_data), + AttestationType::DcapTdx | AttestationType::GcpTdx | AttestationType::QemuTdx => { + dcap::create_dcap_attestation(input_data) + } } } @@ -394,7 +396,7 @@ impl AttestationVerifier { return Err(AttestationError::AttestationTypeNotSupported); } } - _ => { + AttestationType::DcapTdx | AttestationType::GcpTdx | AttestationType::QemuTdx => { dcap::verify_dcap_attestation( attestation_exchange_message.attestation, expected_input_data, diff --git a/crates/attestation/src/measurements.rs b/crates/attestation/src/measurements.rs index 97759f0..055de46 100644 --- a/crates/attestation/src/measurements.rs +++ b/crates/attestation/src/measurements.rs @@ -166,6 +166,7 @@ impl MultiMeasurements { let measurements_map: HashMap = serde_json::from_str(input)?; Ok(match attestation_type { + AttestationType::None => Self::NoAttestation, AttestationType::AzureTdx => Self::Azure( measurements_map .into_iter() @@ -179,8 +180,7 @@ impl MultiMeasurements { }) .collect::>()?, ), - AttestationType::None => Self::NoAttestation, - _ => { + AttestationType::DcapTdx | AttestationType::GcpTdx | AttestationType::QemuTdx => { let measurements_map = measurements_map .into_iter() .map(|(k, v)| { @@ -304,7 +304,9 @@ impl MeasurementRecord { measurements: match attestation_type { AttestationType::None => ExpectedMeasurements::NoAttestation, AttestationType::AzureTdx => ExpectedMeasurements::Azure(HashMap::new()), - _ => ExpectedMeasurements::Dcap(HashMap::new()), + AttestationType::DcapTdx | AttestationType::GcpTdx | AttestationType::QemuTdx => { + ExpectedMeasurements::Dcap(HashMap::new()) + } }, } } @@ -524,6 +526,7 @@ impl MeasurementPolicy { if let Some(measurements) = record.measurements { let expected_measurements = match attestation_type { + AttestationType::None => ExpectedMeasurements::NoAttestation, AttestationType::AzureTdx => { let azure_measurements = measurements .iter() @@ -535,8 +538,9 @@ impl MeasurementPolicy { )?; ExpectedMeasurements::Azure(azure_measurements) } - AttestationType::None => ExpectedMeasurements::NoAttestation, - _ => ExpectedMeasurements::Dcap( + AttestationType::DcapTdx | + AttestationType::GcpTdx | + AttestationType::QemuTdx => ExpectedMeasurements::Dcap( measurements .iter() .map(|(index_str, entry)| { From d49510d10858fe27bd0c6d5ef9b488db45d91bb2 Mon Sep 17 00:00:00 2001 From: 0x416e746f6e Date: Fri, 17 Apr 2026 17:34:13 +0200 Subject: [PATCH 7/7] feat: implement builder for cert resolver/verifier --- crates/attestation/Cargo.toml | 34 +- crates/attested-tls/src/lib.rs | 599 +++++++++++++----------- crates/attested-tls/tests/nested_tls.rs | 22 +- 3 files changed, 360 insertions(+), 295 deletions(-) diff --git a/crates/attestation/Cargo.toml b/crates/attestation/Cargo.toml index 16fb6aa..ffda6d1 100644 --- a/crates/attestation/Cargo.toml +++ b/crates/attestation/Cargo.toml @@ -8,31 +8,30 @@ repository = "https://github.com/flashbots/attested-tls" keywords = ["attestation", "CVM", "TDX"] [dependencies] +dcap-qvl = { workspace = true, features = ["danger-allow-tcb-override"] } pccs = { workspace = true } tokio = { workspace = true, features = ["fs"] } tokio-rustls = { workspace = true, default-features = false } -x509-parser = "0.18.0" -thiserror = "2.0.17" + anyhow = "1.0.100" -pem-rfc7468 = { version = "0.7.0", features = ["std"] } +base64 = "0.22.1" configfs-tsm = "0.0.2" -rand_core = { version = "0.6.4", features = ["getrandom"] } -dcap-qvl = { workspace = true, features = ["danger-allow-tcb-override"] } hex = "0.4.3" http = "1.3.1" -serde_json = "1.0.145" +num-bigint = "0.4.6" +once_cell = "1.21.3" +parity-scale-codec = "3.7.5" +pem-rfc7468 = { version = "0.7.0", features = ["std"] } +rand_core = { version = "0.6.4", features = ["getrandom"] } +reqwest = { version = "0.12.23", default-features = false, features = [ "rustls-tls-webpki-roots-no-provider" ] } serde = "1.0.228" -base64 = "0.22.1" -reqwest = { version = "0.12.23", default-features = false, features = [ - "rustls-tls-webpki-roots-no-provider", -] } -ureq = "2.12.1" +serde_json = "1.0.145" +thiserror = "2.0.17" +time = "0.3.47" tracing = "0.1.41" -parity-scale-codec = "3.7.5" -num-bigint = "0.4.6" +ureq = "2.12.1" webpki = { package = "rustls-webpki", version = "0.103.8" } -time = "0.3.47" -once_cell = "1.21.3" +x509-parser = "0.18.0" # Used for azure vTPM attestation support az-tdx-vtpm = { version = "0.7.4", optional = true } @@ -43,10 +42,11 @@ openssl = { version = "0.10.75", optional = true } tdx-quote = { version = "0.0.5", features = ["mock"], optional = true } [dev-dependencies] -tempfile = "3.23.0" -tdx-quote = { version = "0.0.5", features = ["mock"] } tokio-rustls = { workspace = true, default-features = true } + serde-saphyr = "0.0.22" +tdx-quote = { version = "0.0.5", features = ["mock"] } +tempfile = "3.23.0" [features] default = [] diff --git a/crates/attested-tls/src/lib.rs b/crates/attested-tls/src/lib.rs index baf40ff..41b68d8 100644 --- a/crates/attested-tls/src/lib.rs +++ b/crates/attested-tls/src/lib.rs @@ -15,7 +15,7 @@ pub use attestation::{ use ra_tls::{ attestation::{Attestation, AttestationQuote, VersionedAttestation}, cert::CertRequest, - rcgen::KeyPair, + rcgen::{KeyPair, PKCS_ECDSA_P256_SHA256}, }; pub use ra_tls::{cert::CaCert, rcgen}; use rustls::{ @@ -110,72 +110,29 @@ impl fmt::Debug for ResolverState { } impl AttestedCertificateResolver { - /// Create a certificate resolver with a given attestation generator - /// A private certificate authority can also be given - otherwise - /// certificates will be self signed - pub fn new( + /// Create a default TLS certificate resolver wrapping given attestation + /// generator + pub fn try_default( + subject: &str, attestation_generator: AttestationGenerator, - key_pair: &KeyPair, - ca: Option, - subject: String, - subject_alt_names: Vec, - certificate_validity_duration: Duration, ) -> Result { - Self::new_with_provider( - attestation_generator, - key_pair, - ca, - subject, - subject_alt_names, - default_crypto_provider()?, - certificate_validity_duration, - ) + Self::build(subject, attestation_generator).finish() } - /// Also provide a crypto provider - pub fn new_with_provider( + /// Build attested certificate resolver + pub fn build<'a, 'b>( + subject: &'b str, attestation_generator: AttestationGenerator, - key_pair: &KeyPair, - ca: Option, - subject: String, - subject_alt_names: Vec, - provider: Arc, - certificate_validity_duration: Duration, - ) -> Result { - if certificate_validity_duration < MIN_CERTIFICATE_VALIDITY_DURATION { - return Err(AttestedTlsError::InvalidCertificateValidityDuration { - minimum: MIN_CERTIFICATE_VALIDITY_DURATION, - }); - } - let subject_alt_names = normalized_subject_alt_names(subject.as_str(), subject_alt_names); - - let key_pair_der = key_pair.serialize_der(); - let key = Self::load_signing_key(key_pair, provider)?; - - // Generate initial attested certificate - let certificate = Self::issue_ra_cert_chain( - key_pair, - ca.as_ref(), - subject.as_str(), - &subject_alt_names, - &attestation_generator, - certificate_validity_duration, - )?; - - let state = Arc::new(ResolverState { - key, - certificate: RwLock::new(certificate), - ca: ca.map(Arc::new), - key_pair_der, + ) -> AttestedCertificateResolverBuilder<'a, 'b> { + AttestedCertificateResolverBuilder { attestation_generator, + ca: None, + certificate_validity: Duration::from_millis(300000), + key_pair: None, + provider: None, subject, - subject_alt_names, - }); - - // Start a loop which will periodically renew the certificate - Self::spawn_renewal_task(Arc::downgrade(&state), certificate_validity_duration); - - Ok(Self { state }) + subject_alt_names: None, + } } /// Create an attested certificate chain - either self-signed or with @@ -324,6 +281,114 @@ impl ResolvesServerCert for AttestedCertificateResolver { } } +pub struct AttestedCertificateResolverBuilder<'a, 'b> { + /// Configured to generate attestations + attestation_generator: AttestationGenerator, + /// CA to sign leaf certificates + ca: Option, + /// Duration of certificate validity + certificate_validity: Duration, + /// Key-pair to use + key_pair: Option<&'a KeyPair>, + /// Underlying cryptography provider + provider: Option>, + /// Certificate subject + subject: &'b str, + // Subject alternative names + subject_alt_names: Option>, +} + +impl<'a, 'b> AttestedCertificateResolverBuilder<'a, 'b> { + /// Use specified CA to sign leaf certificates + pub fn with_ca(mut self, ca: CaCert) -> Self { + self.ca = Some(ca); + self + } + + /// Set duration of certificates validity (default is 30 minutes) + pub fn with_certificate_validity(mut self, certificate_validity: Duration) -> Self { + self.certificate_validity = certificate_validity; + self + } + + /// Use specified key-pair (default is to use randomly generated one) + pub fn with_key_pair(mut self, key_pair: &'a KeyPair) -> Self { + self.key_pair = Some(key_pair); + self + } + + /// Use specified crypto provider + pub fn with_provider(mut self, provider: Arc) -> Self { + self.provider = Some(provider.clone()); + self + } + + /// Use specified subject alternative names on generated certificates + pub fn with_subject_alt_names(mut self, subject_alt_names: Vec) -> Self { + self.subject_alt_names = Some(subject_alt_names); + self + } + + /// Finish the build of AttestedCertificateResolver + pub fn finish(self) -> Result { + let provider = match self.provider { + None => default_crypto_provider()?, + Some(provider) => provider, + }; + + if self.certificate_validity < MIN_CERTIFICATE_VALIDITY_DURATION { + return Err(AttestedTlsError::InvalidCertificateValidityDuration { + minimum: MIN_CERTIFICATE_VALIDITY_DURATION, + }); + } + + let key_pair = match self.key_pair { + None => { + &KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).map_err(AttestedTlsError::from)? + } + Some(key_pair) => key_pair, + }; + + let key_pair_der = key_pair.serialize_der(); + let key = AttestedCertificateResolver::load_signing_key(key_pair, provider)?; + + let subject_alt_names = match self.subject_alt_names { + None => vec![self.subject.to_owned()], + Some(subject_alt_names) => { + normalized_subject_alt_names(self.subject, subject_alt_names) + } + }; + + // Generate initial attested certificate + let certificate = AttestedCertificateResolver::issue_ra_cert_chain( + key_pair, + self.ca.as_ref(), + self.subject, + &subject_alt_names, + &self.attestation_generator, + self.certificate_validity, + )?; + + let state = Arc::new(ResolverState { + key, + certificate: RwLock::new(certificate), + ca: self.ca.map(Arc::new), + key_pair_der, + attestation_generator: self.attestation_generator, + subject: self.subject.to_owned(), + subject_alt_names, + }); + + // Start a loop which will periodically renew the certificate + AttestedCertificateResolver::spawn_renewal_task( + Arc::downgrade(&state), + self.certificate_validity, + ); + + Ok(AttestedCertificateResolver { state }) + } +} + impl ResolvesClientCert for AttestedCertificateResolver { fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option> { self.current_certified_key() @@ -405,49 +470,30 @@ pub struct AttestedCertificateVerifier { } impl AttestedCertificateVerifier { - /// Create a certificate verifier with given attestation verification - /// and optionally a private CA root of trust - pub fn new( - root_store: Option, - attestation_verifier: AttestationVerifier, - ) -> Result { - Self::new_with_provider(root_store, attestation_verifier, default_crypto_provider()?) - } - - /// Also provide a crypto provider - pub fn new_with_provider( - root_store: Option, + /// Create a default TLS certificate verifier wrapping given attestation + /// verifier + pub fn try_default( attestation_verifier: AttestationVerifier, - provider: Arc, ) -> Result { - let (server_inner, client_inner) = match root_store { - Some(root_store) => { - let root_store = Arc::new(root_store); - let server_inner = WebPkiServerVerifier::builder_with_provider( - root_store.clone(), - provider.clone(), - ) - .build() - .map_err(AttestedTlsError::VerifierBuilder)?; - let client_inner = - WebPkiClientVerifier::builder_with_provider(root_store, provider.clone()) - .build() - .map_err(AttestedTlsError::VerifierBuilder)?; - - (Some(server_inner), Some(client_inner)) - } - None => (None, None), - }; - Ok(Self { - server_inner, - client_inner, - provider, + server_inner: None, + client_inner: None, + provider: default_crypto_provider()?, attestation_verifier, trusted_certificates: Default::default(), }) } + /// Create a TLS certificate verifier wrapping given attestation + /// verifier + pub fn build(attestation_verifier: AttestationVerifier) -> AttestedCertificateVerifierBuilder { + AttestedCertificateVerifierBuilder { + root_store: None, + provider: None, + attestation_verifier, + } + } + /// Given a TLS certificate, return the embedded attestation pub fn extract_custom_attestation_from_cert( cert: &CertificateDer<'_>, @@ -774,6 +820,63 @@ impl ClientCertVerifier for AttestedCertificateVerifier { } } +pub struct AttestedCertificateVerifierBuilder { + /// Configured for verifying attestations + attestation_verifier: AttestationVerifier, + /// Underlying cryptography provider + provider: Option>, + // Custom root of trust + root_store: Option>, +} + +impl AttestedCertificateVerifierBuilder { + /// Use specified crypto provider + pub fn with_provider(mut self, provider: Arc) -> Self { + self.provider = Some(provider.clone()); + self + } + + /// Use specified root of trust + pub fn with_root_store(mut self, root_store: RootCertStore) -> Self { + self.root_store = Some(Arc::new(root_store)); + self + } + + /// Finish the build of AttestedCertificateVerifier + pub fn finish(self) -> Result { + let provider = match self.provider { + None => default_crypto_provider()?, + Some(provider) => provider, + }; + + let (server_inner, client_inner) = match self.root_store { + Some(root_store) => { + let server_inner = WebPkiServerVerifier::builder_with_provider( + root_store.clone(), + provider.clone(), + ) + .build() + .map_err(AttestedTlsError::VerifierBuilder)?; + let client_inner = + WebPkiClientVerifier::builder_with_provider(root_store, provider.clone()) + .build() + .map_err(AttestedTlsError::VerifierBuilder)?; + + (Some(server_inner), Some(client_inner)) + } + None => (None, None), + }; + + Ok(AttestedCertificateVerifier { + server_inner, + client_inner, + provider, + attestation_verifier: self.attestation_verifier, + trusted_certificates: Default::default(), + }) + } +} + #[derive(Debug, Error)] pub enum AttestedTlsError { #[error("Certificate validity duration must be at least {minimum:?}")] @@ -863,15 +966,14 @@ mod tests { async fn certificate_resolver_creates_initial_certificate() { let provider: Arc = aws_lc_rs::default_provider().into(); let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); - let resolver = AttestedCertificateResolver::new_with_provider( + let resolver = AttestedCertificateResolver::build( + "foo", AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - &key_pair, - None, - "foo".to_string(), - vec![], - provider, - Duration::from_secs(4), ) + .with_key_pair(&key_pair) + .with_provider(provider) + .with_certificate_validity(Duration::from_secs(4)) + .finish() .unwrap(); let certificate = resolver.state.certificate.read().unwrap(); @@ -883,15 +985,14 @@ mod tests { async fn certificate_resolver_rejects_too_short_validity_duration() { let provider: Arc = aws_lc_rs::default_provider().into(); let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); - let error = AttestedCertificateResolver::new_with_provider( + let error = AttestedCertificateResolver::build( + "foo", AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - &key_pair, - None, - "foo".to_string(), - vec![], - provider, - CERTIFICATE_RENEWAL_RETRY_DELAY * 3, ) + .with_provider(provider) + .with_key_pair(&key_pair) + .with_certificate_validity(CERTIFICATE_RENEWAL_RETRY_DELAY * 3) + .finish() .unwrap_err(); assert!(matches!(error, AttestedTlsError::InvalidCertificateValidityDuration { .. })); @@ -902,23 +1003,20 @@ mod tests { let provider: Arc = aws_lc_rs::default_provider().into(); let server_name = "foo"; let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); - let resolver = AttestedCertificateResolver::new_with_provider( + let resolver = AttestedCertificateResolver::build( + server_name, AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - &key_pair, - None, - server_name.to_string(), - vec![], - provider.clone(), - Duration::from_secs(4), ) + .with_provider(provider.clone()) + .with_key_pair(&key_pair) + .with_certificate_validity(Duration::from_secs(4)) + .finish() .unwrap(); - let verifier = AttestedCertificateVerifier::new_with_provider( - None, - AttestationVerifier::mock(), - provider.clone(), - ) - .unwrap(); + let verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock()) + .with_provider(provider.clone()) + .finish() + .unwrap(); let server_config = ServerConfig::builder_with_provider(provider.clone()) .with_safe_default_protocol_versions() @@ -957,15 +1055,15 @@ mod tests { let ca = test_ca(); let ca_cert = CertificateDer::from_pem_slice(ca.pem_cert.as_bytes()).unwrap(); - let resolver = AttestedCertificateResolver::new_with_provider( + let resolver = AttestedCertificateResolver::build( + server_name, AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - &key_pair, - Some(ca), - server_name.to_string(), - vec![], - provider.clone(), - Duration::from_secs(4), ) + .with_provider(provider.clone()) + .with_key_pair(&key_pair) + .with_ca(ca) + .with_certificate_validity(Duration::from_secs(4)) + .finish() .unwrap(); let certificate_chain = resolver.state.certificate.read().unwrap().clone(); @@ -975,12 +1073,11 @@ mod tests { let mut roots = RootCertStore::empty(); roots.add(ca_cert).unwrap(); - let verifier = AttestedCertificateVerifier::new_with_provider( - Some(roots), - AttestationVerifier::mock(), - provider.clone(), - ) - .unwrap(); + let verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock()) + .with_provider(provider.clone()) + .with_root_store(roots) + .finish() + .unwrap(); let server_config = ServerConfig::builder_with_provider(provider.clone()) .with_safe_default_protocol_versions() @@ -1016,15 +1113,14 @@ mod tests { async fn certificate_is_renewed_before_expiry() { let provider: Arc = aws_lc_rs::default_provider().into(); let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); - let resolver = AttestedCertificateResolver::new_with_provider( + let resolver = AttestedCertificateResolver::build( + "foo", AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - &key_pair, - None, - "foo".to_string(), - vec![], - provider, - Duration::from_secs(4), ) + .with_provider(provider.clone()) + .with_key_pair(&key_pair) + .with_certificate_validity(Duration::from_secs(4)) + .finish() .unwrap(); let initial_certificate = resolver.state.certificate.read().unwrap().first().unwrap().clone(); @@ -1043,40 +1139,34 @@ mod tests { let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); let server_name = "foo"; - let server_resolver = AttestedCertificateResolver::new_with_provider( + let server_resolver = AttestedCertificateResolver::build( + server_name, AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - &key_pair, - None, - server_name.to_string(), - vec![], - provider.clone(), - Duration::from_secs(4), ) + .with_provider(provider.clone()) + .with_key_pair(&key_pair) + .with_certificate_validity(Duration::from_secs(4)) + .finish() .unwrap(); - let client_resolver = AttestedCertificateResolver::new_with_provider( + let client_resolver = AttestedCertificateResolver::build( + "client", AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - &key_pair, - None, - "client".to_string(), - vec![], - provider.clone(), - Duration::from_secs(4), ) + .with_provider(provider.clone()) + .with_key_pair(&key_pair) + .with_certificate_validity(Duration::from_secs(4)) + .finish() .unwrap(); - let server_verifier = AttestedCertificateVerifier::new_with_provider( - None, - AttestationVerifier::mock(), - provider.clone(), - ) - .unwrap(); - let client_verifier = AttestedCertificateVerifier::new_with_provider( - None, - AttestationVerifier::mock(), - provider.clone(), - ) - .unwrap(); + let server_verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock()) + .with_provider(provider.clone()) + .finish() + .unwrap(); + let client_verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock()) + .with_provider(provider.clone()) + .finish() + .unwrap(); let server_config = ServerConfig::builder_with_provider(provider.clone()) .with_safe_default_protocol_versions() @@ -1114,22 +1204,20 @@ mod tests { let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); let subject = "foo"; let alternate_name = "bar"; - let resolver = AttestedCertificateResolver::new_with_provider( + let resolver = AttestedCertificateResolver::build( + subject, AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - &key_pair, - None, - subject.to_string(), - vec![alternate_name.to_string(), subject.to_string()], - provider.clone(), - Duration::from_secs(4), - ) - .unwrap(); - let verifier = AttestedCertificateVerifier::new_with_provider( - None, - AttestationVerifier::mock(), - provider.clone(), ) + .with_subject_alt_names(vec![alternate_name.to_string(), subject.to_string()]) + .with_provider(provider.clone()) + .with_key_pair(&key_pair) + .with_certificate_validity(Duration::from_secs(4)) + .finish() .unwrap(); + let verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock()) + .with_provider(provider.clone()) + .finish() + .unwrap(); let server_config = ServerConfig::builder_with_provider(provider.clone()) .with_safe_default_protocol_versions() @@ -1162,12 +1250,10 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn malformed_certificate_returns_bad_encoding() { let provider: Arc = aws_lc_rs::default_provider().into(); - let verifier = AttestedCertificateVerifier::new_with_provider( - None, - AttestationVerifier::mock(), - provider, - ) - .unwrap(); + let verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock()) + .with_provider(provider) + .finish() + .unwrap(); let cert = CertificateDer::from(vec![1_u8, 2, 3, 4]); let result = verify_server_cert_direct( @@ -1186,12 +1272,11 @@ mod tests { let cert = plain_self_signed_certificate("foo"); let mut roots = RootCertStore::empty(); roots.add(cert.clone()).unwrap(); - let verifier = AttestedCertificateVerifier::new_with_provider( - Some(roots), - AttestationVerifier::mock(), - provider, - ) - .unwrap(); + let verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock()) + .with_provider(provider) + .with_root_store(roots) + .finish() + .unwrap(); let result = verify_server_cert_direct( &verifier, @@ -1208,26 +1293,23 @@ mod tests { let provider: Arc = aws_lc_rs::default_provider().into(); let ca = test_ca(); let ca_cert = CertificateDer::from_pem_slice(ca.pem_cert.as_bytes()).unwrap(); - let resolver = AttestedCertificateResolver::new_with_provider( + let resolver = AttestedCertificateResolver::build( + "foo", AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - None, - "foo".to_string(), - vec![], - provider.clone(), - Duration::from_secs(4), ) - .await + .with_provider(provider.clone()) + .with_certificate_validity(Duration::from_secs(4)) + .finish() .unwrap(); let cert = resolver.state.certificate.read().unwrap().first().unwrap().clone(); let mut roots = RootCertStore::empty(); roots.add(ca_cert).unwrap(); - let verifier = AttestedCertificateVerifier::new_with_provider( - Some(roots), - AttestationVerifier::mock(), - provider, - ) - .unwrap(); + let verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock()) + .with_provider(provider) + .with_root_store(roots) + .finish() + .unwrap(); let result = verify_server_cert_direct( &verifier, @@ -1244,26 +1326,23 @@ mod tests { let provider: Arc = aws_lc_rs::default_provider().into(); let ca = test_ca(); let ca_cert = CertificateDer::from_pem_slice(ca.pem_cert.as_bytes()).unwrap(); - let resolver = AttestedCertificateResolver::new_with_provider( + let resolver = AttestedCertificateResolver::build( + "client", AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - None, - "client".to_string(), - vec![], - provider.clone(), - Duration::from_secs(4), ) - .await + .with_provider(provider.clone()) + .with_certificate_validity(Duration::from_secs(4)) + .finish() .unwrap(); let cert = resolver.state.certificate.read().unwrap().first().unwrap().clone(); let mut roots = RootCertStore::empty(); roots.add(ca_cert).unwrap(); - let verifier = AttestedCertificateVerifier::new_with_provider( - Some(roots), - AttestationVerifier::mock(), - provider, - ) - .unwrap(); + let verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock()) + .with_provider(provider) + .with_root_store(roots) + .finish() + .unwrap(); let result = verify_client_cert_direct(&verifier, &cert, UnixTime::now()); @@ -1274,22 +1353,19 @@ mod tests { async fn self_signed_attested_certificate_with_wrong_name_is_rejected() { let provider: Arc = aws_lc_rs::default_provider().into(); let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); - let resolver = AttestedCertificateResolver::new_with_provider( + let resolver = AttestedCertificateResolver::build( + "foo", AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - &key_pair, - None, - "foo".to_string(), - vec![], - provider.clone(), - Duration::from_secs(4), - ) - .unwrap(); - let verifier = AttestedCertificateVerifier::new_with_provider( - None, - AttestationVerifier::mock(), - provider, ) + .with_provider(provider.clone()) + .with_key_pair(&key_pair) + .with_certificate_validity(Duration::from_secs(4)) + .finish() .unwrap(); + let verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock()) + .with_provider(provider) + .finish() + .unwrap(); let cert = resolver.state.certificate.read().unwrap().first().unwrap().clone(); let result = verify_server_cert_direct( @@ -1309,15 +1385,14 @@ mod tests { async fn certificate_binding_changes_when_identity_changes() { let provider: Arc = aws_lc_rs::default_provider().into(); let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); - let resolver = AttestedCertificateResolver::new_with_provider( + let resolver = AttestedCertificateResolver::build( + "foo", AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - &key_pair, - None, - "foo".to_string(), - vec![], - provider.clone(), - Duration::from_secs(4), ) + .with_provider(provider.clone()) + .with_key_pair(&key_pair) + .with_certificate_validity(Duration::from_secs(4)) + .finish() .unwrap(); let original_cert = resolver.state.certificate.read().unwrap().first().unwrap().clone(); let (original_report_data, original_not_after) = @@ -1351,22 +1426,19 @@ mod tests { async fn attestation_rejection_returns_application_verification_failure() { let provider: Arc = aws_lc_rs::default_provider().into(); let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); - let resolver = AttestedCertificateResolver::new_with_provider( + let resolver = AttestedCertificateResolver::build( + "foo", AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - &key_pair, - None, - "foo".to_string(), - vec![], - provider.clone(), - Duration::from_secs(4), - ) - .unwrap(); - let verifier = AttestedCertificateVerifier::new_with_provider( - None, - AttestationVerifier::expect_none(), - provider, ) + .with_provider(provider.clone()) + .with_key_pair(&key_pair) + .with_certificate_validity(Duration::from_secs(4)) + .finish() .unwrap(); + let verifier = AttestedCertificateVerifier::build(AttestationVerifier::expect_none()) + .with_provider(provider) + .finish() + .unwrap(); let cert = resolver.state.certificate.read().unwrap().first().unwrap().clone(); let result = verify_server_cert_direct( @@ -1386,22 +1458,19 @@ mod tests { async fn verifier_reuses_trusted_certificate_cache() { let provider: Arc = aws_lc_rs::default_provider().into(); let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); - let resolver = AttestedCertificateResolver::new_with_provider( + let resolver = AttestedCertificateResolver::build( + "foo", AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - &key_pair, - None, - "foo".to_string(), - vec![], - provider.clone(), - Duration::from_secs(4), - ) - .unwrap(); - let mut verifier = AttestedCertificateVerifier::new_with_provider( - None, - AttestationVerifier::mock(), - provider, ) + .with_provider(provider.clone()) + .with_key_pair(&key_pair) + .with_certificate_validity(Duration::from_secs(4)) + .finish() .unwrap(); + let mut verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock()) + .with_provider(provider) + .finish() + .unwrap(); let cert = resolver.state.certificate.read().unwrap().first().unwrap().clone(); let (expected_input_data, not_after) = AttestedCertificateVerifier::cert_binding_data(&cert).unwrap(); diff --git a/crates/attested-tls/tests/nested_tls.rs b/crates/attested-tls/tests/nested_tls.rs index 10e2165..068d54b 100644 --- a/crates/attested-tls/tests/nested_tls.rs +++ b/crates/attested-tls/tests/nested_tls.rs @@ -85,15 +85,13 @@ fn plain_tls_config_pair(provider: Arc) -> (ServerConfig, Client /// self-signed certs fn attested_server_config(server_name: &str, provider: Arc) -> ServerConfig { let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); - let resolver = AttestedCertificateResolver::new_with_provider( + let resolver = AttestedCertificateResolver::build( + server_name, AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), - &key_pair, - None, - server_name.to_string(), - vec![], - provider.clone(), - std::time::Duration::from_secs(91), ) + .with_provider(provider.clone()) + .with_key_pair(&key_pair) + .finish() .unwrap(); ServerConfig::builder_with_provider(provider) @@ -105,12 +103,10 @@ fn attested_server_config(server_name: &str, provider: Arc) -> S /// Create client TLS config with attestation verification fn attested_client_config(provider: Arc) -> ClientConfig { - let verifier = AttestedCertificateVerifier::new_with_provider( - None, - AttestationVerifier::mock(), - provider.clone(), - ) - .unwrap(); + let verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock()) + .with_provider(provider.clone()) + .finish() + .unwrap(); ClientConfig::builder_with_provider(provider) .with_safe_default_protocol_versions()