diff --git a/ml-kem/src/decapsulation_key.rs b/ml-kem/src/decapsulation_key.rs
index 50e62af..5ad75e5 100644
--- a/ml-kem/src/decapsulation_key.rs
+++ b/ml-kem/src/decapsulation_key.rs
@@ -2,7 +2,7 @@ use crate::{
B32, EncapsulationKey, Seed, SharedKey,
crypto::{G, J},
param::{DecapsulationKeySize, ExpandedDecapsulationKey, KemParams},
- pke::{DecryptionKey, EncryptionKey},
+ pke::{Ciphertext1, Ciphertext2, DecryptionKey, EncryptionKey},
};
use array::{
Array, ArraySize,
@@ -110,6 +110,16 @@ where
let d = Some(d);
Self { dk_pke, ek, d, z }
}
+
+ /// Decapsulates the given [`Ciphertext1`] and [`Ciphertext2`] a.k.a. "incremental encapsulated key".
+ pub fn decapsulate_incremental(&self, c1: &Ciphertext1
, c2: &Ciphertext2
) -> SharedKey {
+ let mp = self.dk_pke.decrypt_split(c1, c2);
+ let (Kp, rp) = G(&[&mp, &self.ek.h()]);
+ let Kbar = J(&[self.z.as_slice(), c1.as_ref(), c2.as_ref()]);
+ let cp = self.ek.ek_pke().encrypt(&mp, &rp);
+ let (cp1, cp2) = P::split_ct(&cp);
+ B32::conditional_select(&Kbar, &Kp, cp1.ct_eq(c1) & cp2.ct_eq(c2))
+ }
}
// Handwritten to omit `d` in the comparisons, so keys initialized from seeds compare equally to
diff --git a/ml-kem/src/encapsulation_key.rs b/ml-kem/src/encapsulation_key.rs
index f09e246..7c9f8f7 100644
--- a/ml-kem/src/encapsulation_key.rs
+++ b/ml-kem/src/encapsulation_key.rs
@@ -1,14 +1,45 @@
use crate::{
B32, SharedKey,
+ algebra::NttVector,
crypto::{G, H},
kem::{InvalidKey, Kem, Key, KeyExport, KeySizeUser, TryKeyInit},
param::{EncapsulationKeySize, KemParams},
- pke::EncryptionKey,
+ pke::{Ciphertext1, Ciphertext2, EncryptionKey},
};
use array::sizes::U32;
use kem::{Ciphertext, Encapsulate, Generate};
use rand_core::CryptoRng;
+#[cfg(feature = "zeroize")]
+use zeroize::{Zeroize, ZeroizeOnDrop};
+
+/// A temporary secret produced by the first incremental encapsulation step,
+/// to be used by the second one to finish encapsulation.
+#[derive(Clone, Debug)]
+pub struct EncapsulationSecret
+where
+ P: KemParams,
+{
+ m: B32,
+ r: B32,
+ es: NttVector,
+}
+
+#[cfg(feature = "zeroize")]
+impl Drop for EncapsulationSecret
+where
+ P: KemParams,
+{
+ fn drop(&mut self) {
+ self.m.zeroize();
+ self.r.zeroize();
+ self.es.zeroize();
+ }
+}
+
+#[cfg(feature = "zeroize")]
+impl
ZeroizeOnDrop for EncapsulationSecret
where P: KemParams {}
+
/// An `EncapsulationKey` provides the ability to encapsulate a shared key so that it can only be
/// decapsulated by the holder of the corresponding decapsulation key.
#[derive(Clone, Debug)]
@@ -61,6 +92,45 @@ where
pub(crate) fn h(&self) -> B32 {
self.h
}
+
+ /// Encapsulates incrementally with the given randomness. This is useful for testing against known vectors.
+ ///
+ /// # Warning
+ /// Do NOT use this function unless you know what you're doing. If you fail to use all uniform
+ /// random bytes even once, you can have catastrophic security failure.
+ #[cfg_attr(not(feature = "hazmat"), doc(hidden))]
+ pub fn encapsulate_incremental_1_deterministic(
+ &self,
+ m: &B32,
+ ) -> (Ciphertext1
, EncapsulationSecret
, SharedKey) {
+ let (K, r) = G(&[m, &self.h]);
+ let (c1, es) = self.ek_pke.encrypt_incremental_1(&r);
+ (c1, EncapsulationSecret { m: *m, r, es }, K)
+ }
+
+ /// Finish incremental encapsulation.
+ pub fn encapsulate_incremental_2(
+ &self,
+ encapsulation_secret: EncapsulationSecret
,
+ ) -> Ciphertext2
{
+ self.ek_pke.encrypt_incremental_2(
+ &encapsulation_secret.m,
+ &encapsulation_secret.r,
+ &encapsulation_secret.es,
+ )
+ }
+
+ /// Encapsulates incrementally a fresh [`SharedKey`] generated using the supplied random number generator `R`.
+ pub fn encapsulate_incremental_1_with_rng(
+ &self,
+ rng: &mut R,
+ ) -> (Ciphertext1, EncapsulationSecret
, SharedKey)
+ where
+ R: CryptoRng + ?Sized,
+ {
+ let m = B32::generate_from_rng(rng);
+ self.encapsulate_incremental_1_deterministic(&m)
+ }
}
impl
Encapsulate for EncapsulationKey
diff --git a/ml-kem/src/lib.rs b/ml-kem/src/lib.rs
index 4c7cf7a..f3e28d5 100644
--- a/ml-kem/src/lib.rs
+++ b/ml-kem/src/lib.rs
@@ -72,7 +72,7 @@ pub use array::{self, ArraySize};
pub use decapsulation_key::DecapsulationKey;
#[allow(deprecated)]
pub use decapsulation_key::ExpandedKeyEncoding;
-pub use encapsulation_key::EncapsulationKey;
+pub use encapsulation_key::{EncapsulationKey, EncapsulationSecret};
pub use kem::{
self, Ciphertext, Decapsulate, Encapsulate, FromSeed, Generate, InvalidKey, Kem, Key,
KeyExport, KeyInit, KeySizeUser, TryKeyInit,
@@ -81,6 +81,7 @@ pub use ml_kem_512::MlKem512;
pub use ml_kem_768::MlKem768;
pub use ml_kem_1024::MlKem1024;
pub use param::{ExpandedDecapsulationKey, ParameterSet};
+pub use pke::{Ciphertext1, Ciphertext2};
use array::{
Array,
@@ -133,9 +134,18 @@ pub mod ml_kem_512 {
/// can only be decapsulated by the holder of the corresponding decapsulation key.
pub type EncapsulationKey = crate::EncapsulationKey;
+ /// An ML-KEM-512 `EncapsulationSecret` is a temporary secret used for incremental encapsulation.
+ pub type EncapsulationSecret = crate::EncapsulationSecret;
+
/// Encoded ML-KEM-512 ciphertexts.
pub type Ciphertext = kem::Ciphertext;
+ /// Encoded ML-KEM-512 first incremental ciphertexts.
+ pub type Ciphertext1 = crate::Ciphertext1;
+
+ /// Encoded ML-KEM-512 second incrementalciphertexts.
+ pub type Ciphertext2 = crate::Ciphertext2;
+
/// Legacy expanded decapsulation keys. Prefer seeds instead.
#[doc(hidden)]
#[deprecated(since = "0.3.0", note = "use `Seed` instead")]
@@ -181,9 +191,18 @@ pub mod ml_kem_768 {
/// can only be decapsulated by the holder of the corresponding decapsulation key.
pub type EncapsulationKey = crate::EncapsulationKey;
- /// Encoded ML-KEM-512 ciphertexts.
+ /// An ML-KEM-768 `EncapsulationSecret` is a temporary secret used for incremental encapsulation.
+ pub type EncapsulationSecret = crate::EncapsulationSecret;
+
+ /// Encoded ML-KEM-768 ciphertexts.
pub type Ciphertext = kem::Ciphertext;
+ /// Encoded ML-KEM-768 first incremental ciphertexts.
+ pub type Ciphertext1 = crate::Ciphertext1;
+
+ /// Encoded ML-KEM-768 second incrementalciphertexts.
+ pub type Ciphertext2 = crate::Ciphertext2;
+
/// Legacy expanded decapsulation keys. Prefer seeds instead.
#[doc(hidden)]
#[deprecated(since = "0.3.0", note = "use `Seed` instead")]
@@ -228,9 +247,18 @@ pub mod ml_kem_1024 {
/// it can only be decapsulated by the holder of the corresponding decapsulation key.
pub type EncapsulationKey = crate::EncapsulationKey;
- /// Encoded ML-KEM-512 ciphertexts.
+ /// An ML-KEM-1024 `EncapsulationSecret` is a temporary secret used for incremental encapsulation.
+ pub type EncapsulationSecret = crate::EncapsulationSecret;
+
+ /// Encoded ML-KEM-1024 ciphertexts.
pub type Ciphertext = kem::Ciphertext;
+ /// Encoded ML-KEM-1024 first incremental ciphertexts.
+ pub type Ciphertext1 = crate::Ciphertext1;
+
+ /// Encoded ML-KEM-1024 second incrementalciphertexts.
+ pub type Ciphertext2 = crate::Ciphertext2;
+
/// Legacy expanded decapsulation keys. Prefer seeds instead.
#[doc(hidden)]
#[deprecated(since = "0.3.0", note = "use `Seed` instead")]
diff --git a/ml-kem/src/pke.rs b/ml-kem/src/pke.rs
index fae743b..8d9d0b4 100644
--- a/ml-kem/src/pke.rs
+++ b/ml-kem/src/pke.rs
@@ -88,6 +88,10 @@ where
pub(crate) fn decrypt(&self, ciphertext: &Ciphertext) -> B32 {
let (c1, c2) = P::split_ct(ciphertext);
+ self.decrypt_split(c1, c2)
+ }
+
+ pub(crate) fn decrypt_split(&self, c1: &Ciphertext1
, c2: &Ciphertext2
) -> B32 {
let mut u: Vector = Encode::::decode(c1);
u.decompress::();
@@ -123,6 +127,11 @@ where
rho: B32,
}
+/// First ciphertext for incremental encapsulation
+pub type Ciphertext1 = array::Array>;
+/// Second ciphertext for incremental encapsulation
+pub type Ciphertext2 = array::Array>;
+
impl EncryptionKey
where
P: PkeParams,
@@ -152,6 +161,45 @@ where
P::concat_ct(c1, c2)
}
+ /// Encrypt the specified message for the holder of the corresponding decryption key, using the
+ /// provided randomness, according the `K-PKE.Encrypt` procedure.
+ pub(crate) fn encrypt_incremental_1(
+ &self,
+ randomness: &B32,
+ ) -> (Ciphertext1
, NttVector) {
+ let r = sample_poly_vec_cbd::(randomness, 0);
+ let e1 = sample_poly_vec_cbd::(randomness, P::K::U8);
+
+ let A_hat_t: NttMatrix = matrix_sample_ntt(&self.rho, true);
+ let r_hat: NttVector = r.ntt();
+ let ATr: Vector = (&A_hat_t * &r_hat).ntt_inverse();
+ let mut u = ATr + e1;
+
+ let c1 = Encode::::encode(u.compress::());
+
+ (c1, r_hat)
+ }
+
+ /// Encrypt the specified message for the holder of the corresponding decryption key, using the
+ /// provided randomness, according the `K-PKE.Encrypt` procedure.
+ pub(crate) fn encrypt_incremental_2(
+ &self,
+ message: &B32,
+ randomness: &B32,
+ r_hat: &NttVector,
+ ) -> Ciphertext2 {
+ let prf_output = PRF::(randomness, 2 * P::K::U8);
+ let e2: Polynomial = sample_poly_cbd::(&prf_output);
+
+ let mut mu: Polynomial = Encode::::decode(message);
+ mu.decompress::();
+
+ let tTr: Polynomial = (&self.t_hat * r_hat).ntt_inverse();
+ let mut v = &(&tTr + &e2) + μ
+
+ Encode::::encode(v.compress::())
+ }
+
/// Represent this encryption key as a byte array `(t_hat || rho)`
pub(crate) fn to_bytes(&self) -> EncodedEncryptionKey {
let t_hat = P::encode_u12(&self.t_hat);
diff --git a/ml-kem/tests/encap-decap.rs b/ml-kem/tests/encap-decap.rs
index 4109f27..84a3364 100644
--- a/ml-kem/tests/encap-decap.rs
+++ b/ml-kem/tests/encap-decap.rs
@@ -12,6 +12,8 @@ use std::{fs::read_to_string, path::PathBuf};
pub trait EncapsulateDeterministic {
/// Returns `(ciphertext, shared_secret)`.
fn encapsulate_deterministic(&self, m: &ArrayN) -> (Vec, Vec);
+ /// Returns `(ciphertext, shared_secret)`.
+ fn encapsulate_incremental_deterministic(&self, m: &ArrayN) -> (Vec, Vec);
}
impl EncapsulateDeterministic for EncapsulationKey512 {
@@ -19,6 +21,13 @@ impl EncapsulateDeterministic for EncapsulationKey512 {
let (c, k) = self.encapsulate_deterministic(m);
(c.to_vec(), k.to_vec())
}
+ fn encapsulate_incremental_deterministic(&self, m: &ArrayN) -> (Vec, Vec) {
+ let (c1, es, k) = self.encapsulate_incremental_1_deterministic(m);
+ let c2 = self.encapsulate_incremental_2(es);
+ let mut c = c1.to_vec();
+ c.extend_from_slice(&c2);
+ (c, k.to_vec())
+ }
}
impl EncapsulateDeterministic for EncapsulationKey768 {
@@ -26,6 +35,13 @@ impl EncapsulateDeterministic for EncapsulationKey768 {
let (c, k) = self.encapsulate_deterministic(m);
(c.to_vec(), k.to_vec())
}
+ fn encapsulate_incremental_deterministic(&self, m: &ArrayN) -> (Vec, Vec) {
+ let (c1, es, k) = self.encapsulate_incremental_1_deterministic(m);
+ let c2 = self.encapsulate_incremental_2(es);
+ let mut c = c1.to_vec();
+ c.extend_from_slice(&c2);
+ (c, k.to_vec())
+ }
}
impl EncapsulateDeterministic for EncapsulationKey1024 {
@@ -33,6 +49,40 @@ impl EncapsulateDeterministic for EncapsulationKey1024 {
let (c, k) = self.encapsulate_deterministic(m);
(c.to_vec(), k.to_vec())
}
+ fn encapsulate_incremental_deterministic(&self, m: &ArrayN) -> (Vec, Vec) {
+ let (c1, es, k) = self.encapsulate_incremental_1_deterministic(m);
+ let c2 = self.encapsulate_incremental_2(es);
+ let mut c = c1.to_vec();
+ c.extend_from_slice(&c2);
+ (c, k.to_vec())
+ }
+}
+
+/// A helper trait for deterministic incremental decapsulation tests
+pub trait DecapsulateIncremental {
+ /// Returns `shared_secret`.
+ fn decapsulate_incremental(&self, c: &C) -> Vec;
+}
+
+impl DecapsulateIncremental for DecapsulationKey512 {
+ fn decapsulate_incremental(&self, c: &ml_kem_512::Ciphertext) -> Vec {
+ let (c1, c2) = c.split_ref();
+ self.decapsulate_incremental(c1, c2).to_vec()
+ }
+}
+
+impl DecapsulateIncremental for DecapsulationKey768 {
+ fn decapsulate_incremental(&self, c: &ml_kem_768::Ciphertext) -> Vec {
+ let (c1, c2) = c.split_ref();
+ self.decapsulate_incremental(c1, c2).to_vec()
+ }
+}
+
+impl DecapsulateIncremental for DecapsulationKey1024 {
+ fn decapsulate_incremental(&self, c: &ml_kem_1024::Ciphertext) -> Vec {
+ let (c1, c2) = c.split_ref();
+ self.decapsulate_incremental(c1, c2).to_vec()
+ }
}
#[test]
@@ -76,6 +126,11 @@ where
assert_eq!(k.as_slice(), tc.k.as_slice());
assert_eq!(c.as_slice(), tc.c.as_slice());
+
+ let (c, k) = ek.encapsulate_incremental_deterministic(&m);
+
+ assert_eq!(k.as_slice(), tc.k.as_slice());
+ assert_eq!(c.as_slice(), tc.c.as_slice());
}
fn verify_decap_group(tg: &acvp::DecapTestGroup) {
@@ -92,12 +147,19 @@ fn verify_decap_group(tg: &acvp::DecapTestGroup) {
fn verify_decap(tc: &acvp::DecapTestCase, dk_slice: &[u8])
where
K: Kem,
- K::DecapsulationKey: Decapsulate + Decapsulator + ExpandedKeyEncoding,
+ K::DecapsulationKey: Decapsulate
+ + DecapsulateIncremental>
+ + Decapsulator
+ + ExpandedKeyEncoding,
{
let dk = K::DecapsulationKey::from_expanded_bytes(dk_slice.try_into().unwrap()).unwrap();
let c = Ciphertext::::try_from(tc.c.as_slice()).unwrap();
+
let k = dk.decapsulate(&c);
assert_eq!(k.as_slice(), tc.k.as_slice());
+
+ let k = dk.decapsulate_incremental(&c);
+ assert_eq!(k.as_slice(), tc.k.as_slice());
}
mod acvp {