Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions attested-tls/src/attestation/measurements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,18 @@ impl MeasurementPolicy {
}
}

/// Accept any TDX attestation regardless of platform
pub fn tdx() -> Self {
Self {
accepted_measurements: vec![
MeasurementRecord::allow_any_measurement(AttestationType::DcapTdx),
MeasurementRecord::allow_any_measurement(AttestationType::QemuTdx),
MeasurementRecord::allow_any_measurement(AttestationType::GcpTdx),
MeasurementRecord::allow_any_measurement(AttestationType::AzureTdx),
],
}
}

/// Expect mock measurements used in tests
#[cfg(any(test, feature = "mock"))]
pub fn mock() -> Self {
Expand Down
11 changes: 6 additions & 5 deletions attested-tls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,9 @@ impl AttestedTlsClient {
pub async fn get_tls_cert(
&self,
server_name: &str,
) -> Result<Vec<CertificateDer<'static>>, AttestedTlsError> {
let (mut tls_stream, _, _) = self.connect_tcp(server_name).await?;
) -> Result<(Vec<CertificateDer<'static>>, Option<MultiMeasurements>), AttestedTlsError> {
let (mut tls_stream, measurements, _attestation_type) =
self.connect_tcp(server_name).await?;

let (_io, server_connection) = tls_stream.get_ref();

Expand All @@ -431,7 +432,7 @@ impl AttestedTlsClient {

tls_stream.shutdown().await?;

Ok(remote_cert_chain)
Ok((remote_cert_chain, measurements))
}
}

Expand All @@ -440,7 +441,7 @@ pub async fn get_tls_cert(
server_name: String,
attestation_verifier: AttestationVerifier,
remote_certificate: Option<CertificateDer<'static>>,
) -> Result<Vec<CertificateDer<'static>>, AttestedTlsError> {
) -> Result<(Vec<CertificateDer<'static>>, Option<MultiMeasurements>), AttestedTlsError> {
tracing::debug!("Getting remote TLS cert");
let attested_tls_client = AttestedTlsClient::new(
None,
Expand All @@ -458,7 +459,7 @@ pub async fn get_tls_cert_with_config(
server_name: &str,
attestation_verifier: AttestationVerifier,
client_config: ClientConfig,
) -> Result<Vec<CertificateDer<'static>>, AttestedTlsError> {
) -> Result<(Vec<CertificateDer<'static>>, Option<MultiMeasurements>), AttestedTlsError> {
let attested_tls_client = AttestedTlsClient::new_with_tls_config(
client_config,
AttestationGenerator::with_no_attestation(),
Expand Down
15 changes: 9 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,21 @@ pub async fn get_tls_cert(
attestation_verifier: AttestationVerifier,
remote_certificate: Option<CertificateDer<'static>>,
allow_self_signed: bool,
) -> Result<Vec<CertificateDer<'static>>, AttestedTlsError> {
if allow_self_signed {
) -> Result<(Vec<CertificateDer<'static>>, Option<MultiMeasurements>), AttestedTlsError> {
let (cert, measurements) = if allow_self_signed {
let client_tls_config = self_signed::client_tls_config_allow_self_signed()?;
attested_tls::get_tls_cert_with_config(
&server_name,
attestation_verifier,
client_tls_config,
)
.await
.await?
} else {
attested_tls::get_tls_cert(server_name, attestation_verifier, remote_certificate).await
}
attested_tls::get_tls_cert(server_name, attestation_verifier, remote_certificate).await?
};

debug!("[get-tls-cert] Connected to proxy server with measurements: {measurements:?}");
Ok((cert, measurements))
}

/// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address
Expand Down Expand Up @@ -1114,7 +1117,7 @@ mod tests {
proxy_server.accept().await.unwrap();
});

let retrieved_chain = get_tls_cert_with_config(
let (retrieved_chain, _measurements) = get_tls_cert_with_config(
&proxy_server_addr.to_string(),
AttestationVerifier::mock(),
client_config,
Expand Down
38 changes: 32 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use anyhow::{anyhow, ensure};
use attested_tls::attestation::measurements::MultiMeasurements;
use clap::{Parser, Subcommand};
use std::{
fs::File,
Expand Down Expand Up @@ -126,6 +127,9 @@ enum CliCommand {
/// Enables verification of self-signed TLS certificates
#[arg(long)]
allow_self_signed: bool,
/// Filename to write measurements as JSON to
#[arg(long)]
out_measurements: Option<PathBuf>,
},
/// Serve a filesystem path over an attested channel
AttestedFileServer {
Expand Down Expand Up @@ -201,12 +205,22 @@ async fn main() -> anyhow::Result<()> {
MeasurementPolicy::from_file_or_url(server_measurements).await?
}
None => {
let allowed_server_attestation_type: AttestationType = serde_json::from_value(
serde_json::Value::String(cli.allowed_remote_attestation_type.ok_or(anyhow!(
match cli
.allowed_remote_attestation_type
.ok_or(anyhow!(
"Either a measurements file or an allowed attestation type must be provided"
))?),
)?;
MeasurementPolicy::single_attestation_type(allowed_server_attestation_type)
))?
.to_lowercase()
.as_str()
{
"tdx" => MeasurementPolicy::tdx(),
attestation_type => {
let allowed_server_attestation_type: AttestationType = serde_json::from_value(
serde_json::Value::String(attestation_type.to_string()),
)?;
MeasurementPolicy::single_attestation_type(allowed_server_attestation_type)
}
}
}
};

Expand Down Expand Up @@ -340,6 +354,7 @@ async fn main() -> anyhow::Result<()> {
server,
tls_ca_certificate,
allow_self_signed,
out_measurements,
} => {
let remote_tls_cert = match tls_ca_certificate {
Some(remote_cert_filename) => Some(
Expand All @@ -350,13 +365,24 @@ async fn main() -> anyhow::Result<()> {
),
None => None,
};
let cert_chain = get_tls_cert(
let (cert_chain, measurements) = get_tls_cert(
server,
attestation_verifier,
remote_tls_cert,
allow_self_signed,
)
.await?;

// If the user chose to write measurements to a file as JSON
if let Some(path_to_write_measurements) = out_measurements {
std::fs::write(
path_to_write_measurements,
measurements
.unwrap_or(MultiMeasurements::NoAttestation)
.to_header_format()?
.as_bytes(),
)?;
}
println!("{}", certs_to_pem_string(&cert_chain)?);
}
CliCommand::AttestedFileServer {
Expand Down