Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion ml-kem/src/decapsulation_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
B32, EncapsulationKey, Seed, SharedKey,
crypto::{G, J},
param::{DecapsulationKeySize, ExpandedDecapsulationKey, KemParams},
pke::{DecryptionKey, EncryptionKey},
pke::{Ciphertext1, Ciphertext2, DecryptionKey, EncryptionKey},
};
use array::{
Array, ArraySize,
Expand Down Expand Up @@ -110,6 +110,16 @@ where
let d = Some(d);
Self { dk_pke, ek, d, z }
}

/// Decapsulates the given [`Ciphertext1`] and [`Ciphertext2`] a.k.a. "incremental encapsulated key".
pub fn decapsulate_incremental(&self, c1: &Ciphertext1<P>, c2: &Ciphertext2<P>) -> SharedKey {
let mp = self.dk_pke.decrypt_split(c1, c2);
let (Kp, rp) = G(&[&mp, &self.ek.h()]);
let Kbar = J(&[self.z.as_slice(), c1.as_ref(), c2.as_ref()]);
let cp = self.ek.ek_pke().encrypt(&mp, &rp);
let (cp1, cp2) = P::split_ct(&cp);
B32::conditional_select(&Kbar, &Kp, cp1.ct_eq(c1) & cp2.ct_eq(c2))
}
}

// Handwritten to omit `d` in the comparisons, so keys initialized from seeds compare equally to
Expand Down
72 changes: 71 additions & 1 deletion ml-kem/src/encapsulation_key.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,45 @@
use crate::{
B32, SharedKey,
algebra::NttVector,
crypto::{G, H},
kem::{InvalidKey, Kem, Key, KeyExport, KeySizeUser, TryKeyInit},
param::{EncapsulationKeySize, KemParams},
pke::EncryptionKey,
pke::{Ciphertext1, Ciphertext2, EncryptionKey},
};
use array::sizes::U32;
use kem::{Ciphertext, Encapsulate, Generate};
use rand_core::CryptoRng;

#[cfg(feature = "zeroize")]
use zeroize::{Zeroize, ZeroizeOnDrop};

/// A temporary secret produced by the first incremental encapsulation step,
/// to be used by the second one to finish encapsulation.
#[derive(Clone, Debug)]
pub struct EncapsulationSecret<P>
where
P: KemParams,
{
m: B32,
r: B32,
es: NttVector<P::K>,
}

#[cfg(feature = "zeroize")]
impl<P> Drop for EncapsulationSecret<P>
where
P: KemParams,
{
fn drop(&mut self) {
self.m.zeroize();
self.r.zeroize();
self.es.zeroize();
}
}

#[cfg(feature = "zeroize")]
impl<P> ZeroizeOnDrop for EncapsulationSecret<P> where P: KemParams {}

/// An `EncapsulationKey` provides the ability to encapsulate a shared key so that it can only be
/// decapsulated by the holder of the corresponding decapsulation key.
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -61,6 +92,45 @@ where
pub(crate) fn h(&self) -> B32 {
self.h
}

/// Encapsulates incrementally with the given randomness. This is useful for testing against known vectors.
///
/// # Warning
/// Do NOT use this function unless you know what you're doing. If you fail to use all uniform
/// random bytes even once, you can have catastrophic security failure.
#[cfg_attr(not(feature = "hazmat"), doc(hidden))]
pub fn encapsulate_incremental_1_deterministic(
&self,
m: &B32,
) -> (Ciphertext1<P>, EncapsulationSecret<P>, SharedKey) {
let (K, r) = G(&[m, &self.h]);
let (c1, es) = self.ek_pke.encrypt_incremental_1(&r);
(c1, EncapsulationSecret { m: *m, r, es }, K)
}

/// Finish incremental encapsulation.
pub fn encapsulate_incremental_2(
&self,
encapsulation_secret: EncapsulationSecret<P>,
) -> Ciphertext2<P> {
self.ek_pke.encrypt_incremental_2(
&encapsulation_secret.m,
&encapsulation_secret.r,
&encapsulation_secret.es,
)
}

/// Encapsulates incrementally a fresh [`SharedKey`] generated using the supplied random number generator `R`.
pub fn encapsulate_incremental_1_with_rng<R>(
&self,
rng: &mut R,
) -> (Ciphertext1<P>, EncapsulationSecret<P>, SharedKey)
where
R: CryptoRng + ?Sized,
{
let m = B32::generate_from_rng(rng);
self.encapsulate_incremental_1_deterministic(&m)
}
}

impl<P> Encapsulate for EncapsulationKey<P>
Expand Down
34 changes: 31 additions & 3 deletions ml-kem/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ pub use array::{self, ArraySize};
pub use decapsulation_key::DecapsulationKey;
#[allow(deprecated)]
pub use decapsulation_key::ExpandedKeyEncoding;
pub use encapsulation_key::EncapsulationKey;
pub use encapsulation_key::{EncapsulationKey, EncapsulationSecret};
pub use kem::{
self, Ciphertext, Decapsulate, Encapsulate, FromSeed, Generate, InvalidKey, Kem, Key,
KeyExport, KeyInit, KeySizeUser, TryKeyInit,
Expand All @@ -81,6 +81,7 @@ pub use ml_kem_512::MlKem512;
pub use ml_kem_768::MlKem768;
pub use ml_kem_1024::MlKem1024;
pub use param::{ExpandedDecapsulationKey, ParameterSet};
pub use pke::{Ciphertext1, Ciphertext2};

use array::{
Array,
Expand Down Expand Up @@ -133,9 +134,18 @@ pub mod ml_kem_512 {
/// can only be decapsulated by the holder of the corresponding decapsulation key.
pub type EncapsulationKey = crate::EncapsulationKey<MlKem512>;

/// An ML-KEM-512 `EncapsulationSecret` is a temporary secret used for incremental encapsulation.
pub type EncapsulationSecret = crate::EncapsulationSecret<MlKem512>;

/// Encoded ML-KEM-512 ciphertexts.
pub type Ciphertext = kem::Ciphertext<MlKem512>;

/// Encoded ML-KEM-512 first incremental ciphertexts.
pub type Ciphertext1 = crate::Ciphertext1<MlKem512>;

/// Encoded ML-KEM-512 second incrementalciphertexts.
pub type Ciphertext2 = crate::Ciphertext2<MlKem512>;

/// Legacy expanded decapsulation keys. Prefer seeds instead.
#[doc(hidden)]
#[deprecated(since = "0.3.0", note = "use `Seed` instead")]
Expand Down Expand Up @@ -181,9 +191,18 @@ pub mod ml_kem_768 {
/// can only be decapsulated by the holder of the corresponding decapsulation key.
pub type EncapsulationKey = crate::EncapsulationKey<MlKem768>;

/// Encoded ML-KEM-512 ciphertexts.
/// An ML-KEM-768 `EncapsulationSecret` is a temporary secret used for incremental encapsulation.
pub type EncapsulationSecret = crate::EncapsulationSecret<MlKem768>;

/// Encoded ML-KEM-768 ciphertexts.
pub type Ciphertext = kem::Ciphertext<MlKem768>;

/// Encoded ML-KEM-768 first incremental ciphertexts.
pub type Ciphertext1 = crate::Ciphertext1<MlKem768>;

/// Encoded ML-KEM-768 second incrementalciphertexts.
pub type Ciphertext2 = crate::Ciphertext2<MlKem768>;

/// Legacy expanded decapsulation keys. Prefer seeds instead.
#[doc(hidden)]
#[deprecated(since = "0.3.0", note = "use `Seed` instead")]
Expand Down Expand Up @@ -228,9 +247,18 @@ pub mod ml_kem_1024 {
/// it can only be decapsulated by the holder of the corresponding decapsulation key.
pub type EncapsulationKey = crate::EncapsulationKey<MlKem1024>;

/// Encoded ML-KEM-512 ciphertexts.
/// An ML-KEM-1024 `EncapsulationSecret` is a temporary secret used for incremental encapsulation.
pub type EncapsulationSecret = crate::EncapsulationSecret<MlKem1024>;

/// Encoded ML-KEM-1024 ciphertexts.
pub type Ciphertext = kem::Ciphertext<MlKem1024>;

/// Encoded ML-KEM-1024 first incremental ciphertexts.
pub type Ciphertext1 = crate::Ciphertext1<MlKem1024>;

/// Encoded ML-KEM-1024 second incrementalciphertexts.
pub type Ciphertext2 = crate::Ciphertext2<MlKem1024>;

/// Legacy expanded decapsulation keys. Prefer seeds instead.
#[doc(hidden)]
#[deprecated(since = "0.3.0", note = "use `Seed` instead")]
Expand Down
48 changes: 48 additions & 0 deletions ml-kem/src/pke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ where
pub(crate) fn decrypt(&self, ciphertext: &Ciphertext<P>) -> B32 {
let (c1, c2) = P::split_ct(ciphertext);

self.decrypt_split(c1, c2)
}

pub(crate) fn decrypt_split(&self, c1: &Ciphertext1<P>, c2: &Ciphertext2<P>) -> B32 {
let mut u: Vector<P::K> = Encode::<P::Du>::decode(c1);
u.decompress::<P::Du>();

Expand Down Expand Up @@ -123,6 +127,11 @@ where
rho: B32,
}

/// First ciphertext for incremental encapsulation
pub type Ciphertext1<P> = array::Array<u8, crate::param::EncodedUSize<P>>;
/// Second ciphertext for incremental encapsulation
pub type Ciphertext2<P> = array::Array<u8, crate::param::EncodedVSize<P>>;

impl<P> EncryptionKey<P>
where
P: PkeParams,
Expand Down Expand Up @@ -152,6 +161,45 @@ where
P::concat_ct(c1, c2)
}

/// Encrypt the specified message for the holder of the corresponding decryption key, using the
/// provided randomness, according the `K-PKE.Encrypt` procedure.
pub(crate) fn encrypt_incremental_1(
&self,
randomness: &B32,
) -> (Ciphertext1<P>, NttVector<P::K>) {
let r = sample_poly_vec_cbd::<P::Eta1, P::K>(randomness, 0);
let e1 = sample_poly_vec_cbd::<P::Eta2, P::K>(randomness, P::K::U8);

let A_hat_t: NttMatrix<P::K> = matrix_sample_ntt(&self.rho, true);
let r_hat: NttVector<P::K> = r.ntt();
let ATr: Vector<P::K> = (&A_hat_t * &r_hat).ntt_inverse();
let mut u = ATr + e1;

let c1 = Encode::<P::Du>::encode(u.compress::<P::Du>());

(c1, r_hat)
}

/// Encrypt the specified message for the holder of the corresponding decryption key, using the
/// provided randomness, according the `K-PKE.Encrypt` procedure.
pub(crate) fn encrypt_incremental_2(
&self,
message: &B32,
randomness: &B32,
r_hat: &NttVector<P::K>,
) -> Ciphertext2<P> {
let prf_output = PRF::<P::Eta2>(randomness, 2 * P::K::U8);
let e2: Polynomial = sample_poly_cbd::<P::Eta2>(&prf_output);

let mut mu: Polynomial = Encode::<U1>::decode(message);
mu.decompress::<U1>();

let tTr: Polynomial = (&self.t_hat * r_hat).ntt_inverse();
let mut v = &(&tTr + &e2) + &mu;

Encode::<P::Dv>::encode(v.compress::<P::Dv>())
}

/// Represent this encryption key as a byte array `(t_hat || rho)`
pub(crate) fn to_bytes(&self) -> EncodedEncryptionKey<P> {
let t_hat = P::encode_u12(&self.t_hat);
Expand Down
64 changes: 63 additions & 1 deletion ml-kem/tests/encap-decap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,77 @@ use std::{fs::read_to_string, path::PathBuf};
pub trait EncapsulateDeterministic {
/// Returns `(ciphertext, shared_secret)`.
fn encapsulate_deterministic(&self, m: &ArrayN<u8, 32>) -> (Vec<u8>, Vec<u8>);
/// Returns `(ciphertext, shared_secret)`.
fn encapsulate_incremental_deterministic(&self, m: &ArrayN<u8, 32>) -> (Vec<u8>, Vec<u8>);
}

impl EncapsulateDeterministic for EncapsulationKey512 {
fn encapsulate_deterministic(&self, m: &ArrayN<u8, 32>) -> (Vec<u8>, Vec<u8>) {
let (c, k) = self.encapsulate_deterministic(m);
(c.to_vec(), k.to_vec())
}
fn encapsulate_incremental_deterministic(&self, m: &ArrayN<u8, 32>) -> (Vec<u8>, Vec<u8>) {
let (c1, es, k) = self.encapsulate_incremental_1_deterministic(m);
let c2 = self.encapsulate_incremental_2(es);
let mut c = c1.to_vec();
c.extend_from_slice(&c2);
(c, k.to_vec())
}
}

impl EncapsulateDeterministic for EncapsulationKey768 {
fn encapsulate_deterministic(&self, m: &ArrayN<u8, 32>) -> (Vec<u8>, Vec<u8>) {
let (c, k) = self.encapsulate_deterministic(m);
(c.to_vec(), k.to_vec())
}
fn encapsulate_incremental_deterministic(&self, m: &ArrayN<u8, 32>) -> (Vec<u8>, Vec<u8>) {
let (c1, es, k) = self.encapsulate_incremental_1_deterministic(m);
let c2 = self.encapsulate_incremental_2(es);
let mut c = c1.to_vec();
c.extend_from_slice(&c2);
(c, k.to_vec())
}
}

impl EncapsulateDeterministic for EncapsulationKey1024 {
fn encapsulate_deterministic(&self, m: &ArrayN<u8, 32>) -> (Vec<u8>, Vec<u8>) {
let (c, k) = self.encapsulate_deterministic(m);
(c.to_vec(), k.to_vec())
}
fn encapsulate_incremental_deterministic(&self, m: &ArrayN<u8, 32>) -> (Vec<u8>, Vec<u8>) {
let (c1, es, k) = self.encapsulate_incremental_1_deterministic(m);
let c2 = self.encapsulate_incremental_2(es);
let mut c = c1.to_vec();
c.extend_from_slice(&c2);
(c, k.to_vec())
}
}

/// A helper trait for deterministic incremental decapsulation tests
pub trait DecapsulateIncremental<C> {
/// Returns `shared_secret`.
fn decapsulate_incremental(&self, c: &C) -> Vec<u8>;
}

impl DecapsulateIncremental<ml_kem_512::Ciphertext> for DecapsulationKey512 {
fn decapsulate_incremental(&self, c: &ml_kem_512::Ciphertext) -> Vec<u8> {
let (c1, c2) = c.split_ref();
self.decapsulate_incremental(c1, c2).to_vec()
}
}

impl DecapsulateIncremental<ml_kem_768::Ciphertext> for DecapsulationKey768 {
fn decapsulate_incremental(&self, c: &ml_kem_768::Ciphertext) -> Vec<u8> {
let (c1, c2) = c.split_ref();
self.decapsulate_incremental(c1, c2).to_vec()
}
}

impl DecapsulateIncremental<ml_kem_1024::Ciphertext> for DecapsulationKey1024 {
fn decapsulate_incremental(&self, c: &ml_kem_1024::Ciphertext) -> Vec<u8> {
let (c1, c2) = c.split_ref();
self.decapsulate_incremental(c1, c2).to_vec()
}
}

#[test]
Expand Down Expand Up @@ -76,6 +126,11 @@ where

assert_eq!(k.as_slice(), tc.k.as_slice());
assert_eq!(c.as_slice(), tc.c.as_slice());

let (c, k) = ek.encapsulate_incremental_deterministic(&m);

assert_eq!(k.as_slice(), tc.k.as_slice());
assert_eq!(c.as_slice(), tc.c.as_slice());
}

fn verify_decap_group(tg: &acvp::DecapTestGroup) {
Expand All @@ -92,12 +147,19 @@ fn verify_decap_group(tg: &acvp::DecapTestGroup) {
fn verify_decap<K>(tc: &acvp::DecapTestCase, dk_slice: &[u8])
where
K: Kem,
K::DecapsulationKey: Decapsulate + Decapsulator<Kem = K> + ExpandedKeyEncoding,
K::DecapsulationKey: Decapsulate
+ DecapsulateIncremental<Ciphertext<K>>
+ Decapsulator<Kem = K>
+ ExpandedKeyEncoding,
{
let dk = K::DecapsulationKey::from_expanded_bytes(dk_slice.try_into().unwrap()).unwrap();
let c = Ciphertext::<K>::try_from(tc.c.as_slice()).unwrap();

let k = dk.decapsulate(&c);
assert_eq!(k.as_slice(), tc.k.as_slice());

let k = dk.decapsulate_incremental(&c);
assert_eq!(k.as_slice(), tc.k.as_slice());
}

mod acvp {
Expand Down