Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 26 additions & 50 deletions boring/src/mlkem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,14 @@ impl MlKemPrivateKey {
pub fn generate(algorithm: Algorithm) -> Result<(MlKemPublicKey, MlKemPrivateKey), ErrorStack> {
match algorithm {
Algorithm::MlKem768 => {
let (pk, sk) = MlKem768PrivateKey::generate();
let (pk, sk) = MlKem768PrivateKey::generate()?;
Ok((
MlKemPublicKey(Either::MlKem768(pk)),
MlKemPrivateKey(Either::MlKem768(sk)),
))
}
Algorithm::MlKem1024 => {
let (pk, sk) = MlKem1024PrivateKey::generate();
let (pk, sk) = MlKem1024PrivateKey::generate()?;
Ok((
MlKemPublicKey(Either::MlKem1024(pk)),
MlKemPrivateKey(Either::MlKem1024(sk)),
Expand Down Expand Up @@ -222,39 +222,27 @@ impl MlKem768PrivateKey {
pub const CIPHERTEXT_BYTES: usize = ffi::MLKEM768_CIPHERTEXT_BYTES as usize;

/// Generate a new key pair.
#[must_use]
fn generate() -> (Box<MlKem768PublicKey>, Box<MlKem768PrivateKey>) {
fn generate() -> Result<(Box<MlKem768PublicKey>, Box<MlKem768PrivateKey>), ErrorStack> {
// SAFETY: all buffers are out parameters, correctly sized
unsafe {
ffi::init();
let mut public_key_bytes: MaybeUninit<[u8; MlKem768PublicKey::PUBLIC_KEY_BYTES]> =
MaybeUninit::uninit();
let mut seed: MaybeUninit<MlKemPrivateKeySeed> = MaybeUninit::uninit();
let mut bytes = [0; MlKem768PublicKey::PUBLIC_KEY_BYTES];
let mut seed = [0; PRIVATE_KEY_SEED_BYTES];
let mut expanded: MaybeUninit<ffi::MLKEM768_private_key> = MaybeUninit::uninit();

ffi::MLKEM768_generate_key(
public_key_bytes.as_mut_ptr().cast(),
seed.as_mut_ptr().cast(),
bytes.as_mut_ptr().cast(),
seed.as_mut_ptr(),
expanded.as_mut_ptr(),
);

let bytes = public_key_bytes.assume_init();

// Parse the public key bytes to get the parsed struct
let mut cbs = cbs_init(&bytes);
let mut parsed: MaybeUninit<ffi::MLKEM768_public_key> = MaybeUninit::uninit();
ffi::MLKEM768_parse_public_key(parsed.as_mut_ptr(), &mut cbs);

(
Box::new(MlKem768PublicKey {
bytes,
parsed: parsed.assume_init(),
}),
Ok((
Box::new(MlKem768PublicKey::from_slice(&bytes)?),
Box::new(MlKem768PrivateKey {
seed: seed.assume_init(),
seed,
expanded: expanded.assume_init(),
}),
)
))
}
}

Expand Down Expand Up @@ -439,39 +427,27 @@ impl MlKem1024PrivateKey {
pub const CIPHERTEXT_BYTES: usize = ffi::MLKEM1024_CIPHERTEXT_BYTES as usize;

/// Generate a new key pair.
#[must_use]
fn generate() -> (Box<MlKem1024PublicKey>, Box<MlKem1024PrivateKey>) {
fn generate() -> Result<(Box<MlKem1024PublicKey>, Box<MlKem1024PrivateKey>), ErrorStack> {
// SAFETY: all buffers are out parameters, correctly sized
unsafe {
ffi::init();
let mut public_key_bytes: MaybeUninit<[u8; MlKem1024PublicKey::PUBLIC_KEY_BYTES]> =
MaybeUninit::uninit();
let mut seed: MaybeUninit<MlKemPrivateKeySeed> = MaybeUninit::uninit();
let mut bytes = [0; MlKem1024PublicKey::PUBLIC_KEY_BYTES];
let mut seed = [0; PRIVATE_KEY_SEED_BYTES];
let mut expanded: MaybeUninit<ffi::MLKEM1024_private_key> = MaybeUninit::uninit();

ffi::MLKEM1024_generate_key(
public_key_bytes.as_mut_ptr().cast(),
seed.as_mut_ptr().cast(),
bytes.as_mut_ptr().cast(),
seed.as_mut_ptr(),
expanded.as_mut_ptr(),
);

let bytes = public_key_bytes.assume_init();

// Parse the public key bytes to get the parsed struct
let mut cbs = cbs_init(&bytes);
let mut parsed: MaybeUninit<ffi::MLKEM1024_public_key> = MaybeUninit::uninit();
ffi::MLKEM1024_parse_public_key(parsed.as_mut_ptr(), &mut cbs);

(
Box::new(MlKem1024PublicKey {
bytes,
parsed: parsed.assume_init(),
}),
Ok((
Box::new(MlKem1024PublicKey::from_slice(&bytes)?),
Box::new(MlKem1024PrivateKey {
seed: seed.assume_init(),
seed,
expanded: expanded.assume_init(),
}),
)
))
}
}

Expand Down Expand Up @@ -649,15 +625,15 @@ mod tests {

#[test]
fn roundtrip() {
let (pk, sk) = <$priv>::generate();
let (pk, sk) = <$priv>::generate().unwrap();
let (ct, ss1) = pk.encapsulate();
let ss2 = sk.decapsulate(&ct);
assert_eq!(ss1, ss2);
}

#[test]
fn seed_roundtrip() {
let (pk, sk) = <$priv>::generate();
let (pk, sk) = <$priv>::generate().unwrap();
let sk2 = <$priv>::from_seed(&sk.seed).unwrap();
let (ct, ss1) = pk.encapsulate();
let ss2 = sk2.decapsulate(&ct);
Expand All @@ -666,7 +642,7 @@ mod tests {

#[test]
fn derive_pubkey() {
let (pk, sk) = <$priv>::generate();
let (pk, sk) = <$priv>::generate().unwrap();
assert_eq!(pk.bytes, sk.public_key().unwrap().bytes);
}

Expand All @@ -678,14 +654,14 @@ mod tests {

#[test]
fn from_slice_roundtrip() {
let (pk, _) = <$priv>::generate();
let (pk, _) = <$priv>::generate().unwrap();
let pk2 = <$pub>::from_slice(&pk.bytes).unwrap();
assert_eq!(pk.bytes, pk2.bytes);
}

#[test]
fn implicit_rejection() {
let (_, sk) = <$priv>::generate();
let (_, sk) = <$priv>::generate().unwrap();
let bad_ct = [0x42u8; $ct_len];
// bad ciphertext still "works", just returns deterministic garbage
let ss1 = sk.decapsulate(&bad_ct);
Expand All @@ -695,7 +671,7 @@ mod tests {

#[test]
fn debug_redacts_seed() {
let (_, sk) = <$priv>::generate();
let (_, sk) = <$priv>::generate().unwrap();
let dbg = format!("{:?}", sk);
assert!(dbg.contains("redacted"));
}
Expand Down
Loading