From 2b27a56fd58972458672d4cd6f02a751f96c6e15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20Frauenschl=C3=A4ger?= Date: Tue, 14 Apr 2026 19:36:31 +0200 Subject: [PATCH 1/4] Add ML-KEM (FIPS 203) post-quantum KEM support Implement full client-server ML-KEM (Module-Lattice-Based Key Encapsulation Mechanism) support across all wolfHSM layers, enabling post-quantum key exchange operations to be offloaded to the HSM. Client API (wh_client_crypto): - Key management: import, export, set/get key ID - Key generation: MakeExportKey (ephemeral) and MakeCacheKey (server-cached) - Encapsulation and decapsulation operations - DMA variants for all operations Server handling (wh_server_crypto): - Request handlers for ML-KEM keygen, encapsulate, and decapsulate - Auto-import with evict-after-use for uncached keys - DMA request handlers Crypto callback integration (wh_client_cryptocb): - Register PQC KEM keygen/encaps/decaps handlers so wolfCrypt ML-KEM calls are transparently forwarded to the HSM via WH_DEV_ID Message layer (wh_message_crypto): - Define request/response structures for keygen, encapsulate, decapsulate - Endian translation functions for cross-platform support Shared utilities (wh_crypto): - ML-KEM key serialization/deserialization with automatic level probing Supports all three ML-KEM parameter sets (512, 768, 1024). Includes tests for all operations and DMA paths, and benchmarks for keygen, encaps, and decaps at each security level. Also fixes key export response to use actual stored key length from NVM metadata instead of the request size. --- benchmark/README.md | 2 +- benchmark/bench_modules/wh_bench_mod_all.h | 45 + benchmark/bench_modules/wh_bench_mod_mlkem.c | 343 ++++ benchmark/config/user_settings.h | 4 + benchmark/wh_bench.c | 56 + benchmark/wh_bench_ops.h | 2 +- .../posix/wh_posix_server/user_settings.h | 4 + src/wh_client_crypto.c | 928 ++++++++++ src/wh_client_cryptocb.c | 197 +++ src/wh_crypto.c | 106 ++ src/wh_message_crypto.c | 222 +++ src/wh_server_crypto.c | 1512 +++++++++++++---- src/wh_server_keystore.c | 18 +- test/config/user_settings.h | 4 + test/wh_test_crypto.c | 740 ++++++++ wolfhsm/wh_client_crypto.h | 200 +++ wolfhsm/wh_crypto.h | 11 + wolfhsm/wh_message_crypto.h | 173 ++ wolfhsm/wh_server_crypto.h | 11 + 19 files changed, 4233 insertions(+), 345 deletions(-) create mode 100644 benchmark/bench_modules/wh_bench_mod_mlkem.c diff --git a/benchmark/README.md b/benchmark/README.md index 71e7e7321..b46fb0e04 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -10,7 +10,7 @@ The wolfHSM benchmarks provide a framework for testing and measuring the perform - Hash functions (SHA-2, SHA-3) - Message Authentication Codes (HMAC, CMAC) - Public Key Cryptography (RSA, ECC, Curve25519) -- Post-Quantum Cryptography (ML-DSA) +- Post-Quantum Cryptography (ML-DSA, ML-KEM) - Basic communication (Echo) The benchmark system measures the runtime of registered operations, as well as reports the throughput in either operations per second or bytes per second depending on the algorithm. diff --git a/benchmark/bench_modules/wh_bench_mod_all.h b/benchmark/bench_modules/wh_bench_mod_all.h index 458da0104..629be885a 100644 --- a/benchmark/bench_modules/wh_bench_mod_all.h +++ b/benchmark/bench_modules/wh_bench_mod_all.h @@ -379,4 +379,49 @@ int wh_Bench_Mod_MlDsa87KeyGen(whClientContext* client, whBenchOpContext* ctx, int wh_Bench_Mod_MlDsa87KeyGenDma(whClientContext* client, whBenchOpContext* ctx, int id, void* params); +/* + * ML-KEM benchmark module prototypes (wh_bench_mod_mlkem.c) + */ +int wh_Bench_Mod_MlKem512KeyGen(whClientContext* client, whBenchOpContext* ctx, + int id, void* params); +int wh_Bench_Mod_MlKem512KeyGenDma(whClientContext* client, + whBenchOpContext* ctx, int id, void* params); +int wh_Bench_Mod_MlKem512Encaps(whClientContext* client, whBenchOpContext* ctx, + int id, void* params); +int wh_Bench_Mod_MlKem512EncapsDma(whClientContext* client, + whBenchOpContext* ctx, int id, void* params); +int wh_Bench_Mod_MlKem512Decaps(whClientContext* client, whBenchOpContext* ctx, + int id, void* params); +int wh_Bench_Mod_MlKem512DecapsDma(whClientContext* client, + whBenchOpContext* ctx, int id, void* params); + +int wh_Bench_Mod_MlKem768KeyGen(whClientContext* client, whBenchOpContext* ctx, + int id, void* params); +int wh_Bench_Mod_MlKem768KeyGenDma(whClientContext* client, + whBenchOpContext* ctx, int id, void* params); +int wh_Bench_Mod_MlKem768Encaps(whClientContext* client, whBenchOpContext* ctx, + int id, void* params); +int wh_Bench_Mod_MlKem768EncapsDma(whClientContext* client, + whBenchOpContext* ctx, int id, void* params); +int wh_Bench_Mod_MlKem768Decaps(whClientContext* client, whBenchOpContext* ctx, + int id, void* params); +int wh_Bench_Mod_MlKem768DecapsDma(whClientContext* client, + whBenchOpContext* ctx, int id, void* params); + +int wh_Bench_Mod_MlKem1024KeyGen(whClientContext* client, + whBenchOpContext* ctx, int id, void* params); +int wh_Bench_Mod_MlKem1024KeyGenDma(whClientContext* client, + whBenchOpContext* ctx, int id, + void* params); +int wh_Bench_Mod_MlKem1024Encaps(whClientContext* client, + whBenchOpContext* ctx, int id, void* params); +int wh_Bench_Mod_MlKem1024EncapsDma(whClientContext* client, + whBenchOpContext* ctx, int id, + void* params); +int wh_Bench_Mod_MlKem1024Decaps(whClientContext* client, + whBenchOpContext* ctx, int id, void* params); +int wh_Bench_Mod_MlKem1024DecapsDma(whClientContext* client, + whBenchOpContext* ctx, int id, + void* params); + #endif /* WH_BENCH_MOD_ALL_H_ */ diff --git a/benchmark/bench_modules/wh_bench_mod_mlkem.c b/benchmark/bench_modules/wh_bench_mod_mlkem.c new file mode 100644 index 000000000..5d75b6a8c --- /dev/null +++ b/benchmark/bench_modules/wh_bench_mod_mlkem.c @@ -0,0 +1,343 @@ +/* + * Copyright (C) 2026 wolfSSL Inc. + * + * This file is part of wolfHSM. + * + * wolfHSM is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * wolfHSM is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with wolfHSM. If not, see . + */ +#include + +#include "wh_bench_mod.h" +#include "wolfhsm/wh_error.h" +#include "wolfhsm/wh_client.h" +#include "wolfhsm/wh_client_crypto.h" + +#if !defined(WOLFHSM_CFG_NO_CRYPTO) && defined(WOLFHSM_CFG_BENCH_ENABLE) +#include "wolfssl/wolfcrypt/wc_mlkem.h" + +#if defined(WOLFSSL_HAVE_MLKEM) + +static int _benchMlKemKeyGen(whClientContext* client, whBenchOpContext* ctx, + int id, int securityLevel, int devId) +{ + int ret = WH_ERROR_OK; + int i; + + for (i = 0; i < WOLFHSM_CFG_BENCH_KG_ITERS && ret == WH_ERROR_OK; i++) { + MlKemKey key[1]; + int benchStartRet; + int benchStopRet; + + ret = wc_MlKemKey_Init(key, securityLevel, NULL, devId); + if (ret != WH_ERROR_OK) { + WH_BENCH_PRINTF("Failed to wc_MlKemKey_Init %d\n", ret); + break; + } + + benchStartRet = wh_Bench_StartOp(ctx, id); +#ifdef WOLFHSM_CFG_DMA + if (devId == WH_DEV_ID_DMA) { + ret = wh_Client_MlKemMakeExportKeyDma(client, securityLevel, key); + } + else +#endif /* WOLFHSM_CFG_DMA */ + { + ret = wh_Client_MlKemMakeExportKey(client, securityLevel, key); + } + benchStopRet = wh_Bench_StopOp(ctx, id); + + if (benchStartRet != WH_ERROR_OK) { + WH_BENCH_PRINTF("Failed to wh_Bench_StartOp %d\n", benchStartRet); + ret = benchStartRet; + } + else if (ret != WH_ERROR_OK) { + WH_BENCH_PRINTF("Failed ML-KEM keygen %d\n", ret); + } + else if (benchStopRet != WH_ERROR_OK) { + WH_BENCH_PRINTF("Failed to wh_Bench_StopOp %d\n", benchStopRet); + ret = benchStopRet; + } + + wc_MlKemKey_Free(key); + } + + return ret; +} + +static int _benchMlKemEncaps(whClientContext* client, whBenchOpContext* ctx, + int id, int securityLevel, int devId) +{ + int ret = WH_ERROR_OK; + int i; + MlKemKey key[1]; + byte ct[WC_ML_KEM_MAX_CIPHER_TEXT_SIZE]; + byte ss[WC_ML_KEM_SS_SZ]; + + ret = wc_MlKemKey_Init(key, securityLevel, NULL, devId); + if (ret != WH_ERROR_OK) { + WH_BENCH_PRINTF("Failed to wc_MlKemKey_Init %d\n", ret); + return ret; + } + +#ifdef WOLFHSM_CFG_DMA + if (devId == WH_DEV_ID_DMA) { + ret = wh_Client_MlKemMakeExportKeyDma(client, securityLevel, key); + } + else +#endif /* WOLFHSM_CFG_DMA */ + { + ret = wh_Client_MlKemMakeExportKey(client, securityLevel, key); + } + if (ret != WH_ERROR_OK) { + WH_BENCH_PRINTF("Failed ML-KEM key setup %d\n", ret); + wc_MlKemKey_Free(key); + return ret; + } + + for (i = 0; i < WOLFHSM_CFG_BENCH_PK_ITERS && ret == WH_ERROR_OK; i++) { + word32 ctLen = sizeof(ct); + word32 ssLen = sizeof(ss); + int benchStartRet; + int benchStopRet; + + memset(ct, 0, sizeof(ct)); + memset(ss, 0, sizeof(ss)); + + benchStartRet = wh_Bench_StartOp(ctx, id); +#ifdef WOLFHSM_CFG_DMA + if (devId == WH_DEV_ID_DMA) { + ret = wh_Client_MlKemEncapsulateDma(client, key, ct, &ctLen, ss, + &ssLen); + } + else +#endif /* WOLFHSM_CFG_DMA */ + { + ret = wh_Client_MlKemEncapsulate(client, key, ct, &ctLen, ss, &ssLen); + } + benchStopRet = wh_Bench_StopOp(ctx, id); + + if (benchStartRet != WH_ERROR_OK) { + WH_BENCH_PRINTF("Failed to wh_Bench_StartOp %d\n", benchStartRet); + ret = benchStartRet; + } + else if (ret != WH_ERROR_OK) { + WH_BENCH_PRINTF("Failed ML-KEM encapsulate %d\n", ret); + } + else if (benchStopRet != WH_ERROR_OK) { + WH_BENCH_PRINTF("Failed to wh_Bench_StopOp %d\n", benchStopRet); + ret = benchStopRet; + } + } + + wc_MlKemKey_Free(key); + return ret; +} + +static int _benchMlKemDecaps(whClientContext* client, whBenchOpContext* ctx, + int id, int securityLevel, int devId) +{ + int ret = WH_ERROR_OK; + int i; + MlKemKey key[1]; + byte ct[WC_ML_KEM_MAX_CIPHER_TEXT_SIZE]; + byte ssEnc[WC_ML_KEM_SS_SZ]; + byte ssDec[WC_ML_KEM_SS_SZ]; + word32 ctLen = sizeof(ct); + word32 ssEncLen = sizeof(ssEnc); + + ret = wc_MlKemKey_Init(key, securityLevel, NULL, devId); + if (ret != WH_ERROR_OK) { + WH_BENCH_PRINTF("Failed to wc_MlKemKey_Init %d\n", ret); + return ret; + } + +#ifdef WOLFHSM_CFG_DMA + if (devId == WH_DEV_ID_DMA) { + ret = wh_Client_MlKemMakeExportKeyDma(client, securityLevel, key); + } + else +#endif /* WOLFHSM_CFG_DMA */ + { + ret = wh_Client_MlKemMakeExportKey(client, securityLevel, key); + } + if (ret != WH_ERROR_OK) { + WH_BENCH_PRINTF("Failed ML-KEM key setup %d\n", ret); + wc_MlKemKey_Free(key); + return ret; + } + +#ifdef WOLFHSM_CFG_DMA + if (devId == WH_DEV_ID_DMA) { + ret = wh_Client_MlKemEncapsulateDma(client, key, ct, &ctLen, ssEnc, + &ssEncLen); + } + else +#endif /* WOLFHSM_CFG_DMA */ + { + ret = wh_Client_MlKemEncapsulate(client, key, ct, &ctLen, ssEnc, + &ssEncLen); + } + if (ret != WH_ERROR_OK) { + WH_BENCH_PRINTF("Failed ML-KEM setup encapsulate %d\n", ret); + wc_MlKemKey_Free(key); + return ret; + } + + for (i = 0; i < WOLFHSM_CFG_BENCH_PK_ITERS && ret == WH_ERROR_OK; i++) { + word32 ssDecLen = sizeof(ssDec); + int benchStartRet; + int benchStopRet; + + memset(ssDec, 0, sizeof(ssDec)); + + benchStartRet = wh_Bench_StartOp(ctx, id); +#ifdef WOLFHSM_CFG_DMA + if (devId == WH_DEV_ID_DMA) { + ret = wh_Client_MlKemDecapsulateDma(client, key, ct, ctLen, ssDec, + &ssDecLen); + } + else +#endif /* WOLFHSM_CFG_DMA */ + { + ret = wh_Client_MlKemDecapsulate(client, key, ct, ctLen, ssDec, + &ssDecLen); + } + benchStopRet = wh_Bench_StopOp(ctx, id); + + if (benchStartRet != WH_ERROR_OK) { + WH_BENCH_PRINTF("Failed to wh_Bench_StartOp %d\n", benchStartRet); + ret = benchStartRet; + } + else if (ret != WH_ERROR_OK) { + WH_BENCH_PRINTF("Failed ML-KEM decapsulate %d\n", ret); + } + else if ((ssDecLen != ssEncLen) || + (memcmp(ssDec, ssEnc, ssEncLen) != 0)) { + WH_BENCH_PRINTF("ML-KEM decapsulate mismatch\n"); + ret = WH_ERROR_ABORTED; + } + else if (benchStopRet != WH_ERROR_OK) { + WH_BENCH_PRINTF("Failed to wh_Bench_StopOp %d\n", benchStopRet); + ret = benchStopRet; + } + } + + wc_MlKemKey_Free(key); + return ret; +} + +#define WH_DEFINE_MLKEM_BENCH_NON_DMA_FNS(_Suffix, _Level) \ +int wh_Bench_Mod_MlKem##_Suffix##KeyGen(whClientContext* client, \ + whBenchOpContext* ctx, int id, \ + void* params) \ +{ \ + (void)params; \ + return _benchMlKemKeyGen(client, ctx, id, _Level, WH_DEV_ID); \ +} \ + \ +int wh_Bench_Mod_MlKem##_Suffix##Encaps(whClientContext* client, \ + whBenchOpContext* ctx, int id, \ + void* params) \ +{ \ + (void)params; \ + return _benchMlKemEncaps(client, ctx, id, _Level, WH_DEV_ID); \ +} \ + \ +int wh_Bench_Mod_MlKem##_Suffix##Decaps(whClientContext* client, \ + whBenchOpContext* ctx, int id, \ + void* params) \ +{ \ + (void)params; \ + return _benchMlKemDecaps(client, ctx, id, _Level, WH_DEV_ID); \ +} + +#ifdef WOLFHSM_CFG_DMA +#define WH_DEFINE_MLKEM_BENCH_DMA_FNS(_Suffix, _Level) \ +int wh_Bench_Mod_MlKem##_Suffix##KeyGenDma(whClientContext* client, \ + whBenchOpContext* ctx, int id, \ + void* params) \ +{ \ + (void)params; \ + return _benchMlKemKeyGen(client, ctx, id, _Level, WH_DEV_ID_DMA); \ +} \ + \ +int wh_Bench_Mod_MlKem##_Suffix##EncapsDma(whClientContext* client, \ + whBenchOpContext* ctx, int id, \ + void* params) \ +{ \ + (void)params; \ + return _benchMlKemEncaps(client, ctx, id, _Level, WH_DEV_ID_DMA); \ +} \ + \ +int wh_Bench_Mod_MlKem##_Suffix##DecapsDma(whClientContext* client, \ + whBenchOpContext* ctx, int id, \ + void* params) \ +{ \ + (void)params; \ + return _benchMlKemDecaps(client, ctx, id, _Level, WH_DEV_ID_DMA); \ +} +#else +#define WH_DEFINE_MLKEM_BENCH_DMA_FNS(_Suffix, _Level) \ +int wh_Bench_Mod_MlKem##_Suffix##KeyGenDma(whClientContext* client, \ + whBenchOpContext* ctx, int id, \ + void* params) \ +{ \ + (void)client; \ + (void)ctx; \ + (void)id; \ + (void)params; \ + (void)_Level; \ + return WH_ERROR_NOTIMPL; \ +} \ + \ +int wh_Bench_Mod_MlKem##_Suffix##EncapsDma(whClientContext* client, \ + whBenchOpContext* ctx, int id, \ + void* params) \ +{ \ + (void)client; \ + (void)ctx; \ + (void)id; \ + (void)params; \ + (void)_Level; \ + return WH_ERROR_NOTIMPL; \ +} \ + \ +int wh_Bench_Mod_MlKem##_Suffix##DecapsDma(whClientContext* client, \ + whBenchOpContext* ctx, int id, \ + void* params) \ +{ \ + (void)client; \ + (void)ctx; \ + (void)id; \ + (void)params; \ + (void)_Level; \ + return WH_ERROR_NOTIMPL; \ +} +#endif /* WOLFHSM_CFG_DMA */ + +#ifndef WOLFSSL_NO_ML_KEM_512 +WH_DEFINE_MLKEM_BENCH_NON_DMA_FNS(512, WC_ML_KEM_512) +WH_DEFINE_MLKEM_BENCH_DMA_FNS(512, WC_ML_KEM_512) +#endif +#ifndef WOLFSSL_NO_ML_KEM_768 +WH_DEFINE_MLKEM_BENCH_NON_DMA_FNS(768, WC_ML_KEM_768) +WH_DEFINE_MLKEM_BENCH_DMA_FNS(768, WC_ML_KEM_768) +#endif +#ifndef WOLFSSL_NO_ML_KEM_1024 +WH_DEFINE_MLKEM_BENCH_NON_DMA_FNS(1024, WC_ML_KEM_1024) +WH_DEFINE_MLKEM_BENCH_DMA_FNS(1024, WC_ML_KEM_1024) +#endif + +#endif /* WOLFSSL_HAVE_MLKEM */ +#endif /* !WOLFHSM_CFG_NO_CRYPTO && WOLFHSM_CFG_BENCH_ENABLE */ diff --git a/benchmark/config/user_settings.h b/benchmark/config/user_settings.h index 81a5264a7..c2e5b6794 100644 --- a/benchmark/config/user_settings.h +++ b/benchmark/config/user_settings.h @@ -155,6 +155,10 @@ extern "C" { #define WOLFSSL_DILITHIUM_NO_MAKE_KEY #endif +/* ML-KEM Options */ +#define WOLFSSL_HAVE_MLKEM +#define WOLFSSL_WC_MLKEM + /** Composite features */ #define HAVE_HKDF #define HAVE_CMAC_KDF diff --git a/benchmark/wh_bench.c b/benchmark/wh_bench.c index 15773f568..dc68f04df 100644 --- a/benchmark/wh_bench.c +++ b/benchmark/wh_bench.c @@ -254,6 +254,34 @@ typedef enum BenchModuleIdx { BENCH_MODULE_IDX_ML_DSA_87_KEY_GEN_DMA, #endif /* !(WOLFSSL_NO_ML_DSA_87) */ #endif /* HAVE_DILITHIUM */ + +/* ML-KEM */ +#if defined(WOLFSSL_HAVE_MLKEM) +#ifndef WOLFSSL_NO_ML_KEM_512 + BENCH_MODULE_IDX_ML_KEM_512_KEY_GEN, + BENCH_MODULE_IDX_ML_KEM_512_KEY_GEN_DMA, + BENCH_MODULE_IDX_ML_KEM_512_ENCAPS, + BENCH_MODULE_IDX_ML_KEM_512_ENCAPS_DMA, + BENCH_MODULE_IDX_ML_KEM_512_DECAPS, + BENCH_MODULE_IDX_ML_KEM_512_DECAPS_DMA, +#endif /* !WOLFSSL_NO_ML_KEM_512 */ +#ifndef WOLFSSL_NO_ML_KEM_768 + BENCH_MODULE_IDX_ML_KEM_768_KEY_GEN, + BENCH_MODULE_IDX_ML_KEM_768_KEY_GEN_DMA, + BENCH_MODULE_IDX_ML_KEM_768_ENCAPS, + BENCH_MODULE_IDX_ML_KEM_768_ENCAPS_DMA, + BENCH_MODULE_IDX_ML_KEM_768_DECAPS, + BENCH_MODULE_IDX_ML_KEM_768_DECAPS_DMA, +#endif /* !WOLFSSL_NO_ML_KEM_768 */ +#ifndef WOLFSSL_NO_ML_KEM_1024 + BENCH_MODULE_IDX_ML_KEM_1024_KEY_GEN, + BENCH_MODULE_IDX_ML_KEM_1024_KEY_GEN_DMA, + BENCH_MODULE_IDX_ML_KEM_1024_ENCAPS, + BENCH_MODULE_IDX_ML_KEM_1024_ENCAPS_DMA, + BENCH_MODULE_IDX_ML_KEM_1024_DECAPS, + BENCH_MODULE_IDX_ML_KEM_1024_DECAPS_DMA, +#endif /* !WOLFSSL_NO_ML_KEM_1024 */ +#endif /* WOLFSSL_HAVE_MLKEM */ #endif /* !(WOLFHSM_CFG_NO_CRYPTO) */ /* number of modules. This must be the last entry and will be used as the * size of the global modules array */ @@ -440,6 +468,34 @@ static BenchModule g_benchModules[] = { [BENCH_MODULE_IDX_ML_DSA_87_KEY_GEN_DMA] = {"ML-DSA-87-KEY-GEN-DMA", wh_Bench_Mod_MlDsa87KeyGenDma, BENCH_THROUGHPUT_OPS, 0, NULL}, #endif /* !(WOLFSSL_NO_ML_DSA_87) */ #endif /* HAVE_DILITHIUM */ + + /* ML-KEM */ +#if defined(WOLFSSL_HAVE_MLKEM) +#ifndef WOLFSSL_NO_ML_KEM_512 + [BENCH_MODULE_IDX_ML_KEM_512_KEY_GEN] = {"ML-KEM-512-KEY-GEN", wh_Bench_Mod_MlKem512KeyGen, BENCH_THROUGHPUT_OPS, 0, NULL}, + [BENCH_MODULE_IDX_ML_KEM_512_KEY_GEN_DMA] = {"ML-KEM-512-KEY-GEN-DMA", wh_Bench_Mod_MlKem512KeyGenDma, BENCH_THROUGHPUT_OPS, 0, NULL}, + [BENCH_MODULE_IDX_ML_KEM_512_ENCAPS] = {"ML-KEM-512-ENCAPS", wh_Bench_Mod_MlKem512Encaps, BENCH_THROUGHPUT_OPS, 0, NULL}, + [BENCH_MODULE_IDX_ML_KEM_512_ENCAPS_DMA] = {"ML-KEM-512-ENCAPS-DMA", wh_Bench_Mod_MlKem512EncapsDma, BENCH_THROUGHPUT_OPS, 0, NULL}, + [BENCH_MODULE_IDX_ML_KEM_512_DECAPS] = {"ML-KEM-512-DECAPS", wh_Bench_Mod_MlKem512Decaps, BENCH_THROUGHPUT_OPS, 0, NULL}, + [BENCH_MODULE_IDX_ML_KEM_512_DECAPS_DMA] = {"ML-KEM-512-DECAPS-DMA", wh_Bench_Mod_MlKem512DecapsDma, BENCH_THROUGHPUT_OPS, 0, NULL}, +#endif /* !WOLFSSL_NO_ML_KEM_512 */ +#ifndef WOLFSSL_NO_ML_KEM_768 + [BENCH_MODULE_IDX_ML_KEM_768_KEY_GEN] = {"ML-KEM-768-KEY-GEN", wh_Bench_Mod_MlKem768KeyGen, BENCH_THROUGHPUT_OPS, 0, NULL}, + [BENCH_MODULE_IDX_ML_KEM_768_KEY_GEN_DMA] = {"ML-KEM-768-KEY-GEN-DMA", wh_Bench_Mod_MlKem768KeyGenDma, BENCH_THROUGHPUT_OPS, 0, NULL}, + [BENCH_MODULE_IDX_ML_KEM_768_ENCAPS] = {"ML-KEM-768-ENCAPS", wh_Bench_Mod_MlKem768Encaps, BENCH_THROUGHPUT_OPS, 0, NULL}, + [BENCH_MODULE_IDX_ML_KEM_768_ENCAPS_DMA] = {"ML-KEM-768-ENCAPS-DMA", wh_Bench_Mod_MlKem768EncapsDma, BENCH_THROUGHPUT_OPS, 0, NULL}, + [BENCH_MODULE_IDX_ML_KEM_768_DECAPS] = {"ML-KEM-768-DECAPS", wh_Bench_Mod_MlKem768Decaps, BENCH_THROUGHPUT_OPS, 0, NULL}, + [BENCH_MODULE_IDX_ML_KEM_768_DECAPS_DMA] = {"ML-KEM-768-DECAPS-DMA", wh_Bench_Mod_MlKem768DecapsDma, BENCH_THROUGHPUT_OPS, 0, NULL}, +#endif /* !WOLFSSL_NO_ML_KEM_768 */ +#ifndef WOLFSSL_NO_ML_KEM_1024 + [BENCH_MODULE_IDX_ML_KEM_1024_KEY_GEN] = {"ML-KEM-1024-KEY-GEN", wh_Bench_Mod_MlKem1024KeyGen, BENCH_THROUGHPUT_OPS, 0, NULL}, + [BENCH_MODULE_IDX_ML_KEM_1024_KEY_GEN_DMA] = {"ML-KEM-1024-KEY-GEN-DMA", wh_Bench_Mod_MlKem1024KeyGenDma, BENCH_THROUGHPUT_OPS, 0, NULL}, + [BENCH_MODULE_IDX_ML_KEM_1024_ENCAPS] = {"ML-KEM-1024-ENCAPS", wh_Bench_Mod_MlKem1024Encaps, BENCH_THROUGHPUT_OPS, 0, NULL}, + [BENCH_MODULE_IDX_ML_KEM_1024_ENCAPS_DMA] = {"ML-KEM-1024-ENCAPS-DMA", wh_Bench_Mod_MlKem1024EncapsDma, BENCH_THROUGHPUT_OPS, 0, NULL}, + [BENCH_MODULE_IDX_ML_KEM_1024_DECAPS] = {"ML-KEM-1024-DECAPS", wh_Bench_Mod_MlKem1024Decaps, BENCH_THROUGHPUT_OPS, 0, NULL}, + [BENCH_MODULE_IDX_ML_KEM_1024_DECAPS_DMA] = {"ML-KEM-1024-DECAPS-DMA", wh_Bench_Mod_MlKem1024DecapsDma, BENCH_THROUGHPUT_OPS, 0, NULL}, +#endif /* !WOLFSSL_NO_ML_KEM_1024 */ +#endif /* WOLFSSL_HAVE_MLKEM */ #endif /* !(WOLFHSM_CFG_NO_CRYPTO) */ }; /* clang-format on */ diff --git a/benchmark/wh_bench_ops.h b/benchmark/wh_bench_ops.h index 30329664f..d2d0d3aad 100644 --- a/benchmark/wh_bench_ops.h +++ b/benchmark/wh_bench_ops.h @@ -26,7 +26,7 @@ #include /* Maximum number of operations that can be registered */ -#define MAX_BENCH_OPS 101 +#define MAX_BENCH_OPS 119 /* Maximum length of operation name */ #define MAX_OP_NAME 64 diff --git a/examples/posix/wh_posix_server/user_settings.h b/examples/posix/wh_posix_server/user_settings.h index 99e120ca3..8bdd5bc93 100644 --- a/examples/posix/wh_posix_server/user_settings.h +++ b/examples/posix/wh_posix_server/user_settings.h @@ -151,6 +151,10 @@ extern "C" { #define WOLFSSL_DILITHIUM_NO_MAKE_KEY #endif +/* ML-KEM Options */ +#define WOLFSSL_HAVE_MLKEM +#define WOLFSSL_WC_MLKEM + /** Composite features */ #define HAVE_HKDF #define HAVE_CMAC_KDF diff --git a/src/wh_client_crypto.c b/src/wh_client_crypto.c index 96ec5747e..dd6b1f250 100644 --- a/src/wh_client_crypto.c +++ b/src/wh_client_crypto.c @@ -55,6 +55,7 @@ #include "wolfssl/wolfcrypt/curve25519.h" #include "wolfssl/wolfcrypt/ed25519.h" #include "wolfssl/wolfcrypt/dilithium.h" +#include "wolfssl/wolfcrypt/wc_mlkem.h" #include "wolfssl/wolfcrypt/sha256.h" #include "wolfssl/wolfcrypt/sha512.h" #endif @@ -127,6 +128,17 @@ static int _MlDsaMakeKeyDma(whClientContext* ctx, int level, #endif /* WOLFHSM_CFG_DMA */ #endif /* HAVE_DILITHIUM */ +#ifdef WOLFSSL_HAVE_MLKEM +static int _MlKemMakeKey(whClientContext* ctx, int level, + whKeyId* inout_key_id, whNvmFlags flags, + uint16_t label_len, uint8_t* label, MlKemKey* key); +#ifdef WOLFHSM_CFG_DMA +static int _MlKemMakeKeyDma(whClientContext* ctx, int level, + whKeyId* inout_key_id, whNvmFlags flags, + uint16_t label_len, uint8_t* label, MlKemKey* key); +#endif /* WOLFHSM_CFG_DMA */ +#endif /* WOLFSSL_HAVE_MLKEM */ + static uint8_t* _createCryptoRequest(uint8_t* reqBuf, uint16_t type, uint32_t affinity); static uint8_t* _createCryptoRequestWithSubtype(uint8_t* reqBuf, uint16_t type, @@ -7970,4 +7982,920 @@ int wh_Client_MlDsaCheckPrivKeyDma(whClientContext* ctx, MlDsaKey* key, #endif /* WOLFHSM_CFG_DMA */ #endif /* HAVE_DILITHIUM */ +#ifdef WOLFSSL_HAVE_MLKEM + +int wh_Client_MlKemSetKeyId(MlKemKey* key, whKeyId keyId) +{ + if (key == NULL) { + return WH_ERROR_BADARGS; + } + + key->devCtx = WH_KEYID_TO_DEVCTX(keyId); + return WH_ERROR_OK; +} + +int wh_Client_MlKemGetKeyId(MlKemKey* key, whKeyId* outId) +{ + if ((key == NULL) || (outId == NULL)) { + return WH_ERROR_BADARGS; + } + + *outId = WH_DEVCTX_TO_KEYID(key->devCtx); + return WH_ERROR_OK; +} + +int wh_Client_MlKemImportKey(whClientContext* ctx, MlKemKey* key, + whKeyId* inout_keyId, whNvmFlags flags, + uint16_t label_len, uint8_t* label) +{ + int ret = WH_ERROR_OK; + whKeyId key_id = WH_KEYID_ERASED; + byte* buffer = NULL; + uint16_t buffer_len = 0; + word32 allocSz = 0; + + if ((ctx == NULL) || (key == NULL) || + ((label_len != 0) && (label == NULL))) { + return WH_ERROR_BADARGS; + } + + /* Use exact key size based on level to avoid over-allocation */ + ret = wc_MlKemKey_PrivateKeySize(key, &allocSz); + if (ret != 0) { + /* Fall back to public key size if no private key */ + ret = wc_MlKemKey_PublicKeySize(key, &allocSz); + } + if (ret != 0 || allocSz == 0) { + return WH_ERROR_BADARGS; + } + + buffer = (byte*)XMALLOC(allocSz, NULL, DYNAMIC_TYPE_TMP_BUFFER); + if (buffer == NULL) { + return WH_ERROR_ABORTED; + } + + if (inout_keyId != NULL) { + key_id = *inout_keyId; + } + + ret = wh_Crypto_MlKemSerializeKey(key, (uint16_t)allocSz, buffer, + &buffer_len); + WH_DEBUG_CLIENT_VERBOSE("MlKemImportKey: serialize ret:%d, len:%u\n", + ret, (unsigned int)buffer_len); + if (ret == WH_ERROR_OK) { + ret = wh_Client_KeyCache(ctx, flags, label, label_len, buffer, + buffer_len, &key_id); + if ((ret == WH_ERROR_OK) && (inout_keyId != NULL)) { + *inout_keyId = key_id; + } + } + WH_DEBUG_CLIENT_VERBOSE("MlKemImportKey: ret:%d keyId:%u\n", ret, key_id); + + wc_ForceZero(buffer, allocSz); + XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER); + return ret; +} + +int wh_Client_MlKemExportKey(whClientContext* ctx, whKeyId keyId, MlKemKey* key, + uint16_t label_len, uint8_t* label) +{ + int ret = WH_ERROR_OK; + byte* buffer = NULL; + uint16_t buffer_len = WC_ML_KEM_MAX_PRIVATE_KEY_SIZE; + + if ((ctx == NULL) || WH_KEYID_ISERASED(keyId) || (key == NULL)) { + return WH_ERROR_BADARGS; + } + + buffer = (byte*)XMALLOC(WC_ML_KEM_MAX_PRIVATE_KEY_SIZE, NULL, + DYNAMIC_TYPE_TMP_BUFFER); + if (buffer == NULL) { + return WH_ERROR_ABORTED; + } + + ret = + wh_Client_KeyExport(ctx, keyId, label, label_len, buffer, &buffer_len); + WH_DEBUG_CLIENT_VERBOSE("MlKemExportKey: export ret:%d, len:%u\n", + ret, (unsigned int)buffer_len); + if (ret == WH_ERROR_OK) { + ret = wh_Crypto_MlKemDeserializeKey(buffer, buffer_len, key); + } + WH_DEBUG_CLIENT_VERBOSE("MlKemExportKey: keyId:%x ret:%d\n", keyId, ret); + + wc_ForceZero(buffer, WC_ML_KEM_MAX_PRIVATE_KEY_SIZE); + XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER); + return ret; +} + +static int _MlKemMakeKey(whClientContext* ctx, int level, + whKeyId* inout_key_id, whNvmFlags flags, + uint16_t label_len, uint8_t* label, MlKemKey* key) +{ + int ret = WH_ERROR_OK; + whKeyId key_id = WH_KEYID_ERASED; + uint8_t* dataPtr = NULL; + whMessageCrypto_MlKemKeyGenRequest* req = NULL; + whMessageCrypto_MlKemKeyGenResponse* res = NULL; + uint16_t group = WH_MESSAGE_GROUP_CRYPTO; + uint16_t action = WC_ALGO_TYPE_PK; + uint16_t req_len; + uint16_t res_len; + + if (ctx == NULL) { + return WH_ERROR_BADARGS; + } + + dataPtr = (uint8_t*)wh_CommClient_GetDataPtr(ctx->comm); + if (dataPtr == NULL) { + return WH_ERROR_BADARGS; + } + + req = (whMessageCrypto_MlKemKeyGenRequest*)_createCryptoRequestWithSubtype( + dataPtr, WC_PK_TYPE_PQC_KEM_KEYGEN, WC_PQC_KEM_TYPE_KYBER, + ctx->cryptoAffinity); + + if (inout_key_id != NULL) { + key_id = *inout_key_id; + } + + req_len = sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req); + + /* Defense in depth: ensure request fits in comm buffer */ + if (req_len > WOLFHSM_CFG_COMM_DATA_LEN) { + return WH_ERROR_BADARGS; + } + + memset(req, 0, sizeof(*req)); + req->level = level; + req->flags = flags; + req->keyId = key_id; + req->access = WH_NVM_ACCESS_ANY; + if ((label != NULL) && (label_len > 0)) { + if (label_len > WH_NVM_LABEL_LEN) { + label_len = WH_NVM_LABEL_LEN; + } + memcpy(req->label, label, label_len); + } + + ret = wh_Client_SendRequest(ctx, group, action, req_len, + (uint8_t*)dataPtr); + WH_DEBUG_CLIENT_VERBOSE("MlKemMakeKey: Req sent:level:%d, ret:%d\n", + level, ret); + if (ret != WH_ERROR_OK) { + return ret; + } + + do { + ret = wh_Client_RecvResponse(ctx, &group, &action, &res_len, + (uint8_t*)dataPtr); + } while (ret == WH_ERROR_NOTREADY); + if (ret != WH_ERROR_OK) { + return ret; + } + + ret = _getCryptoResponse(dataPtr, WC_PK_TYPE_PQC_KEM_KEYGEN, + (uint8_t**)&res); + if (ret >= 0) { + key_id = (whKeyId)res->keyId; + WH_DEBUG_CLIENT_VERBOSE("MlKemMakeKey: Res recv:" + "keyId:%u, len:%u, ret:%d\n", + (unsigned int)res->keyId, + (unsigned int)res->len, ret); + if (inout_key_id != NULL) { + *inout_key_id = key_id; + } + if (key != NULL) { + wh_Client_MlKemSetKeyId(key, key_id); + if ((flags & WH_NVM_FLAGS_EPHEMERAL) != 0) { + uint8_t* key_raw = (uint8_t*)(res + 1); + word32 max_resp = WOLFHSM_CFG_COMM_DATA_LEN - + (word32)((uint8_t*)key_raw - dataPtr); + if (res->len > max_resp) { + ret = WH_ERROR_BADARGS; + } + else { + ret = wh_Crypto_MlKemDeserializeKey( + key_raw, (uint16_t)res->len, key); + } + } + } + } + return ret; +} + +int wh_Client_MlKemMakeCacheKey(whClientContext* ctx, int level, + whKeyId* inout_key_id, whNvmFlags flags, + uint16_t label_len, uint8_t* label) +{ + if (inout_key_id == NULL) { + return WH_ERROR_BADARGS; + } + + return _MlKemMakeKey(ctx, level, inout_key_id, flags, label_len, + label, NULL); +} + +int wh_Client_MlKemMakeExportKey(whClientContext* ctx, int level, + MlKemKey* key) +{ + if (key == NULL) { + return WH_ERROR_BADARGS; + } + + return _MlKemMakeKey(ctx, level, NULL, WH_NVM_FLAGS_EPHEMERAL, 0, + NULL, key); +} + +int wh_Client_MlKemEncapsulate(whClientContext* ctx, MlKemKey* key, + byte* ct, word32* inout_ct_len, byte* ss, + word32* inout_ss_len) +{ + int ret = WH_ERROR_OK; + uint8_t* dataPtr = NULL; + whMessageCrypto_MlKemEncapsRequest* req = NULL; + whMessageCrypto_MlKemEncapsResponse* res = NULL; + + whKeyId key_id; + int evict = 0; + + if ((ctx == NULL) || (key == NULL) || (ct == NULL) || (ss == NULL) || + (inout_ct_len == NULL) || (inout_ss_len == NULL)) { + return WH_ERROR_BADARGS; + } + + if ((*inout_ct_len == 0) || (*inout_ss_len == 0)) { + return WH_ERROR_BADARGS; + } + + key_id = WH_DEVCTX_TO_KEYID(key->devCtx); + if (WH_KEYID_ISERASED(key_id)) { + uint8_t keyLabel[] = "TempMlKemEncaps"; + whNvmFlags flags = WH_NVM_FLAGS_USAGE_DERIVE; + ret = wh_Client_MlKemImportKey(ctx, key, &key_id, flags, + sizeof(keyLabel), keyLabel); + if (ret == WH_ERROR_OK) { + evict = 1; + } + } + + if (ret == WH_ERROR_OK) { + uint16_t group = WH_MESSAGE_GROUP_CRYPTO; + uint16_t action = WC_ALGO_TYPE_PK; + uint16_t req_len = + sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req); + uint32_t options = 0; + + dataPtr = (uint8_t*)wh_CommClient_GetDataPtr(ctx->comm); + if (dataPtr == NULL) { + if (evict != 0) { + (void)wh_Client_KeyEvict(ctx, key_id); + } + return WH_ERROR_BADARGS; + } + + req = (whMessageCrypto_MlKemEncapsRequest*) + _createCryptoRequestWithSubtype(dataPtr, WC_PK_TYPE_PQC_KEM_ENCAPS, + WC_PQC_KEM_TYPE_KYBER, + ctx->cryptoAffinity); + + if (req_len <= WOLFHSM_CFG_COMM_DATA_LEN) { + if (evict != 0) { + options |= WH_MESSAGE_CRYPTO_MLKEM_ENCAPS_OPTIONS_EVICT; + } + + memset(req, 0, sizeof(*req)); + req->options = options; + req->level = key->type; + req->keyId = key_id; + + ret = wh_Client_SendRequest(ctx, group, action, req_len, + (uint8_t*)dataPtr); + WH_DEBUG_CLIENT_VERBOSE("MlKemEncapsulate: Req sent:keyId:%u, " + "level:%u, ret:%d\n", + (unsigned int)key_id, + (unsigned int)key->type, ret); + if (ret == WH_ERROR_OK) { + evict = 0; + do { + ret = wh_Client_RecvResponse(ctx, &group, &action, &req_len, + (uint8_t*)dataPtr); + } while (ret == WH_ERROR_NOTREADY); + } + + if (ret == WH_ERROR_OK) { + ret = _getCryptoResponse(dataPtr, WC_PK_TYPE_PQC_KEM_ENCAPS, + (uint8_t**)&res); + if (ret >= 0) { + uint8_t* resp_data = (uint8_t*)(res + 1); + word32 out_ct_len = res->ctSz; + word32 out_ss_len = res->ssSz; + word32 max_resp = WOLFHSM_CFG_COMM_DATA_LEN - + (word32)((uint8_t*)resp_data - dataPtr); + WH_DEBUG_CLIENT_VERBOSE("MlKemEncapsulate: Res recv:" + "ctSz:%u, ssSz:%u, ret:%d\n", + (unsigned int)out_ct_len, + (unsigned int)out_ss_len, ret); + if (out_ct_len + out_ss_len > max_resp || + *inout_ct_len < out_ct_len || + *inout_ss_len < out_ss_len) { + ret = WH_ERROR_BADARGS; + } + else { + memcpy(ct, resp_data, out_ct_len); + memcpy(ss, resp_data + out_ct_len, out_ss_len); + *inout_ct_len = out_ct_len; + *inout_ss_len = out_ss_len; + } + } + } + } + else { + ret = WH_ERROR_BADARGS; + } + } + + if (evict != 0) { + (void)wh_Client_KeyEvict(ctx, key_id); + } + + if (ret != WH_ERROR_OK) { + wc_ForceZero(ss, *inout_ss_len); + } + + return ret; +} + +int wh_Client_MlKemDecapsulate(whClientContext* ctx, MlKemKey* key, + const byte* ct, word32 ct_len, byte* ss, + word32* inout_ss_len) +{ + int ret = WH_ERROR_OK; + uint8_t* dataPtr = NULL; + whMessageCrypto_MlKemDecapsRequest* req = NULL; + whMessageCrypto_MlKemDecapsResponse* res = NULL; + + whKeyId key_id; + int evict = 0; + + if ((ctx == NULL) || (key == NULL) || ((ct == NULL) && (ct_len > 0)) || + (ss == NULL) || (inout_ss_len == NULL)) { + return WH_ERROR_BADARGS; + } + + key_id = WH_DEVCTX_TO_KEYID(key->devCtx); + if (WH_KEYID_ISERASED(key_id)) { + uint8_t keyLabel[] = "TempMlKemDecaps"; + whNvmFlags flags = WH_NVM_FLAGS_USAGE_DERIVE; + ret = wh_Client_MlKemImportKey(ctx, key, &key_id, flags, + sizeof(keyLabel), keyLabel); + if (ret == WH_ERROR_OK) { + evict = 1; + } + } + + if (ret == WH_ERROR_OK) { + uint16_t group = WH_MESSAGE_GROUP_CRYPTO; + uint16_t action = WC_ALGO_TYPE_PK; + uint32_t options = 0; + uint64_t total_len = (uint64_t)sizeof(whMessageCrypto_GenericRequestHeader) + + sizeof(*req) + ct_len; + + dataPtr = (uint8_t*)wh_CommClient_GetDataPtr(ctx->comm); + if (dataPtr == NULL) { + if (evict != 0) { + (void)wh_Client_KeyEvict(ctx, key_id); + } + return WH_ERROR_BADARGS; + } + + req = (whMessageCrypto_MlKemDecapsRequest*) + _createCryptoRequestWithSubtype(dataPtr, WC_PK_TYPE_PQC_KEM_DECAPS, + WC_PQC_KEM_TYPE_KYBER, + ctx->cryptoAffinity); + + if (total_len <= WOLFHSM_CFG_COMM_DATA_LEN) { + uint8_t* req_ct = (uint8_t*)(req + 1); + uint16_t req_len = (uint16_t)total_len; + + if (evict != 0) { + options |= WH_MESSAGE_CRYPTO_MLKEM_DECAPS_OPTIONS_EVICT; + } + + memset(req, 0, sizeof(*req)); + req->options = options; + req->level = key->type; + req->keyId = key_id; + req->ctSz = ct_len; + if ((ct != NULL) && (ct_len > 0)) { + memcpy(req_ct, ct, ct_len); + } + + ret = wh_Client_SendRequest(ctx, group, action, req_len, + (uint8_t*)dataPtr); + WH_DEBUG_CLIENT_VERBOSE("MlKemDecapsulate: Req sent:keyId:%u, " + "ctSz:%u, ret:%d\n", + (unsigned int)key_id, + (unsigned int)ct_len, ret); + if (ret == WH_ERROR_OK) { + evict = 0; + do { + ret = wh_Client_RecvResponse(ctx, &group, &action, &req_len, + (uint8_t*)dataPtr); + } while (ret == WH_ERROR_NOTREADY); + } + + if (ret == WH_ERROR_OK) { + ret = _getCryptoResponse(dataPtr, WC_PK_TYPE_PQC_KEM_DECAPS, + (uint8_t**)&res); + if (ret >= 0) { + uint8_t* resp_ss = (uint8_t*)(res + 1); + word32 out_ss_len = res->ssSz; + word32 max_resp = WOLFHSM_CFG_COMM_DATA_LEN - + (word32)((uint8_t*)resp_ss - dataPtr); + WH_DEBUG_CLIENT_VERBOSE("MlKemDecapsulate: Res recv:" + "ssSz:%u, ret:%d\n", + (unsigned int)out_ss_len, ret); + if (out_ss_len > max_resp || + *inout_ss_len < out_ss_len) { + ret = WH_ERROR_BADARGS; + } + else { + memcpy(ss, resp_ss, out_ss_len); + *inout_ss_len = out_ss_len; + } + } + } + } + else { + ret = WH_ERROR_BADARGS; + } + } + + if (evict != 0) { + (void)wh_Client_KeyEvict(ctx, key_id); + } + + if (ret != WH_ERROR_OK) { + wc_ForceZero(ss, *inout_ss_len); + } + + return ret; +} + +#ifdef WOLFHSM_CFG_DMA +int wh_Client_MlKemImportKeyDma(whClientContext* ctx, MlKemKey* key, + whKeyId* inout_keyId, whNvmFlags flags, + uint16_t label_len, uint8_t* label) +{ + int ret = WH_ERROR_OK; + whKeyId key_id = WH_KEYID_ERASED; + byte* buffer = NULL; + uint16_t buffer_len = 0; + word32 allocSz = 0; + + if ((ctx == NULL) || (key == NULL) || + ((label_len != 0) && (label == NULL))) { + return WH_ERROR_BADARGS; + } + + /* Use exact key size based on level to avoid over-allocation */ + ret = wc_MlKemKey_PrivateKeySize(key, &allocSz); + if (ret != 0) { + ret = wc_MlKemKey_PublicKeySize(key, &allocSz); + } + if (ret != 0 || allocSz == 0) { + return WH_ERROR_BADARGS; + } + + buffer = (byte*)XMALLOC(allocSz, NULL, DYNAMIC_TYPE_TMP_BUFFER); + if (buffer == NULL) { + return WH_ERROR_ABORTED; + } + + if (inout_keyId != NULL) { + key_id = *inout_keyId; + } + + ret = wh_Crypto_MlKemSerializeKey(key, (uint16_t)allocSz, buffer, + &buffer_len); + WH_DEBUG_CLIENT_VERBOSE("MlKemImportKeyDma: serialize ret:%d, len:%u\n", + ret, (unsigned int)buffer_len); + if (ret == WH_ERROR_OK) { + ret = wh_Client_KeyCacheDma(ctx, flags, label, label_len, buffer, + buffer_len, &key_id); + if ((ret == WH_ERROR_OK) && (inout_keyId != NULL)) { + *inout_keyId = key_id; + } + } + WH_DEBUG_CLIENT_VERBOSE("MlKemImportKeyDma: ret:%d keyId:%u\n", + ret, key_id); + + wc_ForceZero(buffer, allocSz); + XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER); + return ret; +} + +int wh_Client_MlKemExportKeyDma(whClientContext* ctx, whKeyId keyId, + MlKemKey* key, uint16_t label_len, + uint8_t* label) +{ + int ret = WH_ERROR_OK; + byte* buffer = NULL; + uint16_t buffer_len = WC_ML_KEM_MAX_PRIVATE_KEY_SIZE; + + if ((ctx == NULL) || WH_KEYID_ISERASED(keyId) || (key == NULL)) { + return WH_ERROR_BADARGS; + } + + buffer = (byte*)XMALLOC(WC_ML_KEM_MAX_PRIVATE_KEY_SIZE, NULL, + DYNAMIC_TYPE_TMP_BUFFER); + if (buffer == NULL) { + return WH_ERROR_ABORTED; + } + memset(buffer, 0, WC_ML_KEM_MAX_PRIVATE_KEY_SIZE); + + ret = wh_Client_KeyExportDma(ctx, keyId, buffer, buffer_len, label, + label_len, &buffer_len); + WH_DEBUG_CLIENT_VERBOSE("MlKemExportKeyDma: export ret:%d, len:%u\n", + ret, (unsigned int)buffer_len); + if (ret == WH_ERROR_OK) { + ret = wh_Crypto_MlKemDeserializeKey(buffer, buffer_len, key); + } + WH_DEBUG_CLIENT_VERBOSE("MlKemExportKeyDma: keyId:%x ret:%d\n", + keyId, ret); + + wc_ForceZero(buffer, WC_ML_KEM_MAX_PRIVATE_KEY_SIZE); + XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER); + return ret; +} + +static int _MlKemMakeKeyDma(whClientContext* ctx, int level, + whKeyId* inout_key_id, whNvmFlags flags, + uint16_t label_len, uint8_t* label, MlKemKey* key) +{ + int ret = WH_ERROR_OK; + whKeyId key_id = WH_KEYID_ERASED; + byte* buffer = NULL; + uint8_t* dataPtr = NULL; + whMessageCrypto_MlKemKeyGenDmaRequest* req = NULL; + whMessageCrypto_MlKemKeyGenDmaResponse* res = NULL; + uintptr_t keyAddr = 0; + uint32_t allocSz = 0; + + if ((ctx == NULL) || (key == NULL)) { + return WH_ERROR_BADARGS; + } + + ret = wc_MlKemKey_PrivateKeySize(key, &allocSz); + if (ret != 0) { + return WH_ERROR_BADARGS; + } + else { + buffer = (byte*)XMALLOC(allocSz, NULL, DYNAMIC_TYPE_TMP_BUFFER); + if (buffer == NULL) { + return WH_ERROR_ABORTED; + } + } + + dataPtr = (uint8_t*)wh_CommClient_GetDataPtr(ctx->comm); + if (dataPtr == NULL) { + XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER); + return WH_ERROR_BADARGS; + } + + req = (whMessageCrypto_MlKemKeyGenDmaRequest*)_createCryptoRequestWithSubtype( + dataPtr, WC_PK_TYPE_PQC_KEM_KEYGEN, WC_PQC_KEM_TYPE_KYBER, + ctx->cryptoAffinity); + + if (inout_key_id != NULL) { + key_id = *inout_key_id; + } + + uint16_t group = WH_MESSAGE_GROUP_CRYPTO_DMA; + uint16_t action = WC_ALGO_TYPE_PK; + uint16_t req_len = + sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req); + + if (req_len > WOLFHSM_CFG_COMM_DATA_LEN) { + XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER); + return WH_ERROR_BADARGS; + } + + memset(req, 0, sizeof(*req)); + req->level = level; + req->flags = flags; + req->keyId = key_id; + req->access = WH_NVM_ACCESS_ANY; + req->key.sz = allocSz; + + ret = wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)buffer, (void**)&keyAddr, allocSz, + WH_DMA_OPER_CLIENT_WRITE_PRE, (whDmaFlags){0}); + if (ret == WH_ERROR_OK) { + req->key.addr = (uint64_t)(uintptr_t)keyAddr; + } + + if ((label != NULL) && (label_len > 0)) { + if (label_len > WH_NVM_LABEL_LEN) { + label_len = WH_NVM_LABEL_LEN; + } + memcpy(req->label, label, label_len); + req->labelSize = label_len; + } + + if (ret == WH_ERROR_OK) { + ret = wh_Client_SendRequest(ctx, group, action, req_len, + (uint8_t*)dataPtr); + } + if (ret == WH_ERROR_OK) { + do { + ret = wh_Client_RecvResponse(ctx, &group, &action, &req_len, + (uint8_t*)dataPtr); + } while (ret == WH_ERROR_NOTREADY); + } + + (void)wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)buffer, (void**)&keyAddr, allocSz, + WH_DMA_OPER_CLIENT_WRITE_POST, (whDmaFlags){0}); + + if (ret == WH_ERROR_OK) { + ret = _getCryptoResponse(dataPtr, WC_PK_TYPE_PQC_KEM_KEYGEN, + (uint8_t**)&res); + if (ret >= 0) { + key_id = (whKeyId)res->keyId; + if (inout_key_id != NULL) { + *inout_key_id = key_id; + } + if (key != NULL) { + wh_Client_MlKemSetKeyId(key, key_id); + if ((flags & WH_NVM_FLAGS_EPHEMERAL) != 0) { + if (res->keySize > allocSz) { + ret = WH_ERROR_BADARGS; + } + else { + ret = wh_Crypto_MlKemDeserializeKey( + buffer, (uint16_t)res->keySize, key); + } + } + } + } + } + + wc_ForceZero(buffer, allocSz); + XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER); + return ret; +} + +int wh_Client_MlKemMakeExportKeyDma(whClientContext* ctx, int level, + MlKemKey* key) +{ + if (key == NULL) { + return WH_ERROR_BADARGS; + } + + return _MlKemMakeKeyDma(ctx, level, NULL, WH_NVM_FLAGS_EPHEMERAL, 0, NULL, + key); +} + +int wh_Client_MlKemEncapsulateDma(whClientContext* ctx, MlKemKey* key, + byte* ct, word32* inout_ct_len, byte* ss, + word32* inout_ss_len) +{ + int ret = WH_ERROR_OK; + uint8_t* dataPtr = NULL; + whMessageCrypto_MlKemEncapsDmaRequest* req = NULL; + whMessageCrypto_MlKemEncapsDmaResponse* res = NULL; + uintptr_t ctAddr = 0; + whKeyId key_id; + int evict = 0; + uint32_t options = 0; + word32 origCtSz; + + if ((ctx == NULL) || (key == NULL) || (ct == NULL) || (ss == NULL) || + (inout_ct_len == NULL) || (inout_ss_len == NULL)) { + return WH_ERROR_BADARGS; + } + + origCtSz = *inout_ct_len; + + key_id = WH_DEVCTX_TO_KEYID(key->devCtx); + if (WH_KEYID_ISERASED(key_id)) { + uint8_t keyLabel[] = "TempMlKemEncaps"; + whNvmFlags flags = WH_NVM_FLAGS_USAGE_DERIVE; + ret = wh_Client_MlKemImportKeyDma(ctx, key, &key_id, flags, + sizeof(keyLabel), keyLabel); + if (ret == WH_ERROR_OK) { + evict = 1; + } + } + + if (ret == WH_ERROR_OK) { + uint16_t group = WH_MESSAGE_GROUP_CRYPTO_DMA; + uint16_t action = WC_ALGO_TYPE_PK; + uint16_t req_len = + sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req); + + dataPtr = (uint8_t*)wh_CommClient_GetDataPtr(ctx->comm); + if (dataPtr == NULL) { + if (evict != 0) { + (void)wh_Client_KeyEvict(ctx, key_id); + } + return WH_ERROR_BADARGS; + } + + req = (whMessageCrypto_MlKemEncapsDmaRequest*) + _createCryptoRequestWithSubtype(dataPtr, WC_PK_TYPE_PQC_KEM_ENCAPS, + WC_PQC_KEM_TYPE_KYBER, + ctx->cryptoAffinity); + + if (req_len <= WOLFHSM_CFG_COMM_DATA_LEN) { + if (evict != 0) { + options |= WH_MESSAGE_CRYPTO_MLKEM_ENCAPS_OPTIONS_EVICT; + } + + memset(req, 0, sizeof(*req)); + req->options = options; + req->level = key->type; + req->keyId = key_id; + + req->ct.sz = origCtSz; + ret = wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)ct, (void**)&ctAddr, req->ct.sz, + WH_DMA_OPER_CLIENT_WRITE_PRE, (whDmaFlags){0}); + if (ret == WH_ERROR_OK) { + req->ct.addr = ctAddr; + } + + if (ret == WH_ERROR_OK) { + ret = wh_Client_SendRequest(ctx, group, action, req_len, + (uint8_t*)dataPtr); + } + if (ret == WH_ERROR_OK) { + evict = 0; + do { + ret = wh_Client_RecvResponse(ctx, &group, &action, &req_len, + (uint8_t*)dataPtr); + } while (ret == WH_ERROR_NOTREADY); + } + + if (ret == WH_ERROR_OK) { + ret = _getCryptoResponse(dataPtr, WC_PK_TYPE_PQC_KEM_ENCAPS, + (uint8_t**)&res); + if (ret >= 0) { + /* ct was transferred via DMA, ss is inline in response */ + uint8_t* resp_ss = (uint8_t*)(res + 1); + word32 max_resp = WOLFHSM_CFG_COMM_DATA_LEN - + (word32)((uint8_t*)resp_ss - dataPtr); + if (res->ctLen > origCtSz || + res->ssLen > max_resp || + res->ssLen > *inout_ss_len) { + ret = WH_ERROR_BADARGS; + } + else { + memcpy(ss, resp_ss, res->ssLen); + *inout_ct_len = res->ctLen; + *inout_ss_len = res->ssLen; + } + } + } + + (void)wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)ct, (void**)&ctAddr, origCtSz, + WH_DMA_OPER_CLIENT_WRITE_POST, (whDmaFlags){0}); + } + else { + ret = WH_ERROR_BADARGS; + } + } + + if (evict != 0) { + (void)wh_Client_KeyEvict(ctx, key_id); + } + + if (ret != WH_ERROR_OK) { + wc_ForceZero(ss, *inout_ss_len); + } + + return ret; +} + +int wh_Client_MlKemDecapsulateDma(whClientContext* ctx, MlKemKey* key, + const byte* ct, word32 ct_len, byte* ss, + word32* inout_ss_len) +{ + int ret = WH_ERROR_OK; + uint8_t* dataPtr = NULL; + whMessageCrypto_MlKemDecapsDmaRequest* req = NULL; + whMessageCrypto_MlKemDecapsDmaResponse* res = NULL; + uintptr_t ctAddr = 0; + whKeyId key_id; + int evict = 0; + uint32_t options = 0; + + if ((ctx == NULL) || (key == NULL) || ((ct == NULL) && (ct_len > 0)) || + (ss == NULL) || (inout_ss_len == NULL)) { + return WH_ERROR_BADARGS; + } + + key_id = WH_DEVCTX_TO_KEYID(key->devCtx); + if (WH_KEYID_ISERASED(key_id)) { + uint8_t keyLabel[] = "TempMlKemDecaps"; + whNvmFlags flags = WH_NVM_FLAGS_USAGE_DERIVE; + ret = wh_Client_MlKemImportKeyDma(ctx, key, &key_id, flags, + sizeof(keyLabel), keyLabel); + if (ret == WH_ERROR_OK) { + evict = 1; + } + } + + if (ret == WH_ERROR_OK) { + uint16_t group = WH_MESSAGE_GROUP_CRYPTO_DMA; + uint16_t action = WC_ALGO_TYPE_PK; + uint16_t req_len = + sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req); + + dataPtr = (uint8_t*)wh_CommClient_GetDataPtr(ctx->comm); + if (dataPtr == NULL) { + if (evict != 0) { + (void)wh_Client_KeyEvict(ctx, key_id); + } + return WH_ERROR_BADARGS; + } + + req = (whMessageCrypto_MlKemDecapsDmaRequest*) + _createCryptoRequestWithSubtype(dataPtr, WC_PK_TYPE_PQC_KEM_DECAPS, + WC_PQC_KEM_TYPE_KYBER, + ctx->cryptoAffinity); + + if (req_len <= WOLFHSM_CFG_COMM_DATA_LEN) { + if (evict != 0) { + options |= WH_MESSAGE_CRYPTO_MLKEM_DECAPS_OPTIONS_EVICT; + } + + memset(req, 0, sizeof(*req)); + req->options = options; + req->level = key->type; + req->keyId = key_id; + + req->ct.sz = ct_len; + ret = wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)ct, (void**)&ctAddr, req->ct.sz, + WH_DMA_OPER_CLIENT_READ_PRE, (whDmaFlags){0}); + if (ret == WH_ERROR_OK) { + req->ct.addr = ctAddr; + } + + if (ret == WH_ERROR_OK) { + ret = wh_Client_SendRequest(ctx, group, action, req_len, + (uint8_t*)dataPtr); + } + if (ret == WH_ERROR_OK) { + evict = 0; + do { + ret = wh_Client_RecvResponse(ctx, &group, &action, &req_len, + (uint8_t*)dataPtr); + } while (ret == WH_ERROR_NOTREADY); + } + + if (ret == WH_ERROR_OK) { + ret = _getCryptoResponse(dataPtr, WC_PK_TYPE_PQC_KEM_DECAPS, + (uint8_t**)&res); + if (ret >= 0) { + /* ss is inline in response, not via DMA */ + uint8_t* resp_ss = (uint8_t*)(res + 1); + word32 max_resp = WOLFHSM_CFG_COMM_DATA_LEN - + (word32)((uint8_t*)resp_ss - dataPtr); + if (res->ssLen > max_resp || + res->ssLen > *inout_ss_len) { + ret = WH_ERROR_BADARGS; + } + else { + memcpy(ss, resp_ss, res->ssLen); + *inout_ss_len = res->ssLen; + } + } + } + + (void)wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)ct, (void**)&ctAddr, ct_len, + WH_DMA_OPER_CLIENT_READ_POST, (whDmaFlags){0}); + } + else { + ret = WH_ERROR_BADARGS; + } + } + + if (evict != 0) { + (void)wh_Client_KeyEvict(ctx, key_id); + } + + if (ret != WH_ERROR_OK) { + wc_ForceZero(ss, *inout_ss_len); + } + + return ret; +} +#endif /* WOLFHSM_CFG_DMA */ +#endif /* WOLFSSL_HAVE_MLKEM */ + #endif /* !WOLFHSM_CFG_NO_CRYPTO && WOLFHSM_CFG_ENABLE_CLIENT */ diff --git a/src/wh_client_cryptocb.c b/src/wh_client_cryptocb.c index 48c620d68..afb270c2a 100644 --- a/src/wh_client_cryptocb.c +++ b/src/wh_client_cryptocb.c @@ -47,6 +47,7 @@ #include "wolfssl/wolfcrypt/ecc.h" #include "wolfssl/wolfcrypt/sha256.h" #include "wolfssl/wolfcrypt/sha512.h" +#include "wolfssl/wolfcrypt/wc_mlkem.h" #include "wolfhsm/wh_crypto.h" #include "wolfhsm/wh_client_crypto.h" @@ -54,6 +55,15 @@ #include "wolfhsm/wh_message_crypto.h" +#if defined(WOLFSSL_HAVE_MLKEM) +static int _handlePqcKemKeyGen(whClientContext* ctx, wc_CryptoInfo* info, + int useDma); +static int _handlePqcEncaps(whClientContext* ctx, wc_CryptoInfo* info, + int useDma); +static int _handlePqcDecaps(whClientContext* ctx, wc_CryptoInfo* info, + int useDma); +#endif /* WOLFSSL_HAVE_MLKEM */ + #if defined(HAVE_DILITHIUM) || defined(HAVE_FALCON) static int _handlePqcSigKeyGen(whClientContext* ctx, wc_CryptoInfo* info, int useDma); @@ -422,6 +432,21 @@ int wh_Client_CryptoCb(int devId, wc_CryptoInfo* info, void* inCtx) #endif /* HAVE_ED25519 */ #endif /* HAVE_CURVE25519 */ +#if defined(WOLFSSL_HAVE_MLKEM) + case WC_PK_TYPE_PQC_KEM_KEYGEN: + ret = _handlePqcKemKeyGen(ctx, info, 0); + break; + + case WC_PK_TYPE_PQC_KEM_ENCAPS: + ret = _handlePqcEncaps(ctx, info, 0); + break; + + case WC_PK_TYPE_PQC_KEM_DECAPS: + ret = _handlePqcDecaps(ctx, info, 0); + break; + +#endif /* WOLFSSL_HAVE_MLKEM */ + #if defined(HAVE_DILITHIUM) || defined(HAVE_FALCON) case WC_PK_TYPE_PQC_SIG_KEYGEN: ret = _handlePqcSigKeyGen(ctx, info, 0); @@ -600,6 +625,167 @@ int wh_Client_CryptoCb(int devId, wc_CryptoInfo* info, void* inCtx) return ret; } +#if defined(WOLFSSL_HAVE_MLKEM) +static int _handlePqcKemKeyGen(whClientContext* ctx, wc_CryptoInfo* info, + int useDma) +{ + int ret = CRYPTOCB_UNAVAILABLE; + + /* Extract info parameters */ + int size = info->pk.pqc_kem_kg.size; + void* key = info->pk.pqc_kem_kg.key; + int type = info->pk.pqc_kem_kg.type; + +#ifndef WOLFHSM_CFG_DMA + if (useDma) { + /* TODO: proper error code? */ + return WC_HW_E; + } +#endif + + (void)size; + + switch (type) { + case WC_PQC_KEM_TYPE_KYBER: { + int level = ((MlKemKey*)key)->type; +#ifdef WOLFHSM_CFG_DMA + if (useDma) { + ret = wh_Client_MlKemMakeExportKeyDma(ctx, level, key); + } + else +#endif /* WOLFHSM_CFG_DMA */ + { + ret = wh_Client_MlKemMakeExportKey(ctx, level, key); + } + } break; + + default: + ret = CRYPTOCB_UNAVAILABLE; + break; + } + + if (ret == WH_ERROR_BADARGS) { + ret = BAD_FUNC_ARG; + } + else if (ret == WH_ERROR_NOTIMPL) { + ret = CRYPTOCB_UNAVAILABLE; + } + + return ret; +} + +static int _handlePqcEncaps(whClientContext* ctx, wc_CryptoInfo* info, + int useDma) +{ + int ret = CRYPTOCB_UNAVAILABLE; + + /* Extract info parameters */ + byte* ciphertext = info->pk.pqc_encaps.ciphertext; + word32 ciphertextLen = info->pk.pqc_encaps.ciphertextLen; + byte* sharedSecret = info->pk.pqc_encaps.sharedSecret; + word32 sharedSecLen = info->pk.pqc_encaps.sharedSecretLen; + void* key = info->pk.pqc_encaps.key; + int type = info->pk.pqc_encaps.type; + +#ifndef WOLFHSM_CFG_DMA + if (useDma) { + /* TODO: proper error code? */ + return WC_HW_E; + } +#endif + + switch (type) { + case WC_PQC_KEM_TYPE_KYBER: +#ifdef WOLFHSM_CFG_DMA + if (useDma) { + ret = wh_Client_MlKemEncapsulateDma(ctx, key, ciphertext, + &ciphertextLen, + sharedSecret, &sharedSecLen); + } + else +#endif /* WOLFHSM_CFG_DMA */ + { + ret = wh_Client_MlKemEncapsulate(ctx, key, ciphertext, + &ciphertextLen, sharedSecret, + &sharedSecLen); + } + if (ret == WH_ERROR_OK) { + info->pk.pqc_encaps.ciphertextLen = ciphertextLen; + info->pk.pqc_encaps.sharedSecretLen = sharedSecLen; + } + break; + + default: + ret = CRYPTOCB_UNAVAILABLE; + break; + } + + if (ret == WH_ERROR_BADARGS) { + ret = BAD_FUNC_ARG; + } + else if (ret == WH_ERROR_NOTIMPL) { + ret = CRYPTOCB_UNAVAILABLE; + } + + return ret; +} + +static int _handlePqcDecaps(whClientContext* ctx, wc_CryptoInfo* info, + int useDma) +{ + int ret = CRYPTOCB_UNAVAILABLE; + + /* Extract info parameters */ + const byte* ciphertext = info->pk.pqc_decaps.ciphertext; + word32 ciphertextLen = info->pk.pqc_decaps.ciphertextLen; + byte* sharedSecret = info->pk.pqc_decaps.sharedSecret; + word32 sharedSecLen = info->pk.pqc_decaps.sharedSecretLen; + void* key = info->pk.pqc_decaps.key; + int type = info->pk.pqc_decaps.type; + +#ifndef WOLFHSM_CFG_DMA + if (useDma) { + /* TODO: proper error code? */ + return WC_HW_E; + } +#endif + + switch (type) { + case WC_PQC_KEM_TYPE_KYBER: +#ifdef WOLFHSM_CFG_DMA + if (useDma) { + ret = wh_Client_MlKemDecapsulateDma( + ctx, key, ciphertext, ciphertextLen, sharedSecret, + &sharedSecLen); + } + else +#endif /* WOLFHSM_CFG_DMA */ + { + ret = wh_Client_MlKemDecapsulate(ctx, key, ciphertext, + ciphertextLen, sharedSecret, + &sharedSecLen); + } + if (ret == WH_ERROR_OK) { + info->pk.pqc_decaps.sharedSecretLen = sharedSecLen; + } + break; + + default: + ret = CRYPTOCB_UNAVAILABLE; + break; + } + + if (ret == WH_ERROR_BADARGS) { + ret = BAD_FUNC_ARG; + } + else if (ret == WH_ERROR_NOTIMPL) { + ret = CRYPTOCB_UNAVAILABLE; + } + + return ret; +} +#endif /* WOLFSSL_HAVE_MLKEM */ + #if defined(HAVE_FALCON) || defined(HAVE_DILITHIUM) static int _handlePqcSigKeyGen(whClientContext* ctx, wc_CryptoInfo* info, int useDma) @@ -864,6 +1050,17 @@ int wh_Client_CryptoCbDma(int devId, wc_CryptoInfo* info, void* inCtx) case WC_ALGO_TYPE_PK: { switch (info->pk.type) { +#if defined(WOLFSSL_HAVE_MLKEM) + case WC_PK_TYPE_PQC_KEM_KEYGEN: + ret = _handlePqcKemKeyGen(ctx, info, 1); + break; + case WC_PK_TYPE_PQC_KEM_ENCAPS: + ret = _handlePqcEncaps(ctx, info, 1); + break; + case WC_PK_TYPE_PQC_KEM_DECAPS: + ret = _handlePqcDecaps(ctx, info, 1); + break; +#endif /* WOLFSSL_HAVE_MLKEM */ #if defined(HAVE_DILITHIUM) || defined(HAVE_FALCON) case WC_PK_TYPE_PQC_SIG_KEYGEN: ret = _handlePqcSigKeyGen(ctx, info, 1); diff --git a/src/wh_crypto.c b/src/wh_crypto.c index d43c73ff8..3b9f6ced5 100644 --- a/src/wh_crypto.c +++ b/src/wh_crypto.c @@ -44,6 +44,8 @@ #include "wolfssl/wolfcrypt/ecc.h" #include "wolfssl/wolfcrypt/ed25519.h" #include "wolfssl/wolfcrypt/dilithium.h" +#include "wolfssl/wolfcrypt/wc_mlkem.h" +#include "wolfssl/wolfcrypt/memory.h" #include "wolfhsm/wh_error.h" #include "wolfhsm/wh_utils.h" @@ -379,6 +381,110 @@ int wh_Crypto_MlDsaDeserializeKeyDer(const uint8_t* buffer, uint16_t size, } #endif /* HAVE_DILITHIUM */ +#ifdef WOLFSSL_HAVE_MLKEM +int wh_Crypto_MlKemSerializeKey(MlKemKey* key, uint16_t max_size, + uint8_t* buffer, uint16_t* out_size) +{ + int ret = WH_ERROR_OK; + word32 keySize; + + if ((key == NULL) || (buffer == NULL) || (out_size == NULL)) { + return WH_ERROR_BADARGS; + } + + /* Try to encode the private key first. wc_MlKemKey_PrivateKeySize() + * returns the size regardless of whether a private key is present, so we + * must attempt encoding and check the return to detect public-only keys. */ + ret = wc_MlKemKey_PrivateKeySize(key, &keySize); + if (ret == WH_ERROR_OK) { + if (keySize > max_size) { + return WH_ERROR_BADARGS; + } + ret = wc_MlKemKey_EncodePrivateKey(key, buffer, keySize); + } + if (ret != WH_ERROR_OK) { + /* Private key encoding failed - try public key only */ + ret = wc_MlKemKey_PublicKeySize(key, &keySize); + if (ret == WH_ERROR_OK) { + if (keySize > max_size) { + return WH_ERROR_BADARGS; + } + ret = wc_MlKemKey_EncodePublicKey(key, buffer, keySize); + } + } + + if (ret == WH_ERROR_OK) { + *out_size = (uint16_t)keySize; + } + else { + /* Clear buffer to avoid leaking partial key material on error */ + wc_ForceZero(buffer, keySize); + } + + return ret; +} + +int wh_Crypto_MlKemDeserializeKey(const uint8_t* buffer, uint16_t size, + MlKemKey* key) +{ + static const int levels[] = { + WC_ML_KEM_512, + WC_ML_KEM_768, + WC_ML_KEM_1024 + }; + int ret; + int origLevel; + int origDevId; + void* origHeap; + word32 i; + + if ((buffer == NULL) || (key == NULL) || (size == 0)) { + return WH_ERROR_BADARGS; + } + + /* Save original key properties so we can restore on failure */ + origLevel = key->type; + origDevId = key->devId; + origHeap = key->heap; + + /* First, try decoding with the level already set in the key */ + ret = wc_MlKemKey_DecodePrivateKey(key, buffer, size); + if (ret == WH_ERROR_OK) { + return ret; + } + ret = wc_MlKemKey_DecodePublicKey(key, buffer, size); + if (ret == WH_ERROR_OK) { + return ret; + } + + /* Current level didn't work, try other levels in place */ + for (i = 0; i < (word32)(sizeof(levels) / sizeof(levels[0])); i++) { + if (levels[i] == origLevel) { + continue; + } + wc_MlKemKey_Free(key); + ret = wc_MlKemKey_Init(key, levels[i], origHeap, origDevId); + if (ret != WH_ERROR_OK) { + continue; + } + ret = wc_MlKemKey_DecodePrivateKey(key, buffer, size); + if (ret == WH_ERROR_OK) { + return ret; + } + ret = wc_MlKemKey_DecodePublicKey(key, buffer, size); + if (ret == WH_ERROR_OK) { + return ret; + } + } + + /* None of the levels worked, restore original level and devId. We return an + * in ret anyway. So we ignore the return value of wc_MlKemKey_Init(). */ + wc_MlKemKey_Free(key); + (void)wc_MlKemKey_Init(key, origLevel, origHeap, origDevId); + return ret; +} +#endif /* WOLFSSL_HAVE_MLKEM */ + #ifdef WOLFSSL_CMAC void wh_Crypto_CmacAesSaveStateToMsg(whMessageCrypto_CmacAesState* state, const Cmac* cmac) diff --git a/src/wh_message_crypto.c b/src/wh_message_crypto.c index a4e6d697b..560aa8e90 100644 --- a/src/wh_message_crypto.c +++ b/src/wh_message_crypto.c @@ -831,6 +831,91 @@ int wh_MessageCrypto_TranslateMlDsaVerifyResponse( return 0; } +/* ML-KEM Key Generation Request translation */ +int wh_MessageCrypto_TranslateMlKemKeyGenRequest( + uint16_t magic, const whMessageCrypto_MlKemKeyGenRequest* src, + whMessageCrypto_MlKemKeyGenRequest* dest) +{ + if ((src == NULL) || (dest == NULL)) { + return WH_ERROR_BADARGS; + } + WH_T32(magic, dest, src, level); + WH_T32(magic, dest, src, keyId); + WH_T32(magic, dest, src, flags); + WH_T32(magic, dest, src, access); + if (src != dest) { + memcpy(dest->label, src->label, sizeof(src->label)); + } + return 0; +} + +/* ML-KEM Key Generation Response translation */ +int wh_MessageCrypto_TranslateMlKemKeyGenResponse( + uint16_t magic, const whMessageCrypto_MlKemKeyGenResponse* src, + whMessageCrypto_MlKemKeyGenResponse* dest) +{ + if ((src == NULL) || (dest == NULL)) { + return WH_ERROR_BADARGS; + } + WH_T32(magic, dest, src, keyId); + WH_T32(magic, dest, src, len); + return 0; +} + +/* ML-KEM Encapsulation Request translation */ +int wh_MessageCrypto_TranslateMlKemEncapsRequest( + uint16_t magic, const whMessageCrypto_MlKemEncapsRequest* src, + whMessageCrypto_MlKemEncapsRequest* dest) +{ + if ((src == NULL) || (dest == NULL)) { + return WH_ERROR_BADARGS; + } + WH_T32(magic, dest, src, options); + WH_T32(magic, dest, src, level); + WH_T32(magic, dest, src, keyId); + return 0; +} + +/* ML-KEM Encapsulation Response translation */ +int wh_MessageCrypto_TranslateMlKemEncapsResponse( + uint16_t magic, const whMessageCrypto_MlKemEncapsResponse* src, + whMessageCrypto_MlKemEncapsResponse* dest) +{ + if ((src == NULL) || (dest == NULL)) { + return WH_ERROR_BADARGS; + } + WH_T32(magic, dest, src, ctSz); + WH_T32(magic, dest, src, ssSz); + return 0; +} + +/* ML-KEM Decapsulation Request translation */ +int wh_MessageCrypto_TranslateMlKemDecapsRequest( + uint16_t magic, const whMessageCrypto_MlKemDecapsRequest* src, + whMessageCrypto_MlKemDecapsRequest* dest) +{ + if ((src == NULL) || (dest == NULL)) { + return WH_ERROR_BADARGS; + } + WH_T32(magic, dest, src, options); + WH_T32(magic, dest, src, level); + WH_T32(magic, dest, src, keyId); + WH_T32(magic, dest, src, ctSz); + return 0; +} + +/* ML-KEM Decapsulation Response translation */ +int wh_MessageCrypto_TranslateMlKemDecapsResponse( + uint16_t magic, const whMessageCrypto_MlKemDecapsResponse* src, + whMessageCrypto_MlKemDecapsResponse* dest) +{ + if ((src == NULL) || (dest == NULL)) { + return WH_ERROR_BADARGS; + } + WH_T32(magic, dest, src, ssSz); + return 0; +} + /* * DMA Messages */ @@ -1143,6 +1228,143 @@ int wh_MessageCrypto_TranslateMlDsaVerifyDmaResponse( return 0; } +/* ML-KEM DMA Key Generation Request translation */ +int wh_MessageCrypto_TranslateMlKemKeyGenDmaRequest( + uint16_t magic, const whMessageCrypto_MlKemKeyGenDmaRequest* src, + whMessageCrypto_MlKemKeyGenDmaRequest* dest) +{ + int ret; + + if ((src == NULL) || (dest == NULL)) { + return WH_ERROR_BADARGS; + } + + ret = wh_MessageCrypto_TranslateDmaBuffer(magic, &src->key, &dest->key); + if (ret != 0) { + return ret; + } + + WH_T32(magic, dest, src, level); + WH_T32(magic, dest, src, flags); + WH_T32(magic, dest, src, keyId); + WH_T32(magic, dest, src, access); + WH_T32(magic, dest, src, labelSize); + if (src != dest) { + memcpy(dest->label, src->label, sizeof(src->label)); + } + + return 0; +} + +/* ML-KEM DMA Key Generation Response translation */ +int wh_MessageCrypto_TranslateMlKemKeyGenDmaResponse( + uint16_t magic, const whMessageCrypto_MlKemKeyGenDmaResponse* src, + whMessageCrypto_MlKemKeyGenDmaResponse* dest) +{ + int ret; + + if ((src == NULL) || (dest == NULL)) { + return WH_ERROR_BADARGS; + } + + ret = wh_MessageCrypto_TranslateDmaAddrStatus(magic, &src->dmaAddrStatus, + &dest->dmaAddrStatus); + if (ret != 0) { + return ret; + } + + WH_T32(magic, dest, src, keyId); + WH_T32(magic, dest, src, keySize); + return 0; +} + +/* ML-KEM DMA Encapsulation Request translation */ +int wh_MessageCrypto_TranslateMlKemEncapsDmaRequest( + uint16_t magic, const whMessageCrypto_MlKemEncapsDmaRequest* src, + whMessageCrypto_MlKemEncapsDmaRequest* dest) +{ + int ret; + + if ((src == NULL) || (dest == NULL)) { + return WH_ERROR_BADARGS; + } + + ret = wh_MessageCrypto_TranslateDmaBuffer(magic, &src->ct, &dest->ct); + if (ret != 0) { + return ret; + } + + WH_T32(magic, dest, src, options); + WH_T32(magic, dest, src, level); + WH_T32(magic, dest, src, keyId); + return 0; +} + +/* ML-KEM DMA Encapsulation Response translation */ +int wh_MessageCrypto_TranslateMlKemEncapsDmaResponse( + uint16_t magic, const whMessageCrypto_MlKemEncapsDmaResponse* src, + whMessageCrypto_MlKemEncapsDmaResponse* dest) +{ + int ret; + + if ((src == NULL) || (dest == NULL)) { + return WH_ERROR_BADARGS; + } + + ret = wh_MessageCrypto_TranslateDmaAddrStatus(magic, &src->dmaAddrStatus, + &dest->dmaAddrStatus); + if (ret != 0) { + return ret; + } + + WH_T32(magic, dest, src, ctLen); + WH_T32(magic, dest, src, ssLen); + return 0; +} + +/* ML-KEM DMA Decapsulation Request translation */ +int wh_MessageCrypto_TranslateMlKemDecapsDmaRequest( + uint16_t magic, const whMessageCrypto_MlKemDecapsDmaRequest* src, + whMessageCrypto_MlKemDecapsDmaRequest* dest) +{ + int ret; + + if ((src == NULL) || (dest == NULL)) { + return WH_ERROR_BADARGS; + } + + ret = wh_MessageCrypto_TranslateDmaBuffer(magic, &src->ct, &dest->ct); + if (ret != 0) { + return ret; + } + + WH_T32(magic, dest, src, options); + WH_T32(magic, dest, src, level); + WH_T32(magic, dest, src, keyId); + return 0; +} + +/* ML-KEM DMA Decapsulation Response translation */ +int wh_MessageCrypto_TranslateMlKemDecapsDmaResponse( + uint16_t magic, const whMessageCrypto_MlKemDecapsDmaResponse* src, + whMessageCrypto_MlKemDecapsDmaResponse* dest) +{ + int ret; + + if ((src == NULL) || (dest == NULL)) { + return WH_ERROR_BADARGS; + } + + ret = wh_MessageCrypto_TranslateDmaAddrStatus(magic, &src->dmaAddrStatus, + &dest->dmaAddrStatus); + if (ret != 0) { + return ret; + } + + WH_T32(magic, dest, src, ssLen); + return 0; +} + /* Ed25519 DMA Sign Request translation */ int wh_MessageCrypto_TranslateEd25519SignDmaRequest( uint16_t magic, const whMessageCrypto_Ed25519SignDmaRequest* src, diff --git a/src/wh_server_crypto.c b/src/wh_server_crypto.c index 33f7b1fb0..c0358fcc2 100644 --- a/src/wh_server_crypto.c +++ b/src/wh_server_crypto.c @@ -44,6 +44,7 @@ #include "wolfssl/wolfcrypt/sha512.h" #include "wolfssl/wolfcrypt/cmac.h" #include "wolfssl/wolfcrypt/dilithium.h" +#include "wolfssl/wolfcrypt/wc_mlkem.h" #include "wolfssl/wolfcrypt/hmac.h" #include "wolfssl/wolfcrypt/kdf.h" @@ -194,6 +195,32 @@ static int _HandleMlDsaCheckPrivKey(whServerContext* ctx, uint16_t magic, uint16_t* outSize); #endif /* HAVE_DILITHIUM */ +#ifdef WOLFSSL_HAVE_MLKEM +static int _HandleMlKemKeyGen(whServerContext* ctx, uint16_t magic, int devId, + const void* cryptoDataIn, uint16_t inSize, + void* cryptoDataOut, uint16_t* outSize); +static int _HandleMlKemEncaps(whServerContext* ctx, uint16_t magic, int devId, + const void* cryptoDataIn, uint16_t inSize, + void* cryptoDataOut, uint16_t* outSize); +static int _HandleMlKemDecaps(whServerContext* ctx, uint16_t magic, int devId, + const void* cryptoDataIn, uint16_t inSize, + void* cryptoDataOut, uint16_t* outSize); +#ifdef WOLFHSM_CFG_DMA +static int _HandleMlKemKeyGenDma(whServerContext* ctx, uint16_t magic, + int devId, const void* cryptoDataIn, + uint16_t inSize, void* cryptoDataOut, + uint16_t* outSize); +static int _HandleMlKemEncapsDma(whServerContext* ctx, uint16_t magic, + int devId, const void* cryptoDataIn, + uint16_t inSize, void* cryptoDataOut, + uint16_t* outSize); +static int _HandleMlKemDecapsDma(whServerContext* ctx, uint16_t magic, + int devId, const void* cryptoDataIn, + uint16_t inSize, void* cryptoDataOut, + uint16_t* outSize); +#endif /* WOLFHSM_CFG_DMA */ +#endif /* WOLFSSL_HAVE_MLKEM */ + /** Public server crypto functions */ #ifndef NO_RSA @@ -798,6 +825,60 @@ int wh_Server_MlDsaKeyCacheExport(whServerContext* ctx, whKeyId keyId, } #endif /* HAVE_DILITHIUM */ +#ifdef WOLFSSL_HAVE_MLKEM +int wh_Server_MlKemKeyCacheImport(whServerContext* ctx, MlKemKey* key, + whKeyId keyId, whNvmFlags flags, + uint16_t label_len, uint8_t* label) +{ + int ret = WH_ERROR_OK; + uint8_t* cacheBuf; + whNvmMetadata* cacheMeta; + uint16_t keySize = WC_ML_KEM_MAX_PRIVATE_KEY_SIZE; + + if ((ctx == NULL) || (key == NULL) || (WH_KEYID_ISERASED(keyId)) || + ((label != NULL) && (label_len > sizeof(cacheMeta->label)))) { + return WH_ERROR_BADARGS; + } + + ret = wh_Server_KeystoreGetCacheSlotChecked(ctx, keyId, keySize, &cacheBuf, + &cacheMeta); + if (ret == WH_ERROR_OK) { + ret = wh_Crypto_MlKemSerializeKey(key, keySize, cacheBuf, &keySize); + } + + if (ret == WH_ERROR_OK) { + cacheMeta->id = keyId; + cacheMeta->len = keySize; + cacheMeta->flags = flags; + cacheMeta->access = WH_NVM_ACCESS_ANY; + if ((label != NULL) && (label_len > 0)) { + memcpy(cacheMeta->label, label, label_len); + } + } + + return ret; +} + +int wh_Server_MlKemKeyCacheExport(whServerContext* ctx, whKeyId keyId, + MlKemKey* key) +{ + uint8_t* cacheBuf; + whNvmMetadata* cacheMeta; + int ret = WH_ERROR_OK; + + if ((ctx == NULL) || (key == NULL) || (WH_KEYID_ISERASED(keyId))) { + return WH_ERROR_BADARGS; + } + + ret = wh_Server_KeystoreFreshenKey(ctx, keyId, &cacheBuf, &cacheMeta); + if (ret == WH_ERROR_OK) { + ret = wh_Crypto_MlKemDeserializeKey(cacheBuf, cacheMeta->len, key); + WH_DEBUG_SERVER_VERBOSE("keyId:%u, ret:%d\n", keyId, ret); + } + return ret; +} +#endif /* WOLFSSL_HAVE_MLKEM */ + /** Request/Response Handling functions */ @@ -4438,62 +4519,40 @@ static int _HandleMlDsaCheckPrivKey(whServerContext* ctx, uint16_t magic, } #endif /* HAVE_DILITHIUM */ -#if defined(HAVE_DILITHIUM) || defined(HAVE_FALCON) -static int _HandlePqcSigAlgorithm(whServerContext* ctx, uint16_t magic, - int devId, const void* cryptoDataIn, - uint16_t cryptoInSize, void* cryptoDataOut, - uint16_t* cryptoOutSize, uint32_t pkAlgoType, - uint32_t pqAlgoType) +#ifdef WOLFSSL_HAVE_MLKEM +static int _IsMlKemLevelSupported(int level) { - int ret = WH_ERROR_NOHANDLER; + int ret = 0; - /* Dispatch the appropriate algorithm handler based on the requested PK type - * and the algorithm type. */ - switch (pqAlgoType) { -#ifdef HAVE_DILITHIUM - case WC_PQC_SIG_TYPE_DILITHIUM: { - switch (pkAlgoType) { - case WC_PK_TYPE_PQC_SIG_KEYGEN: - ret = _HandleMlDsaKeyGen(ctx, magic, devId, cryptoDataIn, - cryptoInSize, cryptoDataOut, - cryptoOutSize); - break; - case WC_PK_TYPE_PQC_SIG_SIGN: - ret = _HandleMlDsaSign(ctx, magic, devId, cryptoDataIn, - cryptoInSize, cryptoDataOut, - cryptoOutSize); - break; - case WC_PK_TYPE_PQC_SIG_VERIFY: - ret = _HandleMlDsaVerify(ctx, magic, devId, cryptoDataIn, - cryptoInSize, cryptoDataOut, - cryptoOutSize); - break; - case WC_PK_TYPE_PQC_SIG_CHECK_PRIV_KEY: - ret = _HandleMlDsaCheckPrivKey( - ctx, magic, devId, cryptoDataIn, cryptoInSize, - cryptoDataOut, cryptoOutSize); - break; - default: - ret = WH_ERROR_NOHANDLER; - break; - } - } break; -#endif /* HAVE_DILITHIUM */ + switch (level) { +#ifndef WOLFSSL_NO_ML_KEM_512 + case WC_ML_KEM_512: + ret = 1; + break; +#endif +#ifndef WOLFSSL_NO_ML_KEM_768 + case WC_ML_KEM_768: + ret = 1; + break; +#endif +#ifndef WOLFSSL_NO_ML_KEM_1024 + case WC_ML_KEM_1024: + ret = 1; + break; +#endif default: - ret = WH_ERROR_NOHANDLER; + ret = 0; break; } return ret; } -#endif -#if defined(HAVE_KYBER) -static int _HandlePqcKemAlgorithm(whServerContext* ctx, uint16_t magic, - int devId, const void* cryptoDataIn, - uint16_t inSize, void* cryptoDataOut, - uint16_t* outSize) +static int _HandleMlKemKeyGen(whServerContext* ctx, uint16_t magic, int devId, + const void* cryptoDataIn, uint16_t inSize, + void* cryptoDataOut, uint16_t* outSize) { +#ifdef WOLFSSL_MLKEM_NO_MAKE_KEY (void)ctx; (void)magic; (void)devId; @@ -4501,116 +4560,488 @@ static int _HandlePqcKemAlgorithm(whServerContext* ctx, uint16_t magic, (void)inSize; (void)cryptoDataOut; (void)outSize; - /* Placeholder for KEM algorithm handling */ return WH_ERROR_NOHANDLER; -} -#endif +#else + int ret = WH_ERROR_OK; + MlKemKey key[1]; + whMessageCrypto_MlKemKeyGenRequest req; + whMessageCrypto_MlKemKeyGenResponse res; + uint16_t res_size = 0; + uint8_t* res_out; + uint16_t max_size; + whKeyId key_id; + uint16_t label_size = WH_NVM_LABEL_LEN; -int wh_Server_HandleCryptoRequest(whServerContext* ctx, uint16_t magic, - uint16_t action, uint16_t seq, - uint16_t req_size, const void* req_packet, - uint16_t* out_resp_size, void* resp_packet) -{ - int ret = 0; - int devId = INVALID_DEVID; - whMessageCrypto_GenericRequestHeader rqstHeader = {0}; - whMessageCrypto_GenericResponseHeader respHeader = {0}; + if (inSize < sizeof(whMessageCrypto_MlKemKeyGenRequest)) { + return WH_ERROR_BADARGS; + } - const void* cryptoDataIn = - (uint8_t*)req_packet + sizeof(whMessageCrypto_GenericRequestHeader); - void* cryptoDataOut = - (uint8_t*)resp_packet + sizeof(whMessageCrypto_GenericResponseHeader); + ret = wh_MessageCrypto_TranslateMlKemKeyGenRequest( + magic, (whMessageCrypto_MlKemKeyGenRequest*)cryptoDataIn, &req); + if (ret != 0) { + return ret; + } - /* Input and output sizes for data passed to crypto handlers. cryptoOutSize - * should be set by the crypto handler as an output parameter */ - uint16_t cryptoInSize = - req_size - sizeof(whMessageCrypto_GenericResponseHeader); - uint16_t cryptoOutSize = 0; + key_id = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, ctx->comm->client_id, + req.keyId); + res_out = (uint8_t*)cryptoDataOut + sizeof(whMessageCrypto_MlKemKeyGenResponse); + max_size = (uint16_t)(WOLFHSM_CFG_COMM_DATA_LEN - + (res_out - (uint8_t*)cryptoDataOut)); - if ((ctx == NULL) || (ctx->crypto == NULL) || (req_packet == NULL) || - (resp_packet == NULL) || (out_resp_size == NULL)) { + if (!_IsMlKemLevelSupported((int)req.level)) { return WH_ERROR_BADARGS; } - /* Validate req_size to prevent integer underflow */ - if (req_size < sizeof(whMessageCrypto_GenericResponseHeader)) { - return WH_ERROR_BADARGS; + ret = wc_MlKemKey_Init(key, (int)req.level, NULL, devId); + if (ret == 0) { + ret = wc_MlKemKey_MakeKey(key, ctx->crypto->rng); + if (ret == 0) { + if ((req.flags & WH_NVM_FLAGS_EPHEMERAL) != 0) { + key_id = WH_KEYID_ERASED; + ret = wh_Crypto_MlKemSerializeKey(key, max_size, res_out, + &res_size); + } + else { + if (WH_KEYID_ISERASED(key_id)) { + ret = wh_Server_KeystoreGetUniqueId(ctx, &key_id); + } + if (ret == WH_ERROR_OK) { + ret = wh_Server_MlKemKeyCacheImport(ctx, key, key_id, + req.flags, label_size, + req.label); + } + } + } + wc_MlKemKey_Free(key); } - /* Translate the request message to get the algo type */ - wh_MessageCrypto_TranslateGenericRequestHeader( - magic, (whMessageCrypto_GenericRequestHeader*)req_packet, &rqstHeader); + if (ret == WH_ERROR_OK) { + res.keyId = wh_KeyId_TranslateToClient(key_id); + res.len = res_size; + (void)wh_MessageCrypto_TranslateMlKemKeyGenResponse( + magic, &res, (whMessageCrypto_MlKemKeyGenResponse*)cryptoDataOut); + *outSize = sizeof(whMessageCrypto_MlKemKeyGenResponse) + res_size; + } - /* Compute devId from the per-message affinity field */ - devId = (rqstHeader.affinity == WH_CRYPTO_AFFINITY_HW && - ctx->devId != INVALID_DEVID) - ? ctx->devId - : INVALID_DEVID; + return ret; +#endif /* WOLFSSL_MLKEM_NO_MAKE_KEY */ +} - WH_DEBUG_SERVER_VERBOSE("HandleCryptoRequest. Action:%u\n", action); - WH_DEBUG_VERBOSE_HEXDUMP("[server] Crypto Request:\n", (const uint8_t*)req_packet, - req_size); - switch (action) { - case WC_ALGO_TYPE_CIPHER: - switch (rqstHeader.algoType) { -#ifndef NO_AES -#ifdef WOLFSSL_AES_COUNTER - case WC_CIPHER_AES_CTR: - ret = _HandleAesCtr(ctx, magic, devId, cryptoDataIn, - cryptoInSize, cryptoDataOut, - &cryptoOutSize); - break; -#endif /* WOLFSSL_AES_COUNTER */ -#ifdef HAVE_AES_ECB - case WC_CIPHER_AES_ECB: - ret = _HandleAesEcb(ctx, magic, devId, cryptoDataIn, - cryptoInSize, cryptoDataOut, - &cryptoOutSize); - break; -#endif /* HAVE_AES_ECB */ -#ifdef HAVE_AES_CBC - case WC_CIPHER_AES_CBC: - ret = _HandleAesCbc(ctx, magic, devId, cryptoDataIn, - cryptoInSize, cryptoDataOut, - &cryptoOutSize); - break; -#endif /* HAVE_AES_CBC */ -#ifdef HAVE_AESGCM - case WC_CIPHER_AES_GCM: - ret = _HandleAesGcm(ctx, magic, devId, cryptoDataIn, - cryptoInSize, cryptoDataOut, - &cryptoOutSize); - break; -#endif /* HAVE_AESGCM */ -#endif /* !NO_AES */ - default: - ret = NOT_COMPILED_IN; - break; - } - break; - case WC_ALGO_TYPE_PK: { - WH_DEBUG_SERVER_VERBOSE("PK type:%d\n", rqstHeader.algoType); - switch (rqstHeader.algoType) { -#ifndef NO_RSA -#ifdef WOLFSSL_KEY_GEN - case WC_PK_TYPE_RSA_KEYGEN: - ret = _HandleRsaKeyGen(ctx, magic, devId, cryptoDataIn, - cryptoInSize, cryptoDataOut, - &cryptoOutSize); - break; -#endif /* WOLFSSL_KEY_GEN */ - case WC_PK_TYPE_RSA: - ret = _HandleRsaFunction(ctx, magic, devId, cryptoDataIn, - cryptoInSize, cryptoDataOut, - &cryptoOutSize); - break; +static int _HandleMlKemEncaps(whServerContext* ctx, uint16_t magic, int devId, + const void* cryptoDataIn, uint16_t inSize, + void* cryptoDataOut, uint16_t* outSize) +{ +#ifdef WOLFSSL_MLKEM_NO_ENCAPSULATE + (void)ctx; + (void)magic; + (void)devId; + (void)cryptoDataIn; + (void)inSize; + (void)cryptoDataOut; + (void)outSize; + return WH_ERROR_NOHANDLER; +#else + int ret = WH_ERROR_OK; + MlKemKey key[1]; + whMessageCrypto_MlKemEncapsRequest req; + whMessageCrypto_MlKemEncapsResponse res; + whKeyId key_id; + uint8_t* res_ct; + uint8_t* res_ss; + word32 ct_len; + word32 ss_len; + word32 max_out; + int evict = 0; + int keyInited = 0; - case WC_PK_TYPE_RSA_GET_SIZE: - ret = _HandleRsaGetSize(ctx, magic, devId, cryptoDataIn, - cryptoInSize, cryptoDataOut, - &cryptoOutSize); - break; -#endif /* !NO_RSA */ + if (inSize < sizeof(whMessageCrypto_MlKemEncapsRequest)) { + return WH_ERROR_BADARGS; + } + + ret = wh_MessageCrypto_TranslateMlKemEncapsRequest( + magic, (whMessageCrypto_MlKemEncapsRequest*)cryptoDataIn, &req); + if (ret != 0) { + return ret; + } + + key_id = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, ctx->comm->client_id, + req.keyId); + evict = !!(req.options & WH_MESSAGE_CRYPTO_MLKEM_ENCAPS_OPTIONS_EVICT); + + if (!WH_KEYID_ISERASED(key_id)) { + ret = wh_Server_KeystoreFindEnforceKeyUsage(ctx, key_id, + WH_NVM_FLAGS_USAGE_DERIVE); + if (ret != WH_ERROR_OK) { + goto cleanup; + } + } + + if (!_IsMlKemLevelSupported((int)req.level)) { + ret = WH_ERROR_BADARGS; + goto cleanup; + } + + ret = wc_MlKemKey_Init(key, (int)req.level, NULL, devId); + if (ret == 0) { + keyInited = 1; + ret = wh_Server_MlKemKeyCacheExport(ctx, key_id, key); + } + + /* Verify the exported key matches the requested level */ + if (ret == WH_ERROR_OK && key->type != (int)req.level) { + ret = WH_ERROR_BADARGS; + } + + if (ret == WH_ERROR_OK) { + ret = wc_MlKemKey_CipherTextSize(key, &ct_len); + if (ret == WH_ERROR_OK) { + ret = wc_MlKemKey_SharedSecretSize(key, &ss_len); + } + } + + if (ret == WH_ERROR_OK) { + res_ct = (uint8_t*)cryptoDataOut + sizeof(whMessageCrypto_MlKemEncapsResponse); + res_ss = res_ct + ct_len; + max_out = (word32)(WOLFHSM_CFG_COMM_DATA_LEN - + ((uint8_t*)res_ct - (uint8_t*)cryptoDataOut)); + if (ct_len + ss_len > max_out) { + ret = WH_ERROR_BADARGS; + } + } + + if (ret == WH_ERROR_OK) { + ret = wc_MlKemKey_Encapsulate(key, res_ct, res_ss, ctx->crypto->rng); + if (ret == WH_ERROR_OK) { + res.ctSz = ct_len; + res.ssSz = ss_len; + (void)wh_MessageCrypto_TranslateMlKemEncapsResponse( + magic, &res, (whMessageCrypto_MlKemEncapsResponse*)cryptoDataOut); + *outSize = sizeof(whMessageCrypto_MlKemEncapsResponse) + ct_len + ss_len; + } + else { + /* Zero sensitive data on failure */ + wc_ForceZero(res_ss, ss_len); + } + } + + if (keyInited) { + wc_MlKemKey_Free(key); + } +cleanup: + if (evict != 0) { + (void)wh_Server_KeystoreEvictKey(ctx, key_id); + } + return ret; +#endif /* WOLFSSL_MLKEM_NO_ENCAPSULATE */ +} + +static int _HandleMlKemDecaps(whServerContext* ctx, uint16_t magic, int devId, + const void* cryptoDataIn, uint16_t inSize, + void* cryptoDataOut, uint16_t* outSize) +{ +#ifdef WOLFSSL_MLKEM_NO_DECAPSULATE + (void)ctx; + (void)magic; + (void)devId; + (void)cryptoDataIn; + (void)inSize; + (void)cryptoDataOut; + (void)outSize; + return WH_ERROR_NOHANDLER; +#else + int ret = WH_ERROR_OK; + MlKemKey key[1]; + whMessageCrypto_MlKemDecapsRequest req; + whMessageCrypto_MlKemDecapsResponse res; + whKeyId key_id; + byte* req_ct; + byte* res_ss; + uint32_t available; + word32 ss_len; + word32 max_out; + int evict = 0; + int keyInited = 0; + + if (inSize < sizeof(whMessageCrypto_MlKemDecapsRequest)) { + return WH_ERROR_BADARGS; + } + + ret = wh_MessageCrypto_TranslateMlKemDecapsRequest( + magic, (whMessageCrypto_MlKemDecapsRequest*)cryptoDataIn, &req); + if (ret != 0) { + return ret; + } + + key_id = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, ctx->comm->client_id, + req.keyId); + evict = !!(req.options & WH_MESSAGE_CRYPTO_MLKEM_DECAPS_OPTIONS_EVICT); + + if (!WH_KEYID_ISERASED(key_id)) { + ret = wh_Server_KeystoreFindEnforceKeyUsage(ctx, key_id, + WH_NVM_FLAGS_USAGE_DERIVE); + if (ret != WH_ERROR_OK) { + goto cleanup; + } + } + + if (!_IsMlKemLevelSupported((int)req.level)) { + ret = WH_ERROR_BADARGS; + goto cleanup; + } + + available = inSize - sizeof(whMessageCrypto_MlKemDecapsRequest); + if (req.ctSz > available) { + ret = WH_ERROR_BADARGS; + goto cleanup; + } + req_ct = (byte*)cryptoDataIn + sizeof(whMessageCrypto_MlKemDecapsRequest); + + ret = wc_MlKemKey_Init(key, (int)req.level, NULL, devId); + if (ret == WH_ERROR_OK) { + keyInited = 1; + ret = wh_Server_MlKemKeyCacheExport(ctx, key_id, key); + } + + /* Verify the exported key matches the requested level */ + if (ret == WH_ERROR_OK && key->type != (int)req.level) { + ret = WH_ERROR_BADARGS; + } + + if (ret == WH_ERROR_OK) { + ret = wc_MlKemKey_SharedSecretSize(key, &ss_len); + } + + if (ret == WH_ERROR_OK) { + res_ss = (byte*)cryptoDataOut + sizeof(whMessageCrypto_MlKemDecapsResponse); + max_out = (word32)(WOLFHSM_CFG_COMM_DATA_LEN - + ((uint8_t*)res_ss - (uint8_t*)cryptoDataOut)); + if (ss_len > max_out) { + ret = WH_ERROR_BADARGS; + } + } + + if (ret == WH_ERROR_OK) { + ret = wc_MlKemKey_Decapsulate(key, res_ss, req_ct, req.ctSz); + if (ret == WH_ERROR_OK) { + res.ssSz = ss_len; + (void)wh_MessageCrypto_TranslateMlKemDecapsResponse( + magic, &res, (whMessageCrypto_MlKemDecapsResponse*)cryptoDataOut); + *outSize = sizeof(whMessageCrypto_MlKemDecapsResponse) + ss_len; + } + else { + /* Zero sensitive data on failure */ + wc_ForceZero(res_ss, ss_len); + } + } + + if (keyInited) { + wc_MlKemKey_Free(key); + } +cleanup: + if (evict != 0) { + (void)wh_Server_KeystoreEvictKey(ctx, key_id); + } + return ret; +#endif /* WOLFSSL_MLKEM_NO_DECAPSULATE */ +} +#endif /* WOLFSSL_HAVE_MLKEM */ + +#if defined(HAVE_DILITHIUM) || defined(HAVE_FALCON) +static int _HandlePqcSigAlgorithm(whServerContext* ctx, uint16_t magic, + int devId, const void* cryptoDataIn, + uint16_t cryptoInSize, void* cryptoDataOut, + uint16_t* cryptoOutSize, uint32_t pkAlgoType, + uint32_t pqAlgoType) +{ + int ret = WH_ERROR_NOHANDLER; + + /* Dispatch the appropriate algorithm handler based on the requested PK type + * and the algorithm type. */ + switch (pqAlgoType) { +#ifdef HAVE_DILITHIUM + case WC_PQC_SIG_TYPE_DILITHIUM: { + switch (pkAlgoType) { + case WC_PK_TYPE_PQC_SIG_KEYGEN: + ret = _HandleMlDsaKeyGen(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + cryptoOutSize); + break; + case WC_PK_TYPE_PQC_SIG_SIGN: + ret = _HandleMlDsaSign(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + cryptoOutSize); + break; + case WC_PK_TYPE_PQC_SIG_VERIFY: + ret = _HandleMlDsaVerify(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + cryptoOutSize); + break; + case WC_PK_TYPE_PQC_SIG_CHECK_PRIV_KEY: + ret = _HandleMlDsaCheckPrivKey( + ctx, magic, devId, cryptoDataIn, cryptoInSize, + cryptoDataOut, cryptoOutSize); + break; + default: + ret = WH_ERROR_NOHANDLER; + break; + } + } break; +#endif /* HAVE_DILITHIUM */ + default: + ret = WH_ERROR_NOHANDLER; + break; + } + + return ret; +} +#endif + +#if defined(WOLFSSL_HAVE_MLKEM) +static int _HandlePqcKemAlgorithm(whServerContext* ctx, uint16_t magic, + int devId, const void* cryptoDataIn, + uint16_t cryptoInSize, void* cryptoDataOut, + uint16_t* cryptoOutSize, uint32_t pkAlgoType, + uint32_t pqAlgoType) +{ + int ret = WH_ERROR_NOHANDLER; + + switch (pqAlgoType) { + case WC_PQC_KEM_TYPE_KYBER: { + switch (pkAlgoType) { + case WC_PK_TYPE_PQC_KEM_KEYGEN: + ret = _HandleMlKemKeyGen(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + cryptoOutSize); + break; + case WC_PK_TYPE_PQC_KEM_ENCAPS: + ret = _HandleMlKemEncaps(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + cryptoOutSize); + break; + case WC_PK_TYPE_PQC_KEM_DECAPS: + ret = _HandleMlKemDecaps(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + cryptoOutSize); + break; + default: + ret = WH_ERROR_NOHANDLER; + break; + } + } break; + default: + ret = WH_ERROR_NOHANDLER; + break; + } + + return ret; +} +#endif + +int wh_Server_HandleCryptoRequest(whServerContext* ctx, uint16_t magic, + uint16_t action, uint16_t seq, + uint16_t req_size, const void* req_packet, + uint16_t* out_resp_size, void* resp_packet) +{ + int ret = 0; + int devId = INVALID_DEVID; + whMessageCrypto_GenericRequestHeader rqstHeader = {0}; + whMessageCrypto_GenericResponseHeader respHeader = {0}; + + const void* cryptoDataIn = + (uint8_t*)req_packet + sizeof(whMessageCrypto_GenericRequestHeader); + void* cryptoDataOut = + (uint8_t*)resp_packet + sizeof(whMessageCrypto_GenericResponseHeader); + + /* Input and output sizes for data passed to crypto handlers. cryptoOutSize + * should be set by the crypto handler as an output parameter */ + uint16_t cryptoInSize = + req_size - sizeof(whMessageCrypto_GenericResponseHeader); + uint16_t cryptoOutSize = 0; + + if ((ctx == NULL) || (ctx->crypto == NULL) || (req_packet == NULL) || + (resp_packet == NULL) || (out_resp_size == NULL)) { + return WH_ERROR_BADARGS; + } + + /* Validate req_size to prevent integer underflow */ + if (req_size < sizeof(whMessageCrypto_GenericResponseHeader)) { + return WH_ERROR_BADARGS; + } + + /* Translate the request message to get the algo type */ + wh_MessageCrypto_TranslateGenericRequestHeader( + magic, (whMessageCrypto_GenericRequestHeader*)req_packet, &rqstHeader); + + /* Compute devId from the per-message affinity field */ + devId = (rqstHeader.affinity == WH_CRYPTO_AFFINITY_HW && + ctx->devId != INVALID_DEVID) + ? ctx->devId + : INVALID_DEVID; + + WH_DEBUG_SERVER_VERBOSE("HandleCryptoRequest. Action:%u\n", action); + WH_DEBUG_VERBOSE_HEXDUMP("[server] Crypto Request:\n", (const uint8_t*)req_packet, + req_size); + switch (action) { + case WC_ALGO_TYPE_CIPHER: + switch (rqstHeader.algoType) { +#ifndef NO_AES +#ifdef WOLFSSL_AES_COUNTER + case WC_CIPHER_AES_CTR: + ret = _HandleAesCtr(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + &cryptoOutSize); + break; +#endif /* WOLFSSL_AES_COUNTER */ +#ifdef HAVE_AES_ECB + case WC_CIPHER_AES_ECB: + ret = _HandleAesEcb(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + &cryptoOutSize); + break; +#endif /* HAVE_AES_ECB */ +#ifdef HAVE_AES_CBC + case WC_CIPHER_AES_CBC: + ret = _HandleAesCbc(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + &cryptoOutSize); + break; +#endif /* HAVE_AES_CBC */ +#ifdef HAVE_AESGCM + case WC_CIPHER_AES_GCM: + ret = _HandleAesGcm(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + &cryptoOutSize); + break; +#endif /* HAVE_AESGCM */ +#endif /* !NO_AES */ + default: + ret = NOT_COMPILED_IN; + break; + } + break; + case WC_ALGO_TYPE_PK: { + WH_DEBUG_SERVER_VERBOSE("PK type:%d\n", rqstHeader.algoType); + switch (rqstHeader.algoType) { +#ifndef NO_RSA +#ifdef WOLFSSL_KEY_GEN + case WC_PK_TYPE_RSA_KEYGEN: + ret = _HandleRsaKeyGen(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + &cryptoOutSize); + break; +#endif /* WOLFSSL_KEY_GEN */ + case WC_PK_TYPE_RSA: + ret = _HandleRsaFunction(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + &cryptoOutSize); + break; + + case WC_PK_TYPE_RSA_GET_SIZE: + ret = _HandleRsaGetSize(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + &cryptoOutSize); + break; +#endif /* !NO_RSA */ #ifdef HAVE_ECC case WC_PK_TYPE_EC_KEYGEN: @@ -4684,13 +5115,15 @@ int wh_Server_HandleCryptoRequest(whServerContext* ctx, uint16_t magic, break; #endif -#if defined(HAVE_KYBER) +#if defined(WOLFSSL_HAVE_MLKEM) case WC_PK_TYPE_PQC_KEM_KEYGEN: case WC_PK_TYPE_PQC_KEM_ENCAPS: case WC_PK_TYPE_PQC_KEM_DECAPS: ret = _HandlePqcKemAlgorithm(ctx, magic, devId, cryptoDataIn, cryptoInSize, - cryptoDataOut, &cryptoOutSize); + cryptoDataOut, &cryptoOutSize, + rqstHeader.algoType, + rqstHeader.algoSubType); break; #endif @@ -5385,29 +5818,163 @@ static int _HandleMlDsaKeyGenDma(whServerContext* ctx, uint16_t magic, } } } - wc_MlDsaKey_Free(key); - } - } + wc_MlDsaKey_Free(key); + } + } + + if (ret == WH_ERROR_ACCESS) { + res.dmaAddrStatus.badAddr = req.key; + } + + /* Translate the response */ + (void)wh_MessageCrypto_TranslateMlDsaKeyGenDmaResponse( + magic, &res, (whMessageCrypto_MlDsaKeyGenDmaResponse*)cryptoDataOut); + + *outSize = sizeof(res); + + return ret; +#endif /* WOLFSSL_DILITHIUM_NO_MAKE_KEY */ +} + +static int _HandleMlDsaSignDma(whServerContext* ctx, uint16_t magic, int devId, + const void* cryptoDataIn, uint16_t inSize, + void* cryptoDataOut, uint16_t* outSize) +{ +#ifdef WOLFSSL_DILITHIUM_NO_SIGN + (void)ctx; + (void)magic; + (void)cryptoDataIn; + (void)inSize; + (void)cryptoDataOut; + (void)outSize; + return WH_ERROR_NOHANDLER; +#else + int ret = 0; + MlDsaKey key[1]; + void* msgAddr = NULL; + void* sigAddr = NULL; + word32 sigLen = 0; + + whMessageCrypto_MlDsaSignDmaRequest req; + whMessageCrypto_MlDsaSignDmaResponse res; + + if (inSize < sizeof(whMessageCrypto_MlDsaSignDmaRequest)) { + return WH_ERROR_BADARGS; + } + + /* Translate the request */ + ret = wh_MessageCrypto_TranslateMlDsaSignDmaRequest( + magic, (whMessageCrypto_MlDsaSignDmaRequest*)cryptoDataIn, &req); + if (ret != WH_ERROR_OK) { + return ret; + } + + /* Transaction state */ + whKeyId key_id; + int evict = 0; + + + /* Get key ID and evict flag */ + key_id = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, + ctx->comm->client_id, req.keyId); + evict = !!(req.options & WH_MESSAGE_CRYPTO_MLDSA_SIGN_OPTIONS_EVICT); + + /* Extract context from inline data after the struct */ + uint32_t contextSz = req.contextSz; + uint32_t preHashType = req.preHashType; + byte* req_context = NULL; + if (contextSz > WH_CRYPTO_MLDSA_MAX_CTX_LEN) { + return WH_ERROR_BADARGS; + } + if (contextSz > 0) { + if (inSize < sizeof(whMessageCrypto_MlDsaSignDmaRequest) + contextSz) { + return WH_ERROR_BADARGS; + } + req_context = (uint8_t*)(cryptoDataIn) + + sizeof(whMessageCrypto_MlDsaSignDmaRequest); + } + + /* Initialize key */ + ret = wc_MlDsaKey_Init(key, NULL, devId); + if (ret == 0) { + /* Export key from cache */ + /* TODO: sanity check security level against key pulled from cache? */ + ret = wh_Server_MlDsaKeyCacheExport(ctx, key_id, key); + if (ret == 0) { + /* Process client message buffer address */ + ret = wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.msg.addr, &msgAddr, req.msg.sz, + WH_DMA_OPER_CLIENT_READ_PRE, (whServerDmaFlags){0}); + + if (ret == 0) { + /* Process client signature buffer address */ + ret = wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.sig.addr, &sigAddr, req.sig.sz, + WH_DMA_OPER_CLIENT_WRITE_PRE, (whServerDmaFlags){0}); + + if (ret == 0) { + /* Sign the message using appropriate FIPS 204 API */ + sigLen = req.sig.sz; + if (preHashType != WC_HASH_TYPE_NONE) { + ret = wc_MlDsaKey_SignCtxHash( + key, req_context, (byte)contextSz, + sigAddr, &sigLen, msgAddr, req.msg.sz, + preHashType, ctx->crypto->rng); + } + else { + ret = wc_MlDsaKey_SignCtx( + key, req_context, (byte)contextSz, + sigAddr, &sigLen, msgAddr, req.msg.sz, + ctx->crypto->rng); + } + } + + if (sigAddr != NULL) { + /* Post-write processing of signature buffer */ + (void)wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.sig.addr, &sigAddr, sigLen, + WH_DMA_OPER_CLIENT_WRITE_POST, + (whServerDmaFlags){0}); + } + if (msgAddr != NULL) { + /* Post-read processing of message buffer */ + (void)wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.msg.addr, &msgAddr, + req.msg.sz, WH_DMA_OPER_CLIENT_READ_POST, + (whServerDmaFlags){0}); + } + } - if (ret == WH_ERROR_ACCESS) { - res.dmaAddrStatus.badAddr = req.key; + /* Evict key if requested */ + if (evict) { + /* User requested to evict from cache, even if the call failed + */ + (void)wh_Server_KeystoreEvictKey(ctx, key_id); + } + } + wc_MlDsaKey_Free(key); } - /* Translate the response */ - (void)wh_MessageCrypto_TranslateMlDsaKeyGenDmaResponse( - magic, &res, (whMessageCrypto_MlDsaKeyGenDmaResponse*)cryptoDataOut); + if (ret == 0) { + /* Set response signature length */ + res.sigLen = sigLen; + *outSize = sizeof(res); - *outSize = sizeof(res); + /* Translate the response */ + (void)wh_MessageCrypto_TranslateMlDsaSignDmaResponse( + magic, &res, (whMessageCrypto_MlDsaSignDmaResponse*)cryptoDataOut); + } return ret; -#endif /* WOLFSSL_DILITHIUM_NO_MAKE_KEY */ +#endif /* WOLFSSL_DILITHIUM_NO_SIGN */ } -static int _HandleMlDsaSignDma(whServerContext* ctx, uint16_t magic, int devId, - const void* cryptoDataIn, uint16_t inSize, - void* cryptoDataOut, uint16_t* outSize) +static int _HandleMlDsaVerifyDma(whServerContext* ctx, uint16_t magic, + int devId, const void* cryptoDataIn, + uint16_t inSize, void* cryptoDataOut, + uint16_t* outSize) { -#ifdef WOLFSSL_DILITHIUM_NO_SIGN +#ifdef WOLFSSL_DILITHIUM_NO_VERIFY (void)ctx; (void)magic; (void)cryptoDataIn; @@ -5418,20 +5985,20 @@ static int _HandleMlDsaSignDma(whServerContext* ctx, uint16_t magic, int devId, #else int ret = 0; MlDsaKey key[1]; - void* msgAddr = NULL; - void* sigAddr = NULL; - word32 sigLen = 0; + void* msgAddr = NULL; + void* sigAddr = NULL; + int verified = 0; - whMessageCrypto_MlDsaSignDmaRequest req; - whMessageCrypto_MlDsaSignDmaResponse res; + whMessageCrypto_MlDsaVerifyDmaRequest req; + whMessageCrypto_MlDsaVerifyDmaResponse res; - if (inSize < sizeof(whMessageCrypto_MlDsaSignDmaRequest)) { + if (inSize < sizeof(whMessageCrypto_MlDsaVerifyDmaRequest)) { return WH_ERROR_BADARGS; } /* Translate the request */ - ret = wh_MessageCrypto_TranslateMlDsaSignDmaRequest( - magic, (whMessageCrypto_MlDsaSignDmaRequest*)cryptoDataIn, &req); + ret = wh_MessageCrypto_TranslateMlDsaVerifyDmaRequest( + magic, (whMessageCrypto_MlDsaVerifyDmaRequest*)cryptoDataIn, &req); if (ret != WH_ERROR_OK) { return ret; } @@ -5440,11 +6007,10 @@ static int _HandleMlDsaSignDma(whServerContext* ctx, uint16_t magic, int devId, whKeyId key_id; int evict = 0; - /* Get key ID and evict flag */ key_id = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, ctx->comm->client_id, req.keyId); - evict = !!(req.options & WH_MESSAGE_CRYPTO_MLDSA_SIGN_OPTIONS_EVICT); + evict = !!(req.options & WH_MESSAGE_CRYPTO_MLDSA_VERIFY_OPTIONS_EVICT); /* Extract context from inline data after the struct */ uint32_t contextSz = req.contextSz; @@ -5454,19 +6020,27 @@ static int _HandleMlDsaSignDma(whServerContext* ctx, uint16_t magic, int devId, return WH_ERROR_BADARGS; } if (contextSz > 0) { - if (inSize < sizeof(whMessageCrypto_MlDsaSignDmaRequest) + contextSz) { + if (inSize < sizeof(whMessageCrypto_MlDsaVerifyDmaRequest) + contextSz) { return WH_ERROR_BADARGS; } req_context = (uint8_t*)(cryptoDataIn) + - sizeof(whMessageCrypto_MlDsaSignDmaRequest); + sizeof(whMessageCrypto_MlDsaVerifyDmaRequest); } /* Initialize key */ ret = wc_MlDsaKey_Init(key, NULL, devId); + if (ret != 0) { + return ret; + } + + /* Export key from cache */ + ret = wh_Server_MlDsaKeyCacheExport(ctx, key_id, key); if (ret == 0) { - /* Export key from cache */ - /* TODO: sanity check security level against key pulled from cache? */ - ret = wh_Server_MlDsaKeyCacheExport(ctx, key_id, key); + /* Process client signature buffer address */ + ret = wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.sig.addr, &sigAddr, req.sig.sz, + WH_DMA_OPER_CLIENT_READ_PRE, (whServerDmaFlags){0}); + if (ret == 0) { /* Process client message buffer address */ ret = wh_Server_DmaProcessClientAddress( @@ -5474,205 +6048,364 @@ static int _HandleMlDsaSignDma(whServerContext* ctx, uint16_t magic, int devId, WH_DMA_OPER_CLIENT_READ_PRE, (whServerDmaFlags){0}); if (ret == 0) { - /* Process client signature buffer address */ - ret = wh_Server_DmaProcessClientAddress( + /* Verify the signature using appropriate FIPS 204 API */ + if (preHashType != WC_HASH_TYPE_NONE) { + ret = wc_MlDsaKey_VerifyCtxHash( + key, sigAddr, req.sig.sz, req_context, (byte)contextSz, + msgAddr, req.msg.sz, preHashType, &verified); + } + else { + ret = wc_MlDsaKey_VerifyCtx( + key, sigAddr, req.sig.sz, req_context, (byte)contextSz, + msgAddr, req.msg.sz, &verified); + } + } + + if (sigAddr != NULL) { + /* Post-read processing of signature buffer */ + (void)wh_Server_DmaProcessClientAddress( ctx, (uintptr_t)req.sig.addr, &sigAddr, req.sig.sz, - WH_DMA_OPER_CLIENT_WRITE_PRE, (whServerDmaFlags){0}); + WH_DMA_OPER_CLIENT_READ_POST, (whServerDmaFlags){0}); + } - if (ret == 0) { - /* Sign the message using appropriate FIPS 204 API */ - sigLen = req.sig.sz; - if (preHashType != WC_HASH_TYPE_NONE) { - ret = wc_MlDsaKey_SignCtxHash( - key, req_context, (byte)contextSz, - sigAddr, &sigLen, msgAddr, req.msg.sz, - preHashType, ctx->crypto->rng); + if (msgAddr != NULL) { + /* Post-read processing of message buffer */ + (void)wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.msg.addr, &msgAddr, + req.msg.sz, WH_DMA_OPER_CLIENT_READ_POST, + (whServerDmaFlags){0}); + } + } + + /* Evict key if requested */ + if (evict) { + /* User requested to evict from cache, even if the call failed */ + (void)wh_Server_KeystoreEvictKey(ctx, key_id); + } + } + + if (ret == 0) { + /* Set verification result */ + res.verifyResult = verified; + + /* Translate the response */ + (void)wh_MessageCrypto_TranslateMlDsaVerifyDmaResponse( + magic, &res, + (whMessageCrypto_MlDsaVerifyDmaResponse*)cryptoDataOut); + + *outSize = sizeof(res); + } + + wc_MlDsaKey_Free(key); + return ret; +#endif /* WOLFSSL_DILITHIUM_NO_VERIFY */ +} + +static int _HandleMlDsaCheckPrivKeyDma(whServerContext* ctx, uint16_t magic, + int devId, const void* cryptoDataIn, + uint16_t inSize, void* cryptoDataOut, + uint16_t* outSize) +{ + (void)ctx; + (void)magic; + (void)devId; + (void)cryptoDataIn; + (void)inSize; + (void)cryptoDataOut; + (void)outSize; + return WH_ERROR_NOHANDLER; +} +#endif /* HAVE_DILITHIUM */ + +#if defined(HAVE_DILITHIUM) || defined(HAVE_FALCON) +static int _HandlePqcSigAlgorithmDma(whServerContext* ctx, uint16_t magic, + int devId, const void* cryptoDataIn, + uint16_t cryptoInSize, void* cryptoDataOut, + uint16_t* cryptoOutSize, + uint32_t pkAlgoType, uint32_t pqAlgoType) +{ + int ret = WH_ERROR_NOHANDLER; + + /* Dispatch the appropriate algorithm handler based on the requested PK type + * and the algorithm type. */ + switch (pqAlgoType) { +#ifdef HAVE_DILITHIUM + case WC_PQC_SIG_TYPE_DILITHIUM: { + switch (pkAlgoType) { + case WC_PK_TYPE_PQC_SIG_KEYGEN: + ret = _HandleMlDsaKeyGenDma(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + cryptoOutSize); + break; + case WC_PK_TYPE_PQC_SIG_SIGN: + ret = _HandleMlDsaSignDma(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + cryptoOutSize); + break; + case WC_PK_TYPE_PQC_SIG_VERIFY: + ret = _HandleMlDsaVerifyDma(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + cryptoOutSize); + break; + case WC_PK_TYPE_PQC_SIG_CHECK_PRIV_KEY: + ret = _HandleMlDsaCheckPrivKeyDma( + ctx, magic, devId, cryptoDataIn, cryptoInSize, + cryptoDataOut, cryptoOutSize); + break; + default: + ret = WH_ERROR_NOHANDLER; + break; + } + } break; +#endif /* HAVE_DILITHIUM */ + default: + ret = WH_ERROR_NOHANDLER; + break; + } + + return ret; +} +#endif /* HAVE_DILITHIUM || HAVE_FALCON */ + +#if defined(WOLFSSL_HAVE_MLKEM) +static int _HandleMlKemKeyGenDma(whServerContext* ctx, uint16_t magic, + int devId, const void* cryptoDataIn, + uint16_t inSize, void* cryptoDataOut, + uint16_t* outSize) +{ +#ifdef WOLFSSL_MLKEM_NO_MAKE_KEY + (void)ctx; + (void)magic; + (void)devId; + (void)cryptoDataIn; + (void)inSize; + (void)cryptoDataOut; + (void)outSize; + return WH_ERROR_NOHANDLER; +#else + int ret = WH_ERROR_OK; + MlKemKey key[1]; + void* clientOutAddr = NULL; + uint16_t keySize = 0; + whMessageCrypto_MlKemKeyGenDmaRequest req; + whMessageCrypto_MlKemKeyGenDmaResponse res; + + memset(&res, 0, sizeof(res)); + + if (inSize < sizeof(whMessageCrypto_MlKemKeyGenDmaRequest)) { + return WH_ERROR_BADARGS; + } + + ret = wh_MessageCrypto_TranslateMlKemKeyGenDmaRequest( + magic, (whMessageCrypto_MlKemKeyGenDmaRequest*)cryptoDataIn, &req); + if (ret != WH_ERROR_OK) { + return ret; + } + + if (!_IsMlKemLevelSupported((int)req.level)) { + ret = WH_ERROR_BADARGS; + } + else { + ret = wc_MlKemKey_Init(key, (int)req.level, NULL, devId); + if (ret == WH_ERROR_OK) { + ret = wc_MlKemKey_MakeKey(key, ctx->crypto->rng); + if (ret == WH_ERROR_OK) { + if ((req.flags & WH_NVM_FLAGS_EPHEMERAL) != 0) { + ret = wh_Server_DmaProcessClientAddress( + ctx, req.key.addr, &clientOutAddr, req.key.sz, + WH_DMA_OPER_CLIENT_WRITE_PRE, (whServerDmaFlags){0}); + if (ret == WH_ERROR_OK) { + ret = wh_Crypto_MlKemSerializeKey( + key, req.key.sz, (uint8_t*)clientOutAddr, &keySize); + if (ret == WH_ERROR_OK) { + res.keyId = WH_KEYID_ERASED; + res.keySize = keySize; + } } - else { - ret = wc_MlDsaKey_SignCtx( - key, req_context, (byte)contextSz, - sigAddr, &sigLen, msgAddr, req.msg.sz, - ctx->crypto->rng); + if (ret == WH_ERROR_OK) { + ret = wh_Server_DmaProcessClientAddress( + ctx, req.key.addr, &clientOutAddr, keySize, + WH_DMA_OPER_CLIENT_WRITE_POST, + (whServerDmaFlags){0}); } } + else { + whKeyId keyId = wh_KeyId_TranslateFromClient( + WH_KEYTYPE_CRYPTO, ctx->comm->client_id, req.keyId); - if (sigAddr != NULL) { - /* Post-write processing of signature buffer */ - (void)wh_Server_DmaProcessClientAddress( - ctx, (uintptr_t)req.sig.addr, &sigAddr, sigLen, - WH_DMA_OPER_CLIENT_WRITE_POST, - (whServerDmaFlags){0}); - } - if (msgAddr != NULL) { - /* Post-read processing of message buffer */ - (void)wh_Server_DmaProcessClientAddress( - ctx, (uintptr_t)req.msg.addr, &msgAddr, - req.msg.sz, WH_DMA_OPER_CLIENT_READ_POST, - (whServerDmaFlags){0}); - } - } - - /* Evict key if requested */ - if (evict) { - /* User requested to evict from cache, even if the call failed - */ - (void)wh_Server_KeystoreEvictKey(ctx, key_id); + if (WH_KEYID_ISERASED(keyId)) { + ret = wh_Server_KeystoreGetUniqueId(ctx, &keyId); + } + if (ret == WH_ERROR_OK) { + ret = wh_Server_MlKemKeyCacheImport( + ctx, key, keyId, req.flags, req.labelSize, + req.label); + if (ret == WH_ERROR_OK) { + res.keyId = wh_KeyId_TranslateToClient(keyId); + res.keySize = keySize; + } + } + } } + wc_MlKemKey_Free(key); } - wc_MlDsaKey_Free(key); } - if (ret == 0) { - /* Set response signature length */ - res.sigLen = sigLen; - *outSize = sizeof(res); - - /* Translate the response */ - (void)wh_MessageCrypto_TranslateMlDsaSignDmaResponse( - magic, &res, (whMessageCrypto_MlDsaSignDmaResponse*)cryptoDataOut); + if (ret == WH_ERROR_ACCESS) { + res.dmaAddrStatus.badAddr = req.key; } + (void)wh_MessageCrypto_TranslateMlKemKeyGenDmaResponse( + magic, &res, (whMessageCrypto_MlKemKeyGenDmaResponse*)cryptoDataOut); + *outSize = sizeof(res); return ret; -#endif /* WOLFSSL_DILITHIUM_NO_SIGN */ +#endif } -static int _HandleMlDsaVerifyDma(whServerContext* ctx, uint16_t magic, +static int _HandleMlKemEncapsDma(whServerContext* ctx, uint16_t magic, int devId, const void* cryptoDataIn, uint16_t inSize, void* cryptoDataOut, uint16_t* outSize) { -#ifdef WOLFSSL_DILITHIUM_NO_VERIFY +#ifdef WOLFSSL_MLKEM_NO_ENCAPSULATE (void)ctx; (void)magic; + (void)devId; (void)cryptoDataIn; (void)inSize; (void)cryptoDataOut; (void)outSize; return WH_ERROR_NOHANDLER; #else - int ret = 0; - MlDsaKey key[1]; - void* msgAddr = NULL; - void* sigAddr = NULL; - int verified = 0; + int ret = WH_ERROR_OK; + MlKemKey key[1]; + void* ctAddr = NULL; + word32 ctLen = 0; + word32 ssLen = 0; + whKeyId key_id; + int evict = 0; + int keyInited = 0; + uint8_t* res_ss; + word32 max_ss; + whMessageCrypto_MlKemEncapsDmaRequest req; + whMessageCrypto_MlKemEncapsDmaResponse res; - whMessageCrypto_MlDsaVerifyDmaRequest req; - whMessageCrypto_MlDsaVerifyDmaResponse res; + memset(&res, 0, sizeof(res)); - if (inSize < sizeof(whMessageCrypto_MlDsaVerifyDmaRequest)) { + if (inSize < sizeof(whMessageCrypto_MlKemEncapsDmaRequest)) { return WH_ERROR_BADARGS; } - /* Translate the request */ - ret = wh_MessageCrypto_TranslateMlDsaVerifyDmaRequest( - magic, (whMessageCrypto_MlDsaVerifyDmaRequest*)cryptoDataIn, &req); + ret = wh_MessageCrypto_TranslateMlKemEncapsDmaRequest( + magic, (whMessageCrypto_MlKemEncapsDmaRequest*)cryptoDataIn, &req); if (ret != WH_ERROR_OK) { return ret; } - /* Transaction state */ - whKeyId key_id; - int evict = 0; - - /* Get key ID and evict flag */ key_id = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, ctx->comm->client_id, req.keyId); - evict = !!(req.options & WH_MESSAGE_CRYPTO_MLDSA_VERIFY_OPTIONS_EVICT); + evict = !!(req.options & WH_MESSAGE_CRYPTO_MLKEM_ENCAPS_OPTIONS_EVICT); - /* Extract context from inline data after the struct */ - uint32_t contextSz = req.contextSz; - uint32_t preHashType = req.preHashType; - byte* req_context = NULL; - if (contextSz > WH_CRYPTO_MLDSA_MAX_CTX_LEN) { - return WH_ERROR_BADARGS; - } - if (contextSz > 0) { - if (inSize < sizeof(whMessageCrypto_MlDsaVerifyDmaRequest) + contextSz) { - return WH_ERROR_BADARGS; + if (!WH_KEYID_ISERASED(key_id)) { + ret = wh_Server_KeystoreFindEnforceKeyUsage(ctx, key_id, + WH_NVM_FLAGS_USAGE_DERIVE); + if (ret != WH_ERROR_OK) { + goto cleanup; } - req_context = (uint8_t*)(cryptoDataIn) + - sizeof(whMessageCrypto_MlDsaVerifyDmaRequest); } - /* Initialize key */ - ret = wc_MlDsaKey_Init(key, NULL, devId); - if (ret != 0) { - return ret; + if (!_IsMlKemLevelSupported((int)req.level)) { + ret = WH_ERROR_BADARGS; + goto cleanup; } - /* Export key from cache */ - ret = wh_Server_MlDsaKeyCacheExport(ctx, key_id, key); - if (ret == 0) { - /* Process client signature buffer address */ - ret = wh_Server_DmaProcessClientAddress( - ctx, (uintptr_t)req.sig.addr, &sigAddr, req.sig.sz, - WH_DMA_OPER_CLIENT_READ_PRE, (whServerDmaFlags){0}); + ret = wc_MlKemKey_Init(key, (int)req.level, NULL, devId); + if (ret == WH_ERROR_OK) { + keyInited = 1; + ret = wh_Server_MlKemKeyCacheExport(ctx, key_id, key); + } - if (ret == 0) { - /* Process client message buffer address */ - ret = wh_Server_DmaProcessClientAddress( - ctx, (uintptr_t)req.msg.addr, &msgAddr, req.msg.sz, - WH_DMA_OPER_CLIENT_READ_PRE, (whServerDmaFlags){0}); + /* Verify the exported key matches the requested level */ + if (ret == WH_ERROR_OK && key->type != (int)req.level) { + ret = WH_ERROR_BADARGS; + } - if (ret == 0) { - /* Verify the signature using appropriate FIPS 204 API */ - if (preHashType != WC_HASH_TYPE_NONE) { - ret = wc_MlDsaKey_VerifyCtxHash( - key, sigAddr, req.sig.sz, req_context, (byte)contextSz, - msgAddr, req.msg.sz, preHashType, &verified); - } - else { - ret = wc_MlDsaKey_VerifyCtx( - key, sigAddr, req.sig.sz, req_context, (byte)contextSz, - msgAddr, req.msg.sz, &verified); - } - } + if (ret == WH_ERROR_OK) { + ret = wc_MlKemKey_CipherTextSize(key, &ctLen); + } + if (ret == WH_ERROR_OK && ctLen > req.ct.sz) { + ret = WH_ERROR_BADARGS; + goto cleanup_key; + } + if (ret == WH_ERROR_OK) { + ret = wc_MlKemKey_SharedSecretSize(key, &ssLen); + } - if (sigAddr != NULL) { - /* Post-read processing of signature buffer */ - (void)wh_Server_DmaProcessClientAddress( - ctx, (uintptr_t)req.sig.addr, &sigAddr, req.sig.sz, - WH_DMA_OPER_CLIENT_READ_POST, (whServerDmaFlags){0}); - } + /* Validate that the inline shared secret fits in the comm buffer */ + if (ret == WH_ERROR_OK) { + res_ss = (uint8_t*)cryptoDataOut + + sizeof(whMessageCrypto_MlKemEncapsDmaResponse); + max_ss = (word32)(WOLFHSM_CFG_COMM_DATA_LEN - + ((uint8_t*)res_ss - (uint8_t*)cryptoDataOut)); + if (ssLen > max_ss) { + ret = WH_ERROR_BADARGS; + } + } - if (msgAddr != NULL) { - /* Post-read processing of message buffer */ - (void)wh_Server_DmaProcessClientAddress( - ctx, (uintptr_t)req.msg.addr, &msgAddr, - req.msg.sz, WH_DMA_OPER_CLIENT_READ_POST, - (whServerDmaFlags){0}); - } + if (ret == WH_ERROR_OK) { + ret = wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.ct.addr, &ctAddr, ctLen, + WH_DMA_OPER_CLIENT_WRITE_PRE, (whServerDmaFlags){0}); + if (ret == WH_ERROR_ACCESS) { + res.dmaAddrStatus.badAddr = req.ct; } + } - /* Evict key if requested */ - if (evict) { - /* User requested to evict from cache, even if the call failed */ - (void)wh_Server_KeystoreEvictKey(ctx, key_id); + if (ret == WH_ERROR_OK) { + /* Shared secret goes inline in response, not via DMA */ + res_ss = (uint8_t*)cryptoDataOut + + sizeof(whMessageCrypto_MlKemEncapsDmaResponse); + ret = wc_MlKemKey_Encapsulate(key, (byte*)ctAddr, res_ss, + ctx->crypto->rng); + if (ret != WH_ERROR_OK) { + /* Zero sensitive data on failure */ + wc_ForceZero(res_ss, ssLen); } } - if (ret == 0) { - /* Set verification result */ - res.verifyResult = verified; + if (ctAddr != NULL) { + (void)wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.ct.addr, &ctAddr, ctLen, + WH_DMA_OPER_CLIENT_WRITE_POST, (whServerDmaFlags){0}); + } - /* Translate the response */ - (void)wh_MessageCrypto_TranslateMlDsaVerifyDmaResponse( - magic, &res, - (whMessageCrypto_MlDsaVerifyDmaResponse*)cryptoDataOut); + if (ret == WH_ERROR_OK) { + res.ctLen = ctLen; + res.ssLen = ssLen; + } - *outSize = sizeof(res); +cleanup_key: + if (keyInited) { + wc_MlKemKey_Free(key); + } +cleanup: + if (evict != 0) { + (void)wh_Server_KeystoreEvictKey(ctx, key_id); } - wc_MlDsaKey_Free(key); + (void)wh_MessageCrypto_TranslateMlKemEncapsDmaResponse( + magic, &res, (whMessageCrypto_MlKemEncapsDmaResponse*)cryptoDataOut); + *outSize = sizeof(res) + ssLen; return ret; -#endif /* WOLFSSL_DILITHIUM_NO_VERIFY */ +#endif } -static int _HandleMlDsaCheckPrivKeyDma(whServerContext* ctx, uint16_t magic, - int devId, const void* cryptoDataIn, - uint16_t inSize, void* cryptoDataOut, - uint16_t* outSize) +static int _HandleMlKemDecapsDma(whServerContext* ctx, uint16_t magic, + int devId, const void* cryptoDataIn, + uint16_t inSize, void* cryptoDataOut, + uint16_t* outSize) { +#ifdef WOLFSSL_MLKEM_NO_DECAPSULATE (void)ctx; (void)magic; (void)devId; @@ -5681,11 +6414,121 @@ static int _HandleMlDsaCheckPrivKeyDma(whServerContext* ctx, uint16_t magic, (void)cryptoDataOut; (void)outSize; return WH_ERROR_NOHANDLER; +#else + int ret = WH_ERROR_OK; + MlKemKey key[1]; + void* ctAddr = NULL; + word32 ssLen = 0; + whKeyId key_id; + int evict = 0; + int keyInited = 0; + uint8_t* res_ss; + word32 max_ss; + whMessageCrypto_MlKemDecapsDmaRequest req; + whMessageCrypto_MlKemDecapsDmaResponse res; + + memset(&res, 0, sizeof(res)); + + if (inSize < sizeof(whMessageCrypto_MlKemDecapsDmaRequest)) { + return WH_ERROR_BADARGS; + } + + ret = wh_MessageCrypto_TranslateMlKemDecapsDmaRequest( + magic, (whMessageCrypto_MlKemDecapsDmaRequest*)cryptoDataIn, &req); + if (ret != WH_ERROR_OK) { + return ret; + } + + key_id = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, + ctx->comm->client_id, req.keyId); + evict = !!(req.options & WH_MESSAGE_CRYPTO_MLKEM_DECAPS_OPTIONS_EVICT); + + if (!WH_KEYID_ISERASED(key_id)) { + ret = wh_Server_KeystoreFindEnforceKeyUsage(ctx, key_id, + WH_NVM_FLAGS_USAGE_DERIVE); + if (ret != WH_ERROR_OK) { + goto cleanup; + } + } + + if (!_IsMlKemLevelSupported((int)req.level)) { + ret = WH_ERROR_BADARGS; + goto cleanup; + } + + ret = wc_MlKemKey_Init(key, (int)req.level, NULL, devId); + if (ret == WH_ERROR_OK) { + keyInited = 1; + ret = wh_Server_MlKemKeyCacheExport(ctx, key_id, key); + } + + /* Verify the exported key matches the requested level */ + if (ret == WH_ERROR_OK && key->type != (int)req.level) { + ret = WH_ERROR_BADARGS; + } + + if (ret == WH_ERROR_OK) { + ret = wc_MlKemKey_SharedSecretSize(key, &ssLen); + } + + /* Validate that the inline shared secret fits in the comm buffer */ + if (ret == WH_ERROR_OK) { + res_ss = (uint8_t*)cryptoDataOut + + sizeof(whMessageCrypto_MlKemDecapsDmaResponse); + max_ss = (word32)(WOLFHSM_CFG_COMM_DATA_LEN - + ((uint8_t*)res_ss - (uint8_t*)cryptoDataOut)); + if (ssLen > max_ss) { + ret = WH_ERROR_BADARGS; + } + } + + if (ret == WH_ERROR_OK) { + ret = wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.ct.addr, &ctAddr, req.ct.sz, + WH_DMA_OPER_CLIENT_READ_PRE, (whServerDmaFlags){0}); + if (ret == WH_ERROR_ACCESS) { + res.dmaAddrStatus.badAddr = req.ct; + } + } + + if (ret == WH_ERROR_OK) { + /* Shared secret goes inline in response, not via DMA */ + res_ss = (uint8_t*)cryptoDataOut + + sizeof(whMessageCrypto_MlKemDecapsDmaResponse); + ret = wc_MlKemKey_Decapsulate(key, res_ss, (const byte*)ctAddr, + (word32)req.ct.sz); + if (ret != WH_ERROR_OK) { + /* Zero sensitive data on failure */ + wc_ForceZero(res_ss, ssLen); + } + } + + if (ctAddr != NULL) { + (void)wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.ct.addr, &ctAddr, req.ct.sz, + WH_DMA_OPER_CLIENT_READ_POST, (whServerDmaFlags){0}); + } + + if (ret == WH_ERROR_OK) { + res.ssLen = ssLen; + } + + if (keyInited) { + wc_MlKemKey_Free(key); + } +cleanup: + if (evict != 0) { + (void)wh_Server_KeystoreEvictKey(ctx, key_id); + } + + (void)wh_MessageCrypto_TranslateMlKemDecapsDmaResponse( + magic, &res, (whMessageCrypto_MlKemDecapsDmaResponse*)cryptoDataOut); + *outSize = sizeof(res) + ssLen; + return ret; +#endif } -#endif /* HAVE_DILITHIUM */ -#if defined(HAVE_DILITHIUM) || defined(HAVE_FALCON) -static int _HandlePqcSigAlgorithmDma(whServerContext* ctx, uint16_t magic, +static int _HandlePqcKemAlgorithmDma(whServerContext* ctx, uint16_t magic, int devId, const void* cryptoDataIn, uint16_t cryptoInSize, void* cryptoDataOut, uint16_t* cryptoOutSize, @@ -5693,38 +6536,29 @@ static int _HandlePqcSigAlgorithmDma(whServerContext* ctx, uint16_t magic, { int ret = WH_ERROR_NOHANDLER; - /* Dispatch the appropriate algorithm handler based on the requested PK type - * and the algorithm type. */ switch (pqAlgoType) { -#ifdef HAVE_DILITHIUM - case WC_PQC_SIG_TYPE_DILITHIUM: { + case WC_PQC_KEM_TYPE_KYBER: { switch (pkAlgoType) { - case WC_PK_TYPE_PQC_SIG_KEYGEN: - ret = _HandleMlDsaKeyGenDma(ctx, magic, devId, cryptoDataIn, + case WC_PK_TYPE_PQC_KEM_KEYGEN: + ret = _HandleMlKemKeyGenDma(ctx, magic, devId, cryptoDataIn, cryptoInSize, cryptoDataOut, cryptoOutSize); break; - case WC_PK_TYPE_PQC_SIG_SIGN: - ret = _HandleMlDsaSignDma(ctx, magic, devId, cryptoDataIn, - cryptoInSize, cryptoDataOut, - cryptoOutSize); - break; - case WC_PK_TYPE_PQC_SIG_VERIFY: - ret = _HandleMlDsaVerifyDma(ctx, magic, devId, cryptoDataIn, + case WC_PK_TYPE_PQC_KEM_ENCAPS: + ret = _HandleMlKemEncapsDma(ctx, magic, devId, cryptoDataIn, cryptoInSize, cryptoDataOut, cryptoOutSize); break; - case WC_PK_TYPE_PQC_SIG_CHECK_PRIV_KEY: - ret = _HandleMlDsaCheckPrivKeyDma( - ctx, magic, devId, cryptoDataIn, cryptoInSize, - cryptoDataOut, cryptoOutSize); + case WC_PK_TYPE_PQC_KEM_DECAPS: + ret = _HandleMlKemDecapsDma(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + cryptoOutSize); break; default: ret = WH_ERROR_NOHANDLER; break; } } break; -#endif /* HAVE_DILITHIUM */ default: ret = WH_ERROR_NOHANDLER; break; @@ -5732,7 +6566,7 @@ static int _HandlePqcSigAlgorithmDma(whServerContext* ctx, uint16_t magic, return ret; } -#endif /* HAVE_DILITHIUM || HAVE_FALCON */ +#endif /* WOLFSSL_HAVE_MLKEM */ #if defined(WOLFSSL_CMAC) && !defined(NO_AES) && defined(WOLFSSL_AES_DIRECT) static int _HandleCmacDma(whServerContext* ctx, uint16_t magic, int devId, @@ -6105,6 +6939,16 @@ int wh_Server_HandleCryptoDmaRequest(whServerContext* ctx, uint16_t magic, rqstHeader.algoSubType); break; #endif /* HAVE_DILITHIUM || HAVE_FALCON */ +#if defined(WOLFSSL_HAVE_MLKEM) + case WC_PK_TYPE_PQC_KEM_KEYGEN: + case WC_PK_TYPE_PQC_KEM_ENCAPS: + case WC_PK_TYPE_PQC_KEM_DECAPS: + ret = _HandlePqcKemAlgorithmDma( + ctx, magic, devId, cryptoDataIn, cryptoInSize, + cryptoDataOut, &cryptoOutSize, rqstHeader.algoType, + rqstHeader.algoSubType); + break; +#endif /* WOLFSSL_HAVE_MLKEM */ #ifdef HAVE_ED25519 case WC_PK_TYPE_ED25519_SIGN: ret = _HandleEd25519SignDma(ctx, magic, devId, cryptoDataIn, diff --git a/src/wh_server_keystore.c b/src/wh_server_keystore.c index e3724474e..ec7a6789a 100644 --- a/src/wh_server_keystore.c +++ b/src/wh_server_keystore.c @@ -1237,8 +1237,8 @@ static int _HandleKeyWrapRequest(whServerContext* server, } /* Translate the server key id passed in from the client */ - serverKeyId = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, - server->comm->client_id, + serverKeyId = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, + server->comm->client_id, req->serverKeyId); /* Store the wrapped key in the response data */ @@ -1304,8 +1304,8 @@ static int _HandleKeyUnwrapAndExportRequest( wrappedKey = reqData; /* Translate the server key id passed in from the client */ - serverKeyId = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, - server->comm->client_id, + serverKeyId = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, + server->comm->client_id, req->serverKeyId); /* Ensure the cipher type in the response matches the request */ @@ -1424,8 +1424,8 @@ static int _HandleKeyUnwrapAndCacheRequest( wrappedKey = reqData; /* Translate the server key id passed in from the client */ - serverKeyId = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, - server->comm->client_id, + serverKeyId = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, + server->comm->client_id, req->serverKeyId); /* Ensure the cipher type in the response matches the request */ @@ -1535,8 +1535,8 @@ static int _HandleDataWrapRequest(whServerContext* server, memcpy(data, reqData, req->dataSz); /* Translate the server key id passed in from the client */ - serverKeyId = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, - server->comm->client_id, + serverKeyId = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, + server->comm->client_id, req->serverKeyId); /* Ensure the cipher type in the response matches the request */ @@ -1806,7 +1806,7 @@ int wh_Server_HandleKeyRequest(whServerContext* server, uint16_t magic, } if (ret == WH_ERROR_OK) { - resp.len = req.key.sz; + resp.len = meta->len; memcpy(resp.label, meta->label, sizeof(meta->label)); } diff --git a/test/config/user_settings.h b/test/config/user_settings.h index e86389345..e6fdf0629 100644 --- a/test/config/user_settings.h +++ b/test/config/user_settings.h @@ -139,6 +139,10 @@ #define WOLFSSL_SHAKE128 #define WOLFSSL_SHAKE256 +/* ML-KEM Options */ +#define WOLFSSL_HAVE_MLKEM +#define WOLFSSL_WC_MLKEM + /* Ed25519 Options */ #define HAVE_ED25519 diff --git a/test/wh_test_crypto.c b/test/wh_test_crypto.c index d33fa192e..189f549de 100644 --- a/test/wh_test_crypto.c +++ b/test/wh_test_crypto.c @@ -32,6 +32,7 @@ #include "wolfssl/wolfcrypt/types.h" #include "wolfssl/wolfcrypt/kdf.h" #include "wolfssl/wolfcrypt/ed25519.h" +#include "wolfssl/wolfcrypt/wc_mlkem.h" #include "wolfhsm/wh_error.h" @@ -58,6 +59,7 @@ #endif #include "wolfhsm/wh_transport_mem.h" +#include "wolfhsm/wh_crypto.h" #include "wh_test_common.h" @@ -7849,6 +7851,712 @@ int whTestCrypto_MlDsaVerifyOnlyDma(whClientContext* ctx, int devId, #endif /* HAVE_DILITHIUM */ +#ifdef WOLFSSL_HAVE_MLKEM +#if !defined(WOLFSSL_MLKEM_NO_MAKE_KEY) && \ + !defined(WOLFSSL_MLKEM_NO_ENCAPSULATE) && \ + !defined(WOLFSSL_MLKEM_NO_DECAPSULATE) +static int whTestCrypto_MlKemGetLevels(int* levels, int maxLevels) +{ + int count = 0; + +#ifndef WOLFSSL_NO_ML_KEM_512 + if (count < maxLevels) { + levels[count++] = WC_ML_KEM_512; + } +#endif +#ifndef WOLFSSL_NO_ML_KEM_768 + if (count < maxLevels) { + levels[count++] = WC_ML_KEM_768; + } +#endif +#ifndef WOLFSSL_NO_ML_KEM_1024 + if (count < maxLevels) { + levels[count++] = WC_ML_KEM_1024; + } +#endif + + return count; +} + +static int whTestCrypto_MlKemWolfCrypt(whClientContext* ctx, int devId, + WC_RNG* rng) +{ + int ret = 0; + int levels[3]; + int levelCnt = 0; + int i; + byte ct[WC_ML_KEM_MAX_CIPHER_TEXT_SIZE]; + byte ssEnc[WC_ML_KEM_SS_SZ]; + byte ssDec[WC_ML_KEM_SS_SZ]; + word32 ctLen; + word32 ssEncLen; + word32 ssDecLen; + + (void)ctx; + + levelCnt = + whTestCrypto_MlKemGetLevels(levels, (int)(sizeof(levels) / sizeof(levels[0]))); + + for (i = 0; (ret == 0) && (i < levelCnt); i++) { + MlKemKey key[1]; + int keyInited = 0; + + ctLen = sizeof(ct); + ssEncLen = sizeof(ssEnc); + ssDecLen = sizeof(ssDec); + memset(ct, 0, sizeof(ct)); + memset(ssEnc, 0, sizeof(ssEnc)); + memset(ssDec, 0, sizeof(ssDec)); + + ret = wc_MlKemKey_Init(key, levels[i], NULL, devId); + if (ret != 0) { + WH_ERROR_PRINT("Failed to init ML-KEM key level=%d ret=%d\n", + levels[i], ret); + break; + } + keyInited = 1; + + ret = wc_MlKemKey_MakeKey(key, rng); + if (ret != 0) { + WH_ERROR_PRINT("Failed to make ML-KEM key level=%d ret=%d\n", + levels[i], ret); + } + if (ret == 0) { + ret = wc_MlKemKey_CipherTextSize(key, &ctLen); + if (ret != 0) { + WH_ERROR_PRINT("Failed to get ML-KEM ct size level=%d ret=%d\n", + levels[i], ret); + } + } + if (ret == 0) { + ret = wc_MlKemKey_SharedSecretSize(key, &ssEncLen); + if (ret != 0) { + WH_ERROR_PRINT("Failed to get ML-KEM ss size level=%d ret=%d\n", + levels[i], ret); + } + else { + ssDecLen = ssEncLen; + } + } + if (ret == 0) { + ret = wc_MlKemKey_Encapsulate(key, ct, ssEnc, rng); + if (ret != 0) { + WH_ERROR_PRINT("Failed ML-KEM encapsulate level=%d ret=%d\n", + levels[i], ret); + } + } + if (ret == 0) { + ret = wc_MlKemKey_Decapsulate(key, ssDec, ct, ctLen); + if (ret != 0) { + WH_ERROR_PRINT("Failed ML-KEM decapsulate level=%d ret=%d\n", + levels[i], ret); + } + else if ((ssEncLen != ssDecLen) || + (memcmp(ssEnc, ssDec, ssEncLen) != 0)) { + WH_ERROR_PRINT("ML-KEM shared secret mismatch level=%d\n", + levels[i]); + ret = -1; + } + } + + if (keyInited) { + wc_MlKemKey_Free(key); + } + } + + if (ret == 0) { + WH_TEST_PRINT("ML-KEM DEVID=0x%X SUCCESS\n", devId); + } + + return ret; +} + +static int whTestCrypto_MlKemClient(whClientContext* ctx, int devId, WC_RNG* rng) +{ + int ret = 0; + int levels[3]; + int levelCnt = 0; + int i; + byte ct[WC_ML_KEM_MAX_CIPHER_TEXT_SIZE]; + byte ssEnc[WC_ML_KEM_SS_SZ]; + byte ssDec[WC_ML_KEM_SS_SZ]; + byte ssWrong[WC_ML_KEM_SS_SZ]; + byte usageCt[WC_ML_KEM_MAX_CIPHER_TEXT_SIZE]; + byte usageSs[WC_ML_KEM_SS_SZ]; + word32 ctLen; + word32 ssEncLen; + word32 ssDecLen; + word32 ssWrongLen; + word32 usageCtLen; + word32 usageSsLen; + const uint8_t usageLabel[] = "mlkem-no-derive"; + + (void)rng; + + levelCnt = + whTestCrypto_MlKemGetLevels(levels, (int)(sizeof(levels) / sizeof(levels[0]))); + + for (i = 0; (ret == 0) && (i < levelCnt); i++) { + MlKemKey key[1]; + MlKemKey wrongKey[1]; + MlKemKey usageKey[1]; + int keyInited = 0; + int wrongInited = 0; + int usageInited = 0; + whKeyId usageKeyId = WH_KEYID_ERASED; + int usageKeyCached = 0; + + ctLen = sizeof(ct); + ssEncLen = sizeof(ssEnc); + ssDecLen = sizeof(ssDec); + ssWrongLen = sizeof(ssWrong); + usageCtLen = sizeof(usageCt); + usageSsLen = sizeof(usageSs); + memset(ct, 0, sizeof(ct)); + memset(ssEnc, 0, sizeof(ssEnc)); + memset(ssDec, 0, sizeof(ssDec)); + memset(ssWrong, 0, sizeof(ssWrong)); + memset(usageCt, 0, sizeof(usageCt)); + memset(usageSs, 0, sizeof(usageSs)); + + ret = wc_MlKemKey_Init(key, levels[i], NULL, devId); + if (ret != 0) { + WH_ERROR_PRINT("Failed to init ML-KEM client key level=%d ret=%d\n", + levels[i], ret); + break; + } + keyInited = 1; + + ret = wc_MlKemKey_Init(wrongKey, levels[i], NULL, devId); + if (ret != 0) { + WH_ERROR_PRINT( + "Failed to init ML-KEM wrong key level=%d ret=%d\n", + levels[i], ret); + } + else { + wrongInited = 1; + } + + if (ret == 0) { + ret = wh_Client_MlKemMakeExportKey(ctx, levels[i], key); + if (ret != 0) { + WH_ERROR_PRINT( + "Failed ML-KEM make export key level=%d ret=%d\n", + levels[i], ret); + } + } + if (ret == 0) { + ret = wh_Client_MlKemMakeExportKey(ctx, levels[i], wrongKey); + if (ret != 0) { + WH_ERROR_PRINT( + "Failed ML-KEM make wrong export key level=%d ret=%d\n", + levels[i], ret); + } + } + + if (ret == 0) { + ret = wh_Client_MlKemEncapsulate(ctx, key, ct, &ctLen, ssEnc, + &ssEncLen); + if (ret != 0) { + WH_ERROR_PRINT("Failed ML-KEM encapsulate level=%d ret=%d\n", + levels[i], ret); + } + } + if (ret == 0) { + ret = wh_Client_MlKemDecapsulate(ctx, key, ct, ctLen, ssDec, + &ssDecLen); + if (ret != 0) { + WH_ERROR_PRINT("Failed ML-KEM decapsulate level=%d ret=%d\n", + levels[i], ret); + } + else if ((ssEncLen != ssDecLen) || + (memcmp(ssEnc, ssDec, ssEncLen) != 0)) { + WH_ERROR_PRINT("ML-KEM client shared secret mismatch level=%d\n", + levels[i]); + ret = -1; + } + } + if (ret == 0) { + ret = wh_Client_MlKemDecapsulate(ctx, wrongKey, ct, ctLen, ssWrong, + &ssWrongLen); + if (ret != 0) { + WH_ERROR_PRINT( + "Failed ML-KEM wrong-key decapsulate level=%d ret=%d\n", + levels[i], ret); + } + else if ((ssWrongLen == ssEncLen) && + (memcmp(ssWrong, ssEnc, ssEncLen) == 0)) { + WH_ERROR_PRINT( + "ML-KEM wrong-key decaps unexpectedly matched level=%d\n", + levels[i]); + ret = -1; + } + } + + if (ret == 0) { + ret = wh_Client_MlKemMakeCacheKey( + ctx, levels[i], &usageKeyId, WH_NVM_FLAGS_NONE, + (uint16_t)strlen((const char*)usageLabel), (uint8_t*)usageLabel); + if (ret != 0) { + WH_ERROR_PRINT( + "Failed ML-KEM cache key without derive level=%d ret=%d\n", + levels[i], ret); + } + else { + usageKeyCached = 1; + } + } + if (ret == 0) { + ret = wc_MlKemKey_Init(usageKey, levels[i], NULL, devId); + if (ret != 0) { + WH_ERROR_PRINT("Failed init ML-KEM usage key level=%d ret=%d\n", + levels[i], ret); + } + else { + usageInited = 1; + } + } + if (ret == 0) { + ret = wh_Client_MlKemSetKeyId(usageKey, usageKeyId); + if (ret != 0) { + WH_ERROR_PRINT( + "Failed set ML-KEM usage key ID level=%d ret=%d\n", + levels[i], ret); + } + } + if (ret == 0) { + ret = wh_Client_MlKemEncapsulate(ctx, usageKey, usageCt, &usageCtLen, + usageSs, &usageSsLen); + if (ret == WH_ERROR_USAGE) { + ret = 0; + } + else { + WH_ERROR_PRINT("Expected WH_ERROR_USAGE for ML-KEM derive " + "policy encaps level=%d got=%d\n", + levels[i], ret); + ret = WH_ERROR_ABORTED; + } + } + /* Negative test: decapsulate with key lacking derive usage */ + if (ret == 0) { + byte dummyCt[WC_ML_KEM_MAX_CIPHER_TEXT_SIZE] = {0}; + word32 dummyCtLen = sizeof(dummyCt); + ret = wh_Client_MlKemDecapsulate(ctx, usageKey, dummyCt, + dummyCtLen, usageSs, + &usageSsLen); + if (ret == WH_ERROR_USAGE) { + ret = 0; + } + else { + WH_ERROR_PRINT("Expected WH_ERROR_USAGE for ML-KEM derive " + "policy decaps level=%d got=%d\n", + levels[i], ret); + ret = WH_ERROR_ABORTED; + } + } + + if (usageKeyCached) { + int evictRet = wh_Client_KeyEvict(ctx, usageKeyId); + if ((evictRet != 0) && (ret == 0)) { + WH_ERROR_PRINT("Failed ML-KEM usage key evict level=%d ret=%d\n", + levels[i], evictRet); + ret = evictRet; + } + usageKeyCached = 0; + } + if (usageInited) { + wc_MlKemKey_Free(usageKey); + usageInited = 0; + } + + /* Positive test: cached key WITH derive usage should succeed */ + if (ret == 0) { + const uint8_t deriveLabel[] = "mlkem-derive-ok"; + byte deriveCt[WC_ML_KEM_MAX_CIPHER_TEXT_SIZE]; + byte deriveSsEnc[WC_ML_KEM_SS_SZ]; + byte deriveSsDec[WC_ML_KEM_SS_SZ]; + word32 deriveCtLen = sizeof(deriveCt); + word32 deriveSsEncLen = sizeof(deriveSsEnc); + word32 deriveSsDecLen = sizeof(deriveSsDec); + + ret = wh_Client_MlKemMakeCacheKey( + ctx, levels[i], &usageKeyId, WH_NVM_FLAGS_USAGE_DERIVE, + (uint16_t)strlen((const char*)deriveLabel), + (uint8_t*)deriveLabel); + if (ret != 0) { + WH_ERROR_PRINT( + "Failed ML-KEM cache key with derive level=%d ret=%d\n", + levels[i], ret); + } + else { + usageKeyCached = 1; + } + if (ret == 0) { + ret = wc_MlKemKey_Init(usageKey, levels[i], NULL, devId); + if (ret != 0) { + WH_ERROR_PRINT( + "Failed init ML-KEM derive key level=%d ret=%d\n", + levels[i], ret); + } + else { + usageInited = 1; + } + } + if (ret == 0) { + ret = wh_Client_MlKemSetKeyId(usageKey, usageKeyId); + } + if (ret == 0) { + ret = wh_Client_MlKemEncapsulate(ctx, usageKey, deriveCt, + &deriveCtLen, deriveSsEnc, + &deriveSsEncLen); + if (ret != 0) { + WH_ERROR_PRINT("Failed ML-KEM encaps with derive key " + "level=%d ret=%d\n", + levels[i], ret); + } + } + if (ret == 0) { + ret = wh_Client_MlKemDecapsulate(ctx, usageKey, deriveCt, + deriveCtLen, deriveSsDec, + &deriveSsDecLen); + if (ret != 0) { + WH_ERROR_PRINT("Failed ML-KEM decaps with derive key " + "level=%d ret=%d\n", + levels[i], ret); + } + else if ((deriveSsEncLen != deriveSsDecLen) || + (memcmp(deriveSsEnc, deriveSsDec, + deriveSsEncLen) != 0)) { + WH_ERROR_PRINT("ML-KEM derive key shared secret mismatch " + "level=%d\n", + levels[i]); + ret = -1; + } + } + if (usageKeyCached) { + int evictRet = wh_Client_KeyEvict(ctx, usageKeyId); + if ((evictRet != 0) && (ret == 0)) { + WH_ERROR_PRINT("Failed ML-KEM derive key evict level=%d " + "ret=%d\n", + levels[i], evictRet); + ret = evictRet; + } + } + if (usageInited) { + wc_MlKemKey_Free(usageKey); + } + } + + if (wrongInited) { + wc_MlKemKey_Free(wrongKey); + } + if (keyInited) { + wc_MlKemKey_Free(key); + } + } + + if (ret == 0) { + WH_TEST_PRINT("ML-KEM Client Non-DMA API SUCCESS\n"); + } + + return ret; +} + +#ifdef WOLFHSM_CFG_DMA +static int whTestCrypto_MlKemDmaClient(whClientContext* ctx, int devId, + WC_RNG* rng) +{ + int ret = 0; + int levels[3]; + int levelCnt = 0; + int i; + byte ct[WC_ML_KEM_MAX_CIPHER_TEXT_SIZE]; + byte ssEnc[WC_ML_KEM_SS_SZ]; + byte ssDec[WC_ML_KEM_SS_SZ]; + byte ssWrong[WC_ML_KEM_SS_SZ]; + byte keyBuf1[WC_ML_KEM_MAX_PRIVATE_KEY_SIZE]; + byte keyBuf2[WC_ML_KEM_MAX_PRIVATE_KEY_SIZE]; + word32 ctLen; + word32 ssEncLen; + word32 ssDecLen; + word32 ssWrongLen; + uint16_t keyBuf1Len; + uint16_t keyBuf2Len; + whKeyId keyId; + const uint8_t cacheLabel[] = "mlkem-dma-cache"; + + (void)rng; + + levelCnt = + whTestCrypto_MlKemGetLevels(levels, (int)(sizeof(levels) / sizeof(levels[0]))); + + for (i = 0; (ret == 0) && (i < levelCnt); i++) { + MlKemKey key[1]; + MlKemKey importedKey[1]; + MlKemKey wrongKey[1]; + int keyInited = 0; + int importedInited = 0; + int wrongInited = 0; + int keyCached = 0; + + ctLen = sizeof(ct); + ssEncLen = sizeof(ssEnc); + ssDecLen = sizeof(ssDec); + ssWrongLen = sizeof(ssWrong); + keyBuf1Len = sizeof(keyBuf1); + keyBuf2Len = sizeof(keyBuf2); + keyId = WH_KEYID_ERASED; + + memset(ct, 0, sizeof(ct)); + memset(ssEnc, 0, sizeof(ssEnc)); + memset(ssDec, 0, sizeof(ssDec)); + memset(ssWrong, 0, sizeof(ssWrong)); + memset(keyBuf1, 0, sizeof(keyBuf1)); + memset(keyBuf2, 0, sizeof(keyBuf2)); + + ret = wc_MlKemKey_Init(key, levels[i], NULL, devId); + if (ret != 0) { + WH_ERROR_PRINT("Failed init ML-KEM DMA key level=%d ret=%d\n", + levels[i], ret); + break; + } + keyInited = 1; + + ret = wc_MlKemKey_Init(importedKey, levels[i], NULL, devId); + if (ret != 0) { + WH_ERROR_PRINT("Failed init ML-KEM DMA imported key level=%d " + "ret=%d\n", + levels[i], ret); + } + else { + importedInited = 1; + } + + if (ret == 0) { + ret = wc_MlKemKey_Init(wrongKey, levels[i], NULL, devId); + if (ret != 0) { + WH_ERROR_PRINT("Failed init ML-KEM DMA wrong key level=%d " + "ret=%d\n", + levels[i], ret); + } + else { + wrongInited = 1; + } + } + + if (ret == 0) { + ret = wh_Client_MlKemMakeExportKeyDma(ctx, levels[i], key); + if (ret != 0) { + WH_ERROR_PRINT("Failed ML-KEM DMA keygen level=%d ret=%d\n", + levels[i], ret); + } + } + if (ret == 0) { + ret = wh_Client_MlKemMakeExportKeyDma(ctx, levels[i], wrongKey); + if (ret != 0) { + WH_ERROR_PRINT("Failed ML-KEM DMA wrong keygen level=%d ret=%d\n", + levels[i], ret); + } + } + + if (ret == 0) { + ret = wh_Crypto_MlKemSerializeKey(key, keyBuf1Len, keyBuf1, + &keyBuf1Len); + if (ret != 0) { + WH_ERROR_PRINT("Failed ML-KEM DMA serialize key level=%d " + "ret=%d\n", + levels[i], ret); + } + } + if (ret == 0) { + ret = wh_Client_MlKemImportKeyDma( + ctx, key, &keyId, WH_NVM_FLAGS_NONE, + (uint16_t)strlen((const char*)cacheLabel), (uint8_t*)cacheLabel); + if (ret != 0) { + WH_ERROR_PRINT("Failed ML-KEM DMA import key level=%d ret=%d\n", + levels[i], ret); + } + else { + keyCached = 1; + } + } + if (ret == 0) { + ret = wh_Client_MlKemExportKeyDma( + ctx, keyId, importedKey, + (uint16_t)strlen((const char*)cacheLabel), (uint8_t*)cacheLabel); + if (ret != 0) { + WH_ERROR_PRINT("Failed ML-KEM DMA export key level=%d ret=%d\n", + levels[i], ret); + } + } + if (ret == 0) { + ret = wh_Crypto_MlKemSerializeKey(importedKey, keyBuf2Len, keyBuf2, + &keyBuf2Len); + if (ret != 0) { + WH_ERROR_PRINT("Failed ML-KEM DMA serialize imported key " + "level=%d ret=%d\n", + levels[i], ret); + } + else if ((keyBuf1Len != keyBuf2Len) || + (memcmp(keyBuf1, keyBuf2, keyBuf1Len) != 0)) { + WH_ERROR_PRINT("ML-KEM DMA imported key mismatch level=%d\n", + levels[i]); + ret = -1; + } + } + + if (ret == 0) { + ret = wh_Client_MlKemEncapsulateDma(ctx, key, ct, &ctLen, ssEnc, + &ssEncLen); + if (ret != 0) { + WH_ERROR_PRINT("Failed ML-KEM DMA encapsulate level=%d ret=%d\n", + levels[i], ret); + } + } + if (ret == 0) { + ret = wh_Client_MlKemDecapsulateDma(ctx, key, ct, ctLen, ssDec, + &ssDecLen); + if (ret != 0) { + WH_ERROR_PRINT("Failed ML-KEM DMA decapsulate level=%d ret=%d\n", + levels[i], ret); + } + else if ((ssEncLen != ssDecLen) || + (memcmp(ssEnc, ssDec, ssEncLen) != 0)) { + WH_ERROR_PRINT("ML-KEM DMA shared secret mismatch level=%d\n", + levels[i]); + ret = -1; + } + } + if (ret == 0) { + ret = wh_Client_MlKemDecapsulateDma(ctx, wrongKey, ct, ctLen, + ssWrong, &ssWrongLen); + if (ret != 0) { + WH_ERROR_PRINT("Failed ML-KEM DMA wrong-key decaps level=%d " + "ret=%d\n", + levels[i], ret); + } + else if ((ssWrongLen == ssEncLen) && + (memcmp(ssWrong, ssEnc, ssEncLen) == 0)) { + WH_ERROR_PRINT("ML-KEM DMA wrong-key decaps unexpectedly " + "matched level=%d\n", + levels[i]); + ret = -1; + } + } + + /* Usage policy enforcement: key without derive should be denied */ + if (ret == 0) { + MlKemKey usageKey[1]; + whKeyId usageKeyId = WH_KEYID_ERASED; + int usageInited = 0; + int usageKeyCached = 0; + const uint8_t usageLabel[] = "mlkem-dma-nouse"; + + ret = wh_Client_MlKemMakeCacheKey( + ctx, levels[i], &usageKeyId, WH_NVM_FLAGS_NONE, + (uint16_t)strlen((const char*)usageLabel), + (uint8_t*)usageLabel); + if (ret != 0) { + WH_ERROR_PRINT("Failed ML-KEM DMA cache key without derive " + "level=%d ret=%d\n", + levels[i], ret); + } + else { + usageKeyCached = 1; + } + if (ret == 0) { + ret = wc_MlKemKey_Init(usageKey, levels[i], NULL, devId); + if (ret != 0) { + WH_ERROR_PRINT("Failed init ML-KEM DMA usage key " + "level=%d ret=%d\n", + levels[i], ret); + } + else { + usageInited = 1; + } + } + if (ret == 0) { + ret = wh_Client_MlKemSetKeyId(usageKey, usageKeyId); + } + if (ret == 0) { + word32 tmpCtLen = sizeof(ct); + word32 tmpSsLen = sizeof(ssEnc); + ret = wh_Client_MlKemEncapsulateDma(ctx, usageKey, ct, + &tmpCtLen, ssEnc, + &tmpSsLen); + if (ret == WH_ERROR_USAGE) { + ret = 0; /* Expected */ + } + else { + WH_ERROR_PRINT("Expected WH_ERROR_USAGE for ML-KEM DMA " + "derive policy encaps level=%d got=%d\n", + levels[i], ret); + ret = WH_ERROR_ABORTED; + } + } + /* Negative test: DMA decapsulate with key lacking derive usage */ + if (ret == 0) { + byte dummyCt[WC_ML_KEM_MAX_CIPHER_TEXT_SIZE] = {0}; + word32 dummySsLen = sizeof(ssEnc); + ret = wh_Client_MlKemDecapsulateDma( + ctx, usageKey, dummyCt, + sizeof(dummyCt), ssEnc, &dummySsLen); + if (ret == WH_ERROR_USAGE) { + ret = 0; /* Expected */ + } + else { + WH_ERROR_PRINT("Expected WH_ERROR_USAGE for ML-KEM DMA " + "derive policy decaps level=%d got=%d\n", + levels[i], ret); + ret = WH_ERROR_ABORTED; + } + } + if (usageKeyCached) { + int evictRet = wh_Client_KeyEvict(ctx, usageKeyId); + if ((evictRet != 0) && (ret == 0)) { + WH_ERROR_PRINT("Failed ML-KEM DMA usage key evict " + "level=%d ret=%d\n", + levels[i], evictRet); + ret = evictRet; + } + } + if (usageInited) { + wc_MlKemKey_Free(usageKey); + } + } + + if (keyCached) { + int evictRet = wh_Client_KeyEvict(ctx, keyId); + if ((evictRet != 0) && (ret == 0)) { + WH_ERROR_PRINT("Failed ML-KEM DMA evict cached key level=%d " + "ret=%d\n", + levels[i], evictRet); + ret = evictRet; + } + } + if (wrongInited) { + wc_MlKemKey_Free(wrongKey); + } + if (importedInited) { + wc_MlKemKey_Free(importedKey); + } + if (keyInited) { + wc_MlKemKey_Free(key); + } + } + + if (ret == 0) { + WH_TEST_PRINT("ML-KEM Client DMA API SUCCESS\n"); + } + + return ret; +} +#endif /* WOLFHSM_CFG_DMA */ +#endif /* !defined(WOLFSSL_MLKEM_NO_MAKE_KEY) && \ + !defined(WOLFSSL_MLKEM_NO_ENCAPSULATE) && \ + !defined(WOLFSSL_MLKEM_NO_DECAPSULATE) */ +#endif /* WOLFSSL_HAVE_MLKEM */ + /* Test key usage policy enforcement for various crypto operations */ int whTest_CryptoKeyUsagePolicies(whClientContext* client, WC_RNG* rng) { @@ -8825,6 +9533,38 @@ int whTest_CryptoClientConfig(whClientConfig* config) #endif /* HAVE_DILITHIUM */ +#ifdef WOLFSSL_HAVE_MLKEM +#if !defined(WOLFSSL_MLKEM_NO_MAKE_KEY) && \ + !defined(WOLFSSL_MLKEM_NO_ENCAPSULATE) && \ + !defined(WOLFSSL_MLKEM_NO_DECAPSULATE) + i = 0; + while (ret == WH_ERROR_OK && i < WH_NUM_DEVIDS) { +#ifdef WOLFHSM_CFG_TEST_CLIENT_LARGE_DATA_DMA_ONLY + if (WH_DEV_IDS_ARRAY[i] != WH_DEV_ID_DMA) { + i++; + continue; + } +#endif /* WOLFHSM_CFG_TEST_CLIENT_LARGE_DATA_DMA_ONLY */ + ret = whTestCrypto_MlKemWolfCrypt(client, WH_DEV_IDS_ARRAY[i], rng); + if (ret == WH_ERROR_OK) { + i++; + } + } + + if (ret == 0) { + ret = whTestCrypto_MlKemClient(client, WH_DEV_ID, rng); + } + +#ifdef WOLFHSM_CFG_DMA + if (ret == 0) { + ret = whTestCrypto_MlKemDmaClient(client, WH_DEV_ID_DMA, rng); + } +#endif /* WOLFHSM_CFG_DMA */ +#endif /* !defined(WOLFSSL_MLKEM_NO_MAKE_KEY) && \ + !defined(WOLFSSL_MLKEM_NO_ENCAPSULATE) && \ + !defined(WOLFSSL_MLKEM_NO_DECAPSULATE) */ +#endif /* WOLFSSL_HAVE_MLKEM */ + #ifdef WOLFHSM_CFG_DEBUG_VERBOSE if (ret == 0) { (void)whTest_ShowNvmAvailable(client); diff --git a/wolfhsm/wh_client_crypto.h b/wolfhsm/wh_client_crypto.h index 044fb8c05..4b069a694 100644 --- a/wolfhsm/wh_client_crypto.h +++ b/wolfhsm/wh_client_crypto.h @@ -51,6 +51,7 @@ #include "wolfssl/wolfcrypt/ecc.h" #include "wolfssl/wolfcrypt/ed25519.h" #include "wolfssl/wolfcrypt/dilithium.h" +#include "wolfssl/wolfcrypt/wc_mlkem.h" #include "wolfssl/wolfcrypt/hmac.h" /** @@ -2040,5 +2041,204 @@ int wh_Client_MlDsaCheckPrivKeyDma(whClientContext* ctx, MlDsaKey* key, #endif /* HAVE_DILITHIUM */ +#ifdef WOLFSSL_HAVE_MLKEM + +/** + * @brief Associate a ML-KEM key with a specific key ID. + * + * Sets the device context of a ML-KEM key to the specified key ID. On the + * server side, this key ID is used to reference the key stored in the HSM. + * + * @param[in] key Pointer to the ML-KEM key structure. + * @param[in] keyId Key ID to be associated with the ML-KEM key. + * @return int Returns 0 on success or a negative error code on failure. + */ +int wh_Client_MlKemSetKeyId(MlKemKey* key, whKeyId keyId); + +/** + * @brief Retrieve the key ID associated with a ML-KEM key. + * + * @param[in] key Pointer to the ML-KEM key structure. + * @param[out] outId Pointer to store the retrieved key ID. + * @return int Returns 0 on success or a negative error code on failure. + */ +int wh_Client_MlKemGetKeyId(MlKemKey* key, whKeyId* outId); + +/** + * @brief Import a ML-KEM key to the server key cache. + * + * @param[in] ctx Pointer to the client context. + * @param[in] key Pointer to the ML-KEM key to import. + * @param[in,out] inout_keyId Pointer to key ID to use/receive. + * @param[in] flags Flags to control key persistence. + * @param[in] label_len Length of optional label in bytes. + * @param[in] label Optional label to associate with the key. + * @return int Returns 0 on success or a negative error code on failure. + */ +int wh_Client_MlKemImportKey(whClientContext* ctx, MlKemKey* key, + whKeyId* inout_keyId, whNvmFlags flags, + uint16_t label_len, uint8_t* label); + +/** + * @brief Export a ML-KEM key from the server key cache. + * + * @param[in] ctx Pointer to the client context. + * @param[in] keyId Key ID of the key to export. + * @param[out] key Pointer to the ML-KEM key structure to populate. + * @param[in] label_len Length of optional label in bytes. + * @param[out] label Optional label buffer to receive the key label. + * @return int Returns 0 on success or a negative error code on failure. + */ +int wh_Client_MlKemExportKey(whClientContext* ctx, whKeyId keyId, MlKemKey* key, + uint16_t label_len, uint8_t* label); + +/** + * @brief Generate a ML-KEM key pair and return it as an ephemeral key. + * + * The key pair is generated on the server, serialized, and returned to the + * client without being cached. + * + * @param[in] ctx Pointer to the client context. + * @param[in] level ML-KEM security level (WC_ML_KEM_512/768/1024). + * @param[out] key Pointer to the ML-KEM key to populate with the generated key. + * @return int Returns 0 on success or a negative error code on failure. + */ +int wh_Client_MlKemMakeExportKey(whClientContext* ctx, int level, + MlKemKey* key); + +/** + * @brief Generate a ML-KEM key pair and cache it on the server. + * + * @param[in] ctx Pointer to the client context. + * @param[in] level ML-KEM security level (WC_ML_KEM_512/768/1024). + * @param[in,out] inout_key_id Pointer to key ID to use/receive. + * @param[in] flags Flags to control key persistence and usage. + * @param[in] label_len Length of optional label in bytes. + * @param[in] label Optional label to associate with the key. + * @return int Returns 0 on success or a negative error code on failure. + */ +int wh_Client_MlKemMakeCacheKey(whClientContext* ctx, int level, + whKeyId* inout_key_id, whNvmFlags flags, + uint16_t label_len, uint8_t* label); + +/** + * @brief Perform ML-KEM encapsulation using a server-cached public key. + * + * Generates a shared secret and ciphertext using the public key identified by + * the key ID stored in the provided MlKemKey. If the key is not yet cached, + * it will be auto-imported and evicted after use. + * + * @param[in] ctx Pointer to the client context. + * @param[in] key Pointer to the ML-KEM key (must have key ID set). + * @param[out] ct Buffer to receive the ciphertext. + * @param[in,out] inout_ct_len On input, size of ct buffer; on output, actual + * ciphertext length. + * @param[out] ss Buffer to receive the shared secret. + * @param[in,out] inout_ss_len On input, size of ss buffer; on output, actual + * shared secret length. + * @return int Returns 0 on success or a negative error code on failure. + */ +int wh_Client_MlKemEncapsulate(whClientContext* ctx, MlKemKey* key, + byte* ct, word32* inout_ct_len, + byte* ss, word32* inout_ss_len); + +/** + * @brief Perform ML-KEM decapsulation using a server-cached private key. + * + * Recovers the shared secret from the ciphertext using the private key + * identified by the key ID stored in the provided MlKemKey. If the key is not + * yet cached, it will be auto-imported and evicted after use. + * + * @param[in] ctx Pointer to the client context. + * @param[in] key Pointer to the ML-KEM key (must have key ID set). + * @param[in] ct Pointer to the ciphertext. + * @param[in] ct_len Length of the ciphertext in bytes. + * @param[out] ss Buffer to receive the shared secret. + * @param[in,out] inout_ss_len On input, size of ss buffer; on output, actual + * shared secret length. + * @return int Returns 0 on success or a negative error code on failure. + */ +int wh_Client_MlKemDecapsulate(whClientContext* ctx, MlKemKey* key, + const byte* ct, word32 ct_len, byte* ss, + word32* inout_ss_len); + +#ifdef WOLFHSM_CFG_DMA + +/** + * @brief Import a ML-KEM key using DMA. + * + * @param[in] ctx Pointer to the client context. + * @param[in] key Pointer to the ML-KEM key to import. + * @param[in,out] inout_keyId Pointer to store/provide the key ID. + * @param[in] flags NVM flags for key storage. + * @param[in] label_len Length of the key label in bytes. + * @param[in] label Pointer to the key label. + * @return int Returns 0 on success or a negative error code on failure. + */ +int wh_Client_MlKemImportKeyDma(whClientContext* ctx, MlKemKey* key, + whKeyId* inout_keyId, whNvmFlags flags, + uint16_t label_len, uint8_t* label); + +/** + * @brief Export a ML-KEM key from the server using DMA. + * + * @param[in] ctx Pointer to the client context. + * @param[in] keyId Key ID of the key to export. + * @param[out] key Pointer to the ML-KEM key structure to populate. + * @param[in] label_len Length of the key label in bytes. + * @param[out] label Pointer to the key label buffer. + * @return int Returns 0 on success or a negative error code on failure. + */ +int wh_Client_MlKemExportKeyDma(whClientContext* ctx, whKeyId keyId, + MlKemKey* key, uint16_t label_len, + uint8_t* label); + +/** + * @brief Generate an ephemeral ML-KEM key pair using DMA. + * + * @param[in] ctx Pointer to the client context. + * @param[in] level ML-KEM security level (WC_ML_KEM_512/768/1024). + * @param[out] key Pointer to the ML-KEM key to populate. + * @return int Returns 0 on success or a negative error code on failure. + */ +int wh_Client_MlKemMakeExportKeyDma(whClientContext* ctx, int level, + MlKemKey* key); + +/** + * @brief Perform ML-KEM encapsulation using DMA. + * + * @param[in] ctx Pointer to the client context. + * @param[in] key Pointer to the ML-KEM key (must have key ID set). + * @param[out] ct Buffer to receive the ciphertext. + * @param[in,out] inout_ct_len On input, size of ct buffer; on output, actual + * ciphertext length. + * @param[out] ss Buffer to receive the shared secret. + * @param[in,out] inout_ss_len On input, size of ss buffer; on output, actual + * shared secret length. + * @return int Returns 0 on success or a negative error code on failure. + */ +int wh_Client_MlKemEncapsulateDma(whClientContext* ctx, MlKemKey* key, + byte* ct, word32* inout_ct_len, byte* ss, + word32* inout_ss_len); + +/** + * @brief Perform ML-KEM decapsulation using DMA. + * + * @param[in] ctx Pointer to the client context. + * @param[in] key Pointer to the ML-KEM key (must have key ID set). + * @param[in] ct Pointer to the ciphertext. + * @param[in] ct_len Length of the ciphertext in bytes. + * @param[out] ss Buffer to receive the shared secret. + * @param[in,out] inout_ss_len On input, size of ss buffer; on output, actual + * shared secret length. + * @return int Returns 0 on success or a negative error code on failure. + */ +int wh_Client_MlKemDecapsulateDma(whClientContext* ctx, MlKemKey* key, + const byte* ct, word32 ct_len, byte* ss, + word32* inout_ss_len); +#endif /* WOLFHSM_CFG_DMA */ + +#endif /* WOLFSSL_HAVE_MLKEM */ + #endif /* !WOLFHSM_CFG_NO_CRYPTO */ #endif /* !WOLFHSM_WH_CLIENT_CRYPTO_H_ */ diff --git a/wolfhsm/wh_crypto.h b/wolfhsm/wh_crypto.h index 7254d783c..6f5cfc9c3 100644 --- a/wolfhsm/wh_crypto.h +++ b/wolfhsm/wh_crypto.h @@ -43,6 +43,7 @@ #include "wolfssl/wolfcrypt/ecc.h" #include "wolfssl/wolfcrypt/ed25519.h" #include "wolfssl/wolfcrypt/dilithium.h" +#include "wolfssl/wolfcrypt/wc_mlkem.h" #include "wolfhsm/wh_message_crypto.h" @@ -118,6 +119,16 @@ int wh_Crypto_MlDsaDeserializeKeyDer(const uint8_t* buffer, uint16_t size, MlDsaKey* key); #endif /* HAVE_DILITHIUM */ +#ifdef WOLFSSL_HAVE_MLKEM +/* Store a MlKemKey to a byte sequence */ +int wh_Crypto_MlKemSerializeKey(MlKemKey* key, uint16_t max_size, + uint8_t* buffer, uint16_t* out_size); +/* Restore a MlKemKey from a byte sequence. Tries the level already set in the + * key first, then probes other supported ML-KEM levels if needed. */ +int wh_Crypto_MlKemDeserializeKey(const uint8_t* buffer, uint16_t size, + MlKemKey* key); +#endif /* WOLFSSL_HAVE_MLKEM */ + #endif /* !WOLFHSM_CFG_NO_CRYPTO */ #endif /* WOLFHSM_WH_CRYPTO_H_ */ diff --git a/wolfhsm/wh_message_crypto.h b/wolfhsm/wh_message_crypto.h index ac3417743..3305a5fd5 100644 --- a/wolfhsm/wh_message_crypto.h +++ b/wolfhsm/wh_message_crypto.h @@ -1012,6 +1012,92 @@ int wh_MessageCrypto_TranslateMlDsaVerifyResponse( uint16_t magic, const whMessageCrypto_MlDsaVerifyResponse* src, whMessageCrypto_MlDsaVerifyResponse* dest); +/* + * ML-KEM + */ + +/* ML-KEM Key Generation Request */ +typedef struct { + uint32_t level; + uint32_t keyId; + uint32_t flags; + uint32_t access; + uint8_t label[WH_NVM_LABEL_LEN]; +} whMessageCrypto_MlKemKeyGenRequest; + +/* ML-KEM Key Generation Response */ +typedef struct { + uint32_t keyId; + uint32_t len; + /* Data follows: + * uint8_t out[len]; + */ +} whMessageCrypto_MlKemKeyGenResponse; + +int wh_MessageCrypto_TranslateMlKemKeyGenRequest( + uint16_t magic, const whMessageCrypto_MlKemKeyGenRequest* src, + whMessageCrypto_MlKemKeyGenRequest* dest); + +int wh_MessageCrypto_TranslateMlKemKeyGenResponse( + uint16_t magic, const whMessageCrypto_MlKemKeyGenResponse* src, + whMessageCrypto_MlKemKeyGenResponse* dest); + +/* ML-KEM Encapsulation Request */ +typedef struct { + uint32_t options; +#define WH_MESSAGE_CRYPTO_MLKEM_ENCAPS_OPTIONS_EVICT (1 << 0) + uint32_t level; + uint32_t keyId; + uint8_t WH_PAD[4]; +} whMessageCrypto_MlKemEncapsRequest; + +/* ML-KEM Encapsulation Response */ +typedef struct { + uint32_t ctSz; + uint32_t ssSz; + /* Data follows: + * uint8_t ct[ctSz]; + * uint8_t ss[ssSz]; + */ +} whMessageCrypto_MlKemEncapsResponse; + +int wh_MessageCrypto_TranslateMlKemEncapsRequest( + uint16_t magic, const whMessageCrypto_MlKemEncapsRequest* src, + whMessageCrypto_MlKemEncapsRequest* dest); + +int wh_MessageCrypto_TranslateMlKemEncapsResponse( + uint16_t magic, const whMessageCrypto_MlKemEncapsResponse* src, + whMessageCrypto_MlKemEncapsResponse* dest); + +/* ML-KEM Decapsulation Request */ +typedef struct { + uint32_t options; +#define WH_MESSAGE_CRYPTO_MLKEM_DECAPS_OPTIONS_EVICT (1 << 0) + uint32_t level; + uint32_t keyId; + uint32_t ctSz; + /* Data follows: + * uint8_t ct[ctSz]; + */ +} whMessageCrypto_MlKemDecapsRequest; + +/* ML-KEM Decapsulation Response */ +typedef struct { + uint32_t ssSz; + uint8_t WH_PAD[4]; + /* Data follows: + * uint8_t ss[ssSz]; + */ +} whMessageCrypto_MlKemDecapsResponse; + +int wh_MessageCrypto_TranslateMlKemDecapsRequest( + uint16_t magic, const whMessageCrypto_MlKemDecapsRequest* src, + whMessageCrypto_MlKemDecapsRequest* dest); + +int wh_MessageCrypto_TranslateMlKemDecapsResponse( + uint16_t magic, const whMessageCrypto_MlKemDecapsResponse* src, + whMessageCrypto_MlKemDecapsResponse* dest); + /* * DMA-based crypto messages @@ -1337,6 +1423,93 @@ int wh_MessageCrypto_TranslateMlDsaVerifyDmaResponse( uint16_t magic, const whMessageCrypto_MlDsaVerifyDmaResponse* src, whMessageCrypto_MlDsaVerifyDmaResponse* dest); +/* ML-KEM DMA Key Generation Request */ +typedef struct { + whMessageCrypto_DmaBuffer key; + uint32_t level; + uint32_t flags; + uint32_t keyId; + uint32_t access; /* Key access permissions */ + uint32_t labelSize; + uint8_t label[WH_NVM_LABEL_LEN]; + uint8_t WH_PAD2[4]; /* Pad to 8-byte alignment */ +} whMessageCrypto_MlKemKeyGenDmaRequest; + +/* ML-KEM DMA Key Generation Response */ +typedef struct { + whMessageCrypto_DmaAddrStatus dmaAddrStatus; + uint32_t keyId; + uint32_t keySize; +} whMessageCrypto_MlKemKeyGenDmaResponse; + +/* ML-KEM DMA Encapsulation Request + * Note: The shared secret is transferred inline in the response (not via DMA) + * since it is always WC_ML_KEM_SS_SZ (32) bytes, similar to AES key handling. + */ +typedef struct { + whMessageCrypto_DmaBuffer ct; + uint32_t options; + uint32_t level; + uint32_t keyId; + uint8_t WH_PAD[4]; /* Pad to 8-byte alignment */ +} whMessageCrypto_MlKemEncapsDmaRequest; + +/* ML-KEM DMA Encapsulation Response */ +typedef struct { + whMessageCrypto_DmaAddrStatus dmaAddrStatus; + uint32_t ctLen; + uint32_t ssLen; + /* Data follows: + * uint8_t ss[ssLen]; + */ +} whMessageCrypto_MlKemEncapsDmaResponse; + +/* ML-KEM DMA Decapsulation Request + * Note: The shared secret is transferred inline in the response (not via DMA). + */ +typedef struct { + whMessageCrypto_DmaBuffer ct; + uint32_t options; + uint32_t level; + uint32_t keyId; + uint8_t WH_PAD[4]; /* Pad to 8-byte alignment */ +} whMessageCrypto_MlKemDecapsDmaRequest; + +/* ML-KEM DMA Decapsulation Response */ +typedef struct { + whMessageCrypto_DmaAddrStatus dmaAddrStatus; + uint32_t ssLen; + uint8_t WH_PAD[4]; /* Pad to 8-byte alignment */ + /* Data follows: + * uint8_t ss[ssLen]; + */ +} whMessageCrypto_MlKemDecapsDmaResponse; + +/* ML-KEM DMA translation functions */ +int wh_MessageCrypto_TranslateMlKemKeyGenDmaRequest( + uint16_t magic, const whMessageCrypto_MlKemKeyGenDmaRequest* src, + whMessageCrypto_MlKemKeyGenDmaRequest* dest); + +int wh_MessageCrypto_TranslateMlKemKeyGenDmaResponse( + uint16_t magic, const whMessageCrypto_MlKemKeyGenDmaResponse* src, + whMessageCrypto_MlKemKeyGenDmaResponse* dest); + +int wh_MessageCrypto_TranslateMlKemEncapsDmaRequest( + uint16_t magic, const whMessageCrypto_MlKemEncapsDmaRequest* src, + whMessageCrypto_MlKemEncapsDmaRequest* dest); + +int wh_MessageCrypto_TranslateMlKemEncapsDmaResponse( + uint16_t magic, const whMessageCrypto_MlKemEncapsDmaResponse* src, + whMessageCrypto_MlKemEncapsDmaResponse* dest); + +int wh_MessageCrypto_TranslateMlKemDecapsDmaRequest( + uint16_t magic, const whMessageCrypto_MlKemDecapsDmaRequest* src, + whMessageCrypto_MlKemDecapsDmaRequest* dest); + +int wh_MessageCrypto_TranslateMlKemDecapsDmaResponse( + uint16_t magic, const whMessageCrypto_MlKemDecapsDmaResponse* src, + whMessageCrypto_MlKemDecapsDmaResponse* dest); + /* Ed25519 DMA Sign Request */ typedef struct { whMessageCrypto_DmaBuffer msg; /* Message buffer */ diff --git a/wolfhsm/wh_server_crypto.h b/wolfhsm/wh_server_crypto.h index 655f67690..af945ce92 100644 --- a/wolfhsm/wh_server_crypto.h +++ b/wolfhsm/wh_server_crypto.h @@ -37,6 +37,7 @@ #include "wolfssl/wolfcrypt/curve25519.h" #include "wolfssl/wolfcrypt/ecc.h" #include "wolfssl/wolfcrypt/ed25519.h" +#include "wolfssl/wolfcrypt/wc_mlkem.h" #include "wolfssl/wolfcrypt/aes.h" #include "wolfssl/wolfcrypt/sha256.h" #include "wolfssl/wolfcrypt/cmac.h" @@ -103,6 +104,16 @@ int wh_Server_MlDsaKeyCacheExport(whServerContext* ctx, whKeyId keyId, MlDsaKey* key); #endif /* HAVE_DILITHIUM */ +#ifdef WOLFSSL_HAVE_MLKEM +/* Store a MlKemKey into a server key cache with optional metadata */ +int wh_Server_MlKemKeyCacheImport(whServerContext* ctx, MlKemKey* key, + whKeyId keyId, whNvmFlags flags, + uint16_t label_len, uint8_t* label); +/* Restore a MlKemKey from a server key cache */ +int wh_Server_MlKemKeyCacheExport(whServerContext* ctx, whKeyId keyId, + MlKemKey* key); +#endif /* WOLFSSL_HAVE_MLKEM */ + #ifdef HAVE_HKDF /* Store HKDF output into a server key cache with optional metadata */ int wh_Server_HkdfKeyCacheImport(whServerContext* ctx, const uint8_t* keyData, From 1bb250ce782e928b3024959607f46816f0877116 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20Frauenschl=C3=A4ger?= Date: Mon, 20 Apr 2026 21:00:19 +0200 Subject: [PATCH 2/4] Add new structs to padding test --- test/wh_test_check_struct_padding.c | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test/wh_test_check_struct_padding.c b/test/wh_test_check_struct_padding.c index dfc6d4913..a0e549132 100644 --- a/test/wh_test_check_struct_padding.c +++ b/test/wh_test_check_struct_padding.c @@ -120,6 +120,18 @@ whMessageCrypto_Sha512Request hashSha512Req; whMessageCrypto_Sha2Response hashSha2Res; whMessageCrypto_HkdfRequest hkdfReq; whMessageCrypto_HkdfResponse hkdfRes; +whMessageCrypto_MlDsaKeyGenRequest pkMldsaKeygenReq; +whMessageCrypto_MlDsaKeyGenResponse pkMldsaKeygenRes; +whMessageCrypto_MlDsaSignRequest pkMldsaSignReq; +whMessageCrypto_MlDsaSignResponse pkMldsaSignRes; +whMessageCrypto_MlDsaVerifyRequest pkMldsaVerifyReq; +whMessageCrypto_MlDsaVerifyResponse pkMldsaVerifyRes; +whMessageCrypto_MlKemKeyGenRequest pkMlkemKeygenReq; +whMessageCrypto_MlKemKeyGenResponse pkMlkemKeygenRes; +whMessageCrypto_MlKemEncapsRequest pkMlkemEncapsReq; +whMessageCrypto_MlKemEncapsResponse pkMlkemEncapsRes; +whMessageCrypto_MlKemDecapsRequest pkMlkemDecapsReq; +whMessageCrypto_MlKemDecapsResponse pkMlkemDecapsRes; /* DMA crypto messages */ #if defined(WOLFHSM_CFG_DMA) @@ -134,6 +146,12 @@ whMessageCrypto_MlDsaVerifyDmaRequest pqMldsaVerifyDmaReq; whMessageCrypto_MlDsaVerifyDmaResponse pqMldsaVerifyDmaRes; whMessageCrypto_CmacAesDmaRequest cmacDmaReq; whMessageCrypto_CmacAesDmaResponse cmacDmaRes; +whMessageCrypto_MlKemKeyGenDmaRequest pkMlkemKeygenDmaReq; +whMessageCrypto_MlKemKeyGenDmaResponse pkMlkemKeygenDmaRes; +whMessageCrypto_MlKemEncapsDmaRequest pkMlkemEncapsDmaReq; +whMessageCrypto_MlKemEncapsDmaResponse pkMlkemEncapsDmaRes; +whMessageCrypto_MlKemDecapsDmaRequest pkMlkemDecapsDmaReq; +whMessageCrypto_MlKemDecapsDmaResponse pkMlkemDecapsDmaRes; #endif /* WOLFHSM_CFG_DMA */ #endif /* !WOLFHSM_CFG_NO_CRYPTO */ From 7bdea16f08060d21a5452a9107f542adbd185a4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20Frauenschl=C3=A4ger?= Date: Tue, 28 Apr 2026 09:59:21 +0200 Subject: [PATCH 3/4] Remove dynamic memory allocation --- src/wh_client_crypto.c | 105 +++++++++-------------------------------- 1 file changed, 22 insertions(+), 83 deletions(-) diff --git a/src/wh_client_crypto.c b/src/wh_client_crypto.c index dd6b1f250..c749ab175 100644 --- a/src/wh_client_crypto.c +++ b/src/wh_client_crypto.c @@ -8010,35 +8010,19 @@ int wh_Client_MlKemImportKey(whClientContext* ctx, MlKemKey* key, { int ret = WH_ERROR_OK; whKeyId key_id = WH_KEYID_ERASED; - byte* buffer = NULL; - uint16_t buffer_len = 0; - word32 allocSz = 0; + byte buffer[WC_ML_KEM_MAX_PRIVATE_KEY_SIZE]; + uint16_t buffer_len = sizeof(buffer); if ((ctx == NULL) || (key == NULL) || ((label_len != 0) && (label == NULL))) { return WH_ERROR_BADARGS; } - /* Use exact key size based on level to avoid over-allocation */ - ret = wc_MlKemKey_PrivateKeySize(key, &allocSz); - if (ret != 0) { - /* Fall back to public key size if no private key */ - ret = wc_MlKemKey_PublicKeySize(key, &allocSz); - } - if (ret != 0 || allocSz == 0) { - return WH_ERROR_BADARGS; - } - - buffer = (byte*)XMALLOC(allocSz, NULL, DYNAMIC_TYPE_TMP_BUFFER); - if (buffer == NULL) { - return WH_ERROR_ABORTED; - } - if (inout_keyId != NULL) { key_id = *inout_keyId; } - ret = wh_Crypto_MlKemSerializeKey(key, (uint16_t)allocSz, buffer, + ret = wh_Crypto_MlKemSerializeKey(key, (uint16_t)buffer_len, buffer, &buffer_len); WH_DEBUG_CLIENT_VERBOSE("MlKemImportKey: serialize ret:%d, len:%u\n", ret, (unsigned int)buffer_len); @@ -8051,8 +8035,7 @@ int wh_Client_MlKemImportKey(whClientContext* ctx, MlKemKey* key, } WH_DEBUG_CLIENT_VERBOSE("MlKemImportKey: ret:%d keyId:%u\n", ret, key_id); - wc_ForceZero(buffer, allocSz); - XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER); + wc_ForceZero(buffer, buffer_len); return ret; } @@ -8060,19 +8043,13 @@ int wh_Client_MlKemExportKey(whClientContext* ctx, whKeyId keyId, MlKemKey* key, uint16_t label_len, uint8_t* label) { int ret = WH_ERROR_OK; - byte* buffer = NULL; - uint16_t buffer_len = WC_ML_KEM_MAX_PRIVATE_KEY_SIZE; + byte buffer[WC_ML_KEM_MAX_PRIVATE_KEY_SIZE]; + uint16_t buffer_len = sizeof(buffer); if ((ctx == NULL) || WH_KEYID_ISERASED(keyId) || (key == NULL)) { return WH_ERROR_BADARGS; } - buffer = (byte*)XMALLOC(WC_ML_KEM_MAX_PRIVATE_KEY_SIZE, NULL, - DYNAMIC_TYPE_TMP_BUFFER); - if (buffer == NULL) { - return WH_ERROR_ABORTED; - } - ret = wh_Client_KeyExport(ctx, keyId, label, label_len, buffer, &buffer_len); WH_DEBUG_CLIENT_VERBOSE("MlKemExportKey: export ret:%d, len:%u\n", @@ -8082,8 +8059,7 @@ int wh_Client_MlKemExportKey(whClientContext* ctx, whKeyId keyId, MlKemKey* key, } WH_DEBUG_CLIENT_VERBOSE("MlKemExportKey: keyId:%x ret:%d\n", keyId, ret); - wc_ForceZero(buffer, WC_ML_KEM_MAX_PRIVATE_KEY_SIZE); - XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER); + wc_ForceZero(buffer, buffer_len); return ret; } @@ -8449,34 +8425,19 @@ int wh_Client_MlKemImportKeyDma(whClientContext* ctx, MlKemKey* key, { int ret = WH_ERROR_OK; whKeyId key_id = WH_KEYID_ERASED; - byte* buffer = NULL; - uint16_t buffer_len = 0; - word32 allocSz = 0; + byte buffer[WC_ML_KEM_MAX_PRIVATE_KEY_SIZE]; + uint16_t buffer_len = sizeof(buffer); if ((ctx == NULL) || (key == NULL) || ((label_len != 0) && (label == NULL))) { return WH_ERROR_BADARGS; } - /* Use exact key size based on level to avoid over-allocation */ - ret = wc_MlKemKey_PrivateKeySize(key, &allocSz); - if (ret != 0) { - ret = wc_MlKemKey_PublicKeySize(key, &allocSz); - } - if (ret != 0 || allocSz == 0) { - return WH_ERROR_BADARGS; - } - - buffer = (byte*)XMALLOC(allocSz, NULL, DYNAMIC_TYPE_TMP_BUFFER); - if (buffer == NULL) { - return WH_ERROR_ABORTED; - } - if (inout_keyId != NULL) { key_id = *inout_keyId; } - ret = wh_Crypto_MlKemSerializeKey(key, (uint16_t)allocSz, buffer, + ret = wh_Crypto_MlKemSerializeKey(key, (uint16_t)buffer_len, buffer, &buffer_len); WH_DEBUG_CLIENT_VERBOSE("MlKemImportKeyDma: serialize ret:%d, len:%u\n", ret, (unsigned int)buffer_len); @@ -8490,8 +8451,7 @@ int wh_Client_MlKemImportKeyDma(whClientContext* ctx, MlKemKey* key, WH_DEBUG_CLIENT_VERBOSE("MlKemImportKeyDma: ret:%d keyId:%u\n", ret, key_id); - wc_ForceZero(buffer, allocSz); - XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER); + wc_ForceZero(buffer, buffer_len); return ret; } @@ -8500,20 +8460,13 @@ int wh_Client_MlKemExportKeyDma(whClientContext* ctx, whKeyId keyId, uint8_t* label) { int ret = WH_ERROR_OK; - byte* buffer = NULL; - uint16_t buffer_len = WC_ML_KEM_MAX_PRIVATE_KEY_SIZE; + byte buffer[WC_ML_KEM_MAX_PRIVATE_KEY_SIZE] = {0}; + uint16_t buffer_len = sizeof(buffer); if ((ctx == NULL) || WH_KEYID_ISERASED(keyId) || (key == NULL)) { return WH_ERROR_BADARGS; } - buffer = (byte*)XMALLOC(WC_ML_KEM_MAX_PRIVATE_KEY_SIZE, NULL, - DYNAMIC_TYPE_TMP_BUFFER); - if (buffer == NULL) { - return WH_ERROR_ABORTED; - } - memset(buffer, 0, WC_ML_KEM_MAX_PRIVATE_KEY_SIZE); - ret = wh_Client_KeyExportDma(ctx, keyId, buffer, buffer_len, label, label_len, &buffer_len); WH_DEBUG_CLIENT_VERBOSE("MlKemExportKeyDma: export ret:%d, len:%u\n", @@ -8524,8 +8477,7 @@ int wh_Client_MlKemExportKeyDma(whClientContext* ctx, whKeyId keyId, WH_DEBUG_CLIENT_VERBOSE("MlKemExportKeyDma: keyId:%x ret:%d\n", keyId, ret); - wc_ForceZero(buffer, WC_ML_KEM_MAX_PRIVATE_KEY_SIZE); - XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER); + wc_ForceZero(buffer, buffer_len); return ret; } @@ -8535,31 +8487,20 @@ static int _MlKemMakeKeyDma(whClientContext* ctx, int level, { int ret = WH_ERROR_OK; whKeyId key_id = WH_KEYID_ERASED; - byte* buffer = NULL; uint8_t* dataPtr = NULL; whMessageCrypto_MlKemKeyGenDmaRequest* req = NULL; whMessageCrypto_MlKemKeyGenDmaResponse* res = NULL; uintptr_t keyAddr = 0; - uint32_t allocSz = 0; - if ((ctx == NULL) || (key == NULL)) { - return WH_ERROR_BADARGS; - } + byte buffer[WC_ML_KEM_MAX_PRIVATE_KEY_SIZE]; + uint16_t buffer_len = sizeof(buffer); - ret = wc_MlKemKey_PrivateKeySize(key, &allocSz); - if (ret != 0) { + if ((ctx == NULL) || (key == NULL)) { return WH_ERROR_BADARGS; } - else { - buffer = (byte*)XMALLOC(allocSz, NULL, DYNAMIC_TYPE_TMP_BUFFER); - if (buffer == NULL) { - return WH_ERROR_ABORTED; - } - } dataPtr = (uint8_t*)wh_CommClient_GetDataPtr(ctx->comm); if (dataPtr == NULL) { - XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER); return WH_ERROR_BADARGS; } @@ -8577,7 +8518,6 @@ static int _MlKemMakeKeyDma(whClientContext* ctx, int level, sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req); if (req_len > WOLFHSM_CFG_COMM_DATA_LEN) { - XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER); return WH_ERROR_BADARGS; } @@ -8586,10 +8526,10 @@ static int _MlKemMakeKeyDma(whClientContext* ctx, int level, req->flags = flags; req->keyId = key_id; req->access = WH_NVM_ACCESS_ANY; - req->key.sz = allocSz; + req->key.sz = buffer_len; ret = wh_Client_DmaProcessClientAddress( - ctx, (uintptr_t)buffer, (void**)&keyAddr, allocSz, + ctx, (uintptr_t)buffer, (void**)&keyAddr, buffer_len, WH_DMA_OPER_CLIENT_WRITE_PRE, (whDmaFlags){0}); if (ret == WH_ERROR_OK) { req->key.addr = (uint64_t)(uintptr_t)keyAddr; @@ -8615,7 +8555,7 @@ static int _MlKemMakeKeyDma(whClientContext* ctx, int level, } (void)wh_Client_DmaProcessClientAddress( - ctx, (uintptr_t)buffer, (void**)&keyAddr, allocSz, + ctx, (uintptr_t)buffer, (void**)&keyAddr, buffer_len, WH_DMA_OPER_CLIENT_WRITE_POST, (whDmaFlags){0}); if (ret == WH_ERROR_OK) { @@ -8629,7 +8569,7 @@ static int _MlKemMakeKeyDma(whClientContext* ctx, int level, if (key != NULL) { wh_Client_MlKemSetKeyId(key, key_id); if ((flags & WH_NVM_FLAGS_EPHEMERAL) != 0) { - if (res->keySize > allocSz) { + if (res->keySize > buffer_len) { ret = WH_ERROR_BADARGS; } else { @@ -8641,8 +8581,7 @@ static int _MlKemMakeKeyDma(whClientContext* ctx, int level, } } - wc_ForceZero(buffer, allocSz); - XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER); + wc_ForceZero(buffer, buffer_len); return ret; } From e284b738b21ac2b6d50a886745fe344d246e94d8 Mon Sep 17 00:00:00 2001 From: Paul Adelsbach Date: Mon, 4 May 2026 12:12:14 -0700 Subject: [PATCH 4/4] Add support for LMS and XMSS --- src/wh_client_crypto.c | 768 +++++++++++++++ src/wh_client_cryptocb.c | 312 ++++++ src/wh_crypto.c | 279 ++++++ src/wh_message_crypto.c | 179 ++++ src/wh_server_crypto.c | 1416 +++++++++++++++++++++++++-- test/config/user_settings.h | 8 + test/wh_test_check_struct_padding.c | 8 + test/wh_test_crypto.c | 318 ++++++ wolfhsm/wh_client_crypto.h | 71 ++ wolfhsm/wh_crypto.h | 61 ++ wolfhsm/wh_message_crypto.h | 117 +++ wolfhsm/wh_server_crypto.h | 23 + 12 files changed, 3470 insertions(+), 90 deletions(-) diff --git a/src/wh_client_crypto.c b/src/wh_client_crypto.c index c749ab175..3b1b9018a 100644 --- a/src/wh_client_crypto.c +++ b/src/wh_client_crypto.c @@ -56,6 +56,12 @@ #include "wolfssl/wolfcrypt/ed25519.h" #include "wolfssl/wolfcrypt/dilithium.h" #include "wolfssl/wolfcrypt/wc_mlkem.h" +#if defined(WOLFSSL_HAVE_LMS) +#include "wolfssl/wolfcrypt/wc_lms.h" +#endif +#if defined(WOLFSSL_HAVE_XMSS) +#include "wolfssl/wolfcrypt/wc_xmss.h" +#endif #include "wolfssl/wolfcrypt/sha256.h" #include "wolfssl/wolfcrypt/sha512.h" #endif @@ -8837,4 +8843,766 @@ int wh_Client_MlKemDecapsulateDma(whClientContext* ctx, MlKemKey* key, #endif /* WOLFHSM_CFG_DMA */ #endif /* WOLFSSL_HAVE_MLKEM */ +#if defined(WOLFSSL_HAVE_LMS) || defined(WOLFSSL_HAVE_XMSS) +#ifdef WOLFHSM_CFG_DMA + +#ifdef WOLFSSL_HAVE_LMS + +int wh_Client_LmsSetKeyId(LmsKey* key, whKeyId keyId) +{ + if (key == NULL) { + return WH_ERROR_BADARGS; + } + key->devCtx = WH_KEYID_TO_DEVCTX(keyId); + return WH_ERROR_OK; +} + +int wh_Client_LmsGetKeyId(LmsKey* key, whKeyId* outId) +{ + if (key == NULL || outId == NULL) { + return WH_ERROR_BADARGS; + } + *outId = WH_DEVCTX_TO_KEYID(key->devCtx); + return WH_ERROR_OK; +} + +int wh_Client_LmsMakeKeyDma(whClientContext* ctx, LmsKey* key, + whKeyId* inout_key_id, whNvmFlags flags, + uint16_t label_len, uint8_t* label) +{ + int ret = WH_ERROR_OK; + whKeyId key_id = WH_KEYID_ERASED; + uint8_t* dataPtr; + whMessageCrypto_PqcStatefulSigKeyGenDmaRequest* req; + whMessageCrypto_PqcStatefulSigKeyGenDmaResponse* res; + word32 pubLen32 = 0; + uintptr_t pubAddr = 0; + + if ((ctx == NULL) || (key == NULL) || (key->params == NULL)) { + return WH_ERROR_BADARGS; + } + + ret = wc_LmsKey_GetPubLen(key, &pubLen32); + if (ret != 0) { + return WH_ERROR_BADARGS; + } + + dataPtr = (uint8_t*)wh_CommClient_GetDataPtr(ctx->comm); + if (dataPtr == NULL) { + return WH_ERROR_BADARGS; + } + + req = (whMessageCrypto_PqcStatefulSigKeyGenDmaRequest*) + _createCryptoRequestWithSubtype( + dataPtr, WC_PK_TYPE_PQC_STATEFUL_SIG_KEYGEN, + WC_PQC_STATEFUL_SIG_TYPE_LMS, ctx->cryptoAffinity); + + if (inout_key_id != NULL) { + key_id = *inout_key_id; + } + + { + uint16_t group = WH_MESSAGE_GROUP_CRYPTO_DMA; + uint16_t action = WC_ALGO_TYPE_PK; + uint16_t req_len = + sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req); + + if (req_len > WOLFHSM_CFG_COMM_DATA_LEN) { + return WH_ERROR_BADARGS; + } + + memset(req, 0, sizeof(*req)); + req->flags = flags; + req->keyId = key_id; + req->access = WH_NVM_ACCESS_ANY; + req->lmsLevels = key->params->levels; + req->lmsHeight = key->params->height; + req->lmsWinternitz = key->params->width; + req->pub.sz = pubLen32; + + ret = wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)key->pub, (void**)&pubAddr, pubLen32, + WH_DMA_OPER_CLIENT_WRITE_PRE, (whDmaFlags){0}); + if (ret == WH_ERROR_OK) { + req->pub.addr = (uint64_t)(uintptr_t)pubAddr; + } + + if ((label != NULL) && (label_len > 0)) { + if (label_len > WH_NVM_LABEL_LEN) { + label_len = WH_NVM_LABEL_LEN; + } + memcpy(req->label, label, label_len); + req->labelSize = label_len; + } + + if (ret == WH_ERROR_OK) { + ret = wh_Client_SendRequest(ctx, group, action, req_len, + (uint8_t*)dataPtr); + } + if (ret == WH_ERROR_OK) { + do { + ret = wh_Client_RecvResponse(ctx, &group, &action, &req_len, + (uint8_t*)dataPtr); + } while (ret == WH_ERROR_NOTREADY); + } + + (void)wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)key->pub, (void**)&pubAddr, pubLen32, + WH_DMA_OPER_CLIENT_WRITE_POST, (whDmaFlags){0}); + + if (ret == WH_ERROR_OK) { + ret = _getCryptoResponse(dataPtr, + WC_PK_TYPE_PQC_STATEFUL_SIG_KEYGEN, + (uint8_t**)&res); + if (ret >= 0) { + key_id = (whKeyId)res->keyId; + if (inout_key_id != NULL) { + *inout_key_id = key_id; + } + wh_Client_LmsSetKeyId(key, key_id); + } + } + } + + return ret; +} + +int wh_Client_LmsMakeExportKeyDma(whClientContext* ctx, LmsKey* key) +{ + return wh_Client_LmsMakeKeyDma(ctx, key, NULL, WH_NVM_FLAGS_EPHEMERAL, 0, + NULL); +} + +int wh_Client_LmsSignDma(whClientContext* ctx, const byte* msg, word32 msgSz, + byte* sig, word32* sigSz, LmsKey* key) +{ + int ret = WH_ERROR_OK; + uint8_t* dataPtr; + whMessageCrypto_PqcStatefulSigSignDmaRequest* req; + whMessageCrypto_PqcStatefulSigSignDmaResponse* res; + uintptr_t msgAddr = 0; + uintptr_t sigAddr = 0; + whKeyId key_id; + word32 sigCap; + + if ((ctx == NULL) || (key == NULL) || (msg == NULL) || (sig == NULL) || + (sigSz == NULL)) { + return WH_ERROR_BADARGS; + } + + sigCap = *sigSz; + key_id = WH_DEVCTX_TO_KEYID(key->devCtx); + if (WH_KEYID_ISERASED(key_id)) { + return WH_ERROR_BADARGS; + } + + dataPtr = (uint8_t*)wh_CommClient_GetDataPtr(ctx->comm); + if (dataPtr == NULL) { + return WH_ERROR_BADARGS; + } + + req = (whMessageCrypto_PqcStatefulSigSignDmaRequest*) + _createCryptoRequestWithSubtype( + dataPtr, WC_PK_TYPE_PQC_STATEFUL_SIG_SIGN, + WC_PQC_STATEFUL_SIG_TYPE_LMS, ctx->cryptoAffinity); + + { + uint16_t group = WH_MESSAGE_GROUP_CRYPTO_DMA; + uint16_t action = WC_ALGO_TYPE_PK; + uint16_t req_len = + sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req); + + if (req_len > WOLFHSM_CFG_COMM_DATA_LEN) { + return WH_ERROR_BADARGS; + } + + memset(req, 0, sizeof(*req)); + req->keyId = key_id; + req->options = 0; + req->msg.sz = msgSz; + req->sig.sz = sigCap; + + ret = wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)msg, (void**)&msgAddr, msgSz, + WH_DMA_OPER_CLIENT_READ_PRE, (whDmaFlags){0}); + if (ret == WH_ERROR_OK) { + req->msg.addr = (uint64_t)(uintptr_t)msgAddr; + ret = wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)sig, (void**)&sigAddr, sigCap, + WH_DMA_OPER_CLIENT_WRITE_PRE, (whDmaFlags){0}); + } + if (ret == WH_ERROR_OK) { + req->sig.addr = (uint64_t)(uintptr_t)sigAddr; + ret = wh_Client_SendRequest(ctx, group, action, req_len, + (uint8_t*)dataPtr); + } + if (ret == WH_ERROR_OK) { + do { + ret = wh_Client_RecvResponse(ctx, &group, &action, &req_len, + (uint8_t*)dataPtr); + } while (ret == WH_ERROR_NOTREADY); + } + + (void)wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)msg, (void**)&msgAddr, msgSz, + WH_DMA_OPER_CLIENT_READ_POST, (whDmaFlags){0}); + (void)wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)sig, (void**)&sigAddr, sigCap, + WH_DMA_OPER_CLIENT_WRITE_POST, (whDmaFlags){0}); + + if (ret == WH_ERROR_OK) { + ret = _getCryptoResponse(dataPtr, WC_PK_TYPE_PQC_STATEFUL_SIG_SIGN, + (uint8_t**)&res); + if (ret >= 0) { + if (res->sigLen > sigCap) { + ret = WH_ERROR_BADARGS; + } + else { + *sigSz = res->sigLen; + ret = WH_ERROR_OK; + } + } + } + } + + return ret; +} + +int wh_Client_LmsVerifyDma(whClientContext* ctx, const byte* sig, word32 sigSz, + const byte* msg, word32 msgSz, int* res, LmsKey* key) +{ + int ret = WH_ERROR_OK; + uint8_t* dataPtr; + whMessageCrypto_PqcStatefulSigVerifyDmaRequest* req; + whMessageCrypto_PqcStatefulSigVerifyDmaResponse* resp; + uintptr_t sigAddr = 0; + uintptr_t msgAddr = 0; + whKeyId key_id; + + if ((ctx == NULL) || (key == NULL) || (sig == NULL) || (msg == NULL) || + (res == NULL)) { + return WH_ERROR_BADARGS; + } + + key_id = WH_DEVCTX_TO_KEYID(key->devCtx); + if (WH_KEYID_ISERASED(key_id)) { + return WH_ERROR_BADARGS; + } + + dataPtr = (uint8_t*)wh_CommClient_GetDataPtr(ctx->comm); + if (dataPtr == NULL) { + return WH_ERROR_BADARGS; + } + + req = (whMessageCrypto_PqcStatefulSigVerifyDmaRequest*) + _createCryptoRequestWithSubtype( + dataPtr, WC_PK_TYPE_PQC_STATEFUL_SIG_VERIFY, + WC_PQC_STATEFUL_SIG_TYPE_LMS, ctx->cryptoAffinity); + + { + uint16_t group = WH_MESSAGE_GROUP_CRYPTO_DMA; + uint16_t action = WC_ALGO_TYPE_PK; + uint16_t req_len = + sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req); + + if (req_len > WOLFHSM_CFG_COMM_DATA_LEN) { + return WH_ERROR_BADARGS; + } + + memset(req, 0, sizeof(*req)); + req->keyId = key_id; + req->sig.sz = sigSz; + req->msg.sz = msgSz; + + ret = wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)sig, (void**)&sigAddr, sigSz, + WH_DMA_OPER_CLIENT_READ_PRE, (whDmaFlags){0}); + if (ret == WH_ERROR_OK) { + req->sig.addr = (uint64_t)(uintptr_t)sigAddr; + ret = wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)msg, (void**)&msgAddr, msgSz, + WH_DMA_OPER_CLIENT_READ_PRE, (whDmaFlags){0}); + } + if (ret == WH_ERROR_OK) { + req->msg.addr = (uint64_t)(uintptr_t)msgAddr; + ret = wh_Client_SendRequest(ctx, group, action, req_len, + (uint8_t*)dataPtr); + } + if (ret == WH_ERROR_OK) { + do { + ret = wh_Client_RecvResponse(ctx, &group, &action, &req_len, + (uint8_t*)dataPtr); + } while (ret == WH_ERROR_NOTREADY); + } + + (void)wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)sig, (void**)&sigAddr, sigSz, + WH_DMA_OPER_CLIENT_READ_POST, (whDmaFlags){0}); + (void)wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)msg, (void**)&msgAddr, msgSz, + WH_DMA_OPER_CLIENT_READ_POST, (whDmaFlags){0}); + + if (ret == WH_ERROR_OK) { + ret = _getCryptoResponse(dataPtr, + WC_PK_TYPE_PQC_STATEFUL_SIG_VERIFY, + (uint8_t**)&resp); + if (ret >= 0) { + *res = (int)resp->res; + ret = WH_ERROR_OK; + } + } + } + + return ret; +} + +int wh_Client_LmsSigsLeftDma(whClientContext* ctx, LmsKey* key, + word32* sigsLeft) +{ + int ret = WH_ERROR_OK; + uint8_t* dataPtr; + whMessageCrypto_PqcStatefulSigSigsLeftDmaRequest* req; + whMessageCrypto_PqcStatefulSigSigsLeftDmaResponse* res; + whKeyId key_id; + + if ((ctx == NULL) || (key == NULL) || (sigsLeft == NULL)) { + return WH_ERROR_BADARGS; + } + + key_id = WH_DEVCTX_TO_KEYID(key->devCtx); + if (WH_KEYID_ISERASED(key_id)) { + return WH_ERROR_BADARGS; + } + + dataPtr = (uint8_t*)wh_CommClient_GetDataPtr(ctx->comm); + if (dataPtr == NULL) { + return WH_ERROR_BADARGS; + } + + req = (whMessageCrypto_PqcStatefulSigSigsLeftDmaRequest*) + _createCryptoRequestWithSubtype( + dataPtr, WC_PK_TYPE_PQC_STATEFUL_SIG_SIGS_LEFT, + WC_PQC_STATEFUL_SIG_TYPE_LMS, ctx->cryptoAffinity); + + { + uint16_t group = WH_MESSAGE_GROUP_CRYPTO_DMA; + uint16_t action = WC_ALGO_TYPE_PK; + uint16_t req_len = + sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req); + + memset(req, 0, sizeof(*req)); + req->keyId = key_id; + + ret = wh_Client_SendRequest(ctx, group, action, req_len, + (uint8_t*)dataPtr); + if (ret == WH_ERROR_OK) { + do { + ret = wh_Client_RecvResponse(ctx, &group, &action, &req_len, + (uint8_t*)dataPtr); + } while (ret == WH_ERROR_NOTREADY); + } + if (ret == WH_ERROR_OK) { + ret = _getCryptoResponse(dataPtr, + WC_PK_TYPE_PQC_STATEFUL_SIG_SIGS_LEFT, + (uint8_t**)&res); + if (ret >= 0) { + *sigsLeft = res->sigsLeft; + ret = WH_ERROR_OK; + } + } + } + + return ret; +} + +#endif /* WOLFSSL_HAVE_LMS */ + +#ifdef WOLFSSL_HAVE_XMSS + +int wh_Client_XmssSetKeyId(XmssKey* key, whKeyId keyId) +{ + if (key == NULL) { + return WH_ERROR_BADARGS; + } + key->devCtx = WH_KEYID_TO_DEVCTX(keyId); + return WH_ERROR_OK; +} + +int wh_Client_XmssGetKeyId(XmssKey* key, whKeyId* outId) +{ + if (key == NULL || outId == NULL) { + return WH_ERROR_BADARGS; + } + *outId = WH_DEVCTX_TO_KEYID(key->devCtx); + return WH_ERROR_OK; +} + +/* The XMSS implementations mirror the LMS ones; the only differences are the + * subType passed to _createCryptoRequestWithSubtype and the key field names + * (key->pk instead of key->pub, key->params is XmssParams). */ +int wh_Client_XmssMakeKeyDma(whClientContext* ctx, XmssKey* key, + whKeyId* inout_key_id, whNvmFlags flags, + uint16_t label_len, uint8_t* label) +{ + int ret = WH_ERROR_OK; + whKeyId key_id = WH_KEYID_ERASED; + uint8_t* dataPtr; + whMessageCrypto_PqcStatefulSigKeyGenDmaRequest* req; + whMessageCrypto_PqcStatefulSigKeyGenDmaResponse* res; + word32 pubLen32 = 0; + uintptr_t pubAddr = 0; + + if ((ctx == NULL) || (key == NULL) || (key->params == NULL)) { + return WH_ERROR_BADARGS; + } + + ret = wc_XmssKey_GetPubLen(key, &pubLen32); + if (ret != 0) { + return WH_ERROR_BADARGS; + } + + dataPtr = (uint8_t*)wh_CommClient_GetDataPtr(ctx->comm); + if (dataPtr == NULL) { + return WH_ERROR_BADARGS; + } + + req = (whMessageCrypto_PqcStatefulSigKeyGenDmaRequest*) + _createCryptoRequestWithSubtype( + dataPtr, WC_PK_TYPE_PQC_STATEFUL_SIG_KEYGEN, + WC_PQC_STATEFUL_SIG_TYPE_XMSS, ctx->cryptoAffinity); + + if (inout_key_id != NULL) { + key_id = *inout_key_id; + } + + { + uint16_t group = WH_MESSAGE_GROUP_CRYPTO_DMA; + uint16_t action = WC_ALGO_TYPE_PK; + uint16_t req_len = + sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req); + + if (req_len > WOLFHSM_CFG_COMM_DATA_LEN) { + return WH_ERROR_BADARGS; + } + + memset(req, 0, sizeof(*req)); + req->flags = flags; + req->keyId = key_id; + req->access = WH_NVM_ACCESS_ANY; + req->pub.sz = pubLen32; + + { + const char* paramStr = NULL; + ret = wc_XmssKey_GetParamStr(key, ¶mStr); + if (ret != 0) { + return WH_ERROR_BADARGS; + } + if (XSTRLEN(paramStr) >= sizeof(req->xmssParamStr)) { + return WH_ERROR_BADARGS; + } + XSTRNCPY(req->xmssParamStr, paramStr, sizeof(req->xmssParamStr)); + req->xmssParamStr[sizeof(req->xmssParamStr) - 1] = '\0'; + } + + ret = wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)key->pk, (void**)&pubAddr, pubLen32, + WH_DMA_OPER_CLIENT_WRITE_PRE, (whDmaFlags){0}); + if (ret == WH_ERROR_OK) { + req->pub.addr = (uint64_t)(uintptr_t)pubAddr; + } + + if ((label != NULL) && (label_len > 0)) { + if (label_len > WH_NVM_LABEL_LEN) { + label_len = WH_NVM_LABEL_LEN; + } + memcpy(req->label, label, label_len); + req->labelSize = label_len; + } + + if (ret == WH_ERROR_OK) { + ret = wh_Client_SendRequest(ctx, group, action, req_len, + (uint8_t*)dataPtr); + } + if (ret == WH_ERROR_OK) { + do { + ret = wh_Client_RecvResponse(ctx, &group, &action, &req_len, + (uint8_t*)dataPtr); + } while (ret == WH_ERROR_NOTREADY); + } + + (void)wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)key->pk, (void**)&pubAddr, pubLen32, + WH_DMA_OPER_CLIENT_WRITE_POST, (whDmaFlags){0}); + + if (ret == WH_ERROR_OK) { + ret = _getCryptoResponse(dataPtr, + WC_PK_TYPE_PQC_STATEFUL_SIG_KEYGEN, + (uint8_t**)&res); + if (ret >= 0) { + key_id = (whKeyId)res->keyId; + if (inout_key_id != NULL) { + *inout_key_id = key_id; + } + wh_Client_XmssSetKeyId(key, key_id); + } + } + } + + return ret; +} + +int wh_Client_XmssMakeExportKeyDma(whClientContext* ctx, XmssKey* key) +{ + return wh_Client_XmssMakeKeyDma(ctx, key, NULL, WH_NVM_FLAGS_EPHEMERAL, 0, + NULL); +} + +int wh_Client_XmssSignDma(whClientContext* ctx, const byte* msg, word32 msgSz, + byte* sig, word32* sigSz, XmssKey* key) +{ + int ret = WH_ERROR_OK; + uint8_t* dataPtr; + whMessageCrypto_PqcStatefulSigSignDmaRequest* req; + whMessageCrypto_PqcStatefulSigSignDmaResponse* res; + uintptr_t msgAddr = 0; + uintptr_t sigAddr = 0; + whKeyId key_id; + word32 sigCap; + + if ((ctx == NULL) || (key == NULL) || (msg == NULL) || (sig == NULL) || + (sigSz == NULL)) { + return WH_ERROR_BADARGS; + } + + sigCap = *sigSz; + key_id = WH_DEVCTX_TO_KEYID(key->devCtx); + if (WH_KEYID_ISERASED(key_id)) { + return WH_ERROR_BADARGS; + } + + dataPtr = (uint8_t*)wh_CommClient_GetDataPtr(ctx->comm); + if (dataPtr == NULL) { + return WH_ERROR_BADARGS; + } + + req = (whMessageCrypto_PqcStatefulSigSignDmaRequest*) + _createCryptoRequestWithSubtype( + dataPtr, WC_PK_TYPE_PQC_STATEFUL_SIG_SIGN, + WC_PQC_STATEFUL_SIG_TYPE_XMSS, ctx->cryptoAffinity); + + { + uint16_t group = WH_MESSAGE_GROUP_CRYPTO_DMA; + uint16_t action = WC_ALGO_TYPE_PK; + uint16_t req_len = + sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req); + + if (req_len > WOLFHSM_CFG_COMM_DATA_LEN) { + return WH_ERROR_BADARGS; + } + + memset(req, 0, sizeof(*req)); + req->keyId = key_id; + req->options = 0; + req->msg.sz = msgSz; + req->sig.sz = sigCap; + + ret = wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)msg, (void**)&msgAddr, msgSz, + WH_DMA_OPER_CLIENT_READ_PRE, (whDmaFlags){0}); + if (ret == WH_ERROR_OK) { + req->msg.addr = (uint64_t)(uintptr_t)msgAddr; + ret = wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)sig, (void**)&sigAddr, sigCap, + WH_DMA_OPER_CLIENT_WRITE_PRE, (whDmaFlags){0}); + } + if (ret == WH_ERROR_OK) { + req->sig.addr = (uint64_t)(uintptr_t)sigAddr; + ret = wh_Client_SendRequest(ctx, group, action, req_len, + (uint8_t*)dataPtr); + } + if (ret == WH_ERROR_OK) { + do { + ret = wh_Client_RecvResponse(ctx, &group, &action, &req_len, + (uint8_t*)dataPtr); + } while (ret == WH_ERROR_NOTREADY); + } + + (void)wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)msg, (void**)&msgAddr, msgSz, + WH_DMA_OPER_CLIENT_READ_POST, (whDmaFlags){0}); + (void)wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)sig, (void**)&sigAddr, sigCap, + WH_DMA_OPER_CLIENT_WRITE_POST, (whDmaFlags){0}); + + if (ret == WH_ERROR_OK) { + ret = _getCryptoResponse(dataPtr, WC_PK_TYPE_PQC_STATEFUL_SIG_SIGN, + (uint8_t**)&res); + if (ret >= 0) { + if (res->sigLen > sigCap) { + ret = WH_ERROR_BADARGS; + } + else { + *sigSz = res->sigLen; + ret = WH_ERROR_OK; + } + } + } + } + + return ret; +} + +int wh_Client_XmssVerifyDma(whClientContext* ctx, const byte* sig, + word32 sigSz, const byte* msg, word32 msgSz, + int* res, XmssKey* key) +{ + int ret = WH_ERROR_OK; + uint8_t* dataPtr; + whMessageCrypto_PqcStatefulSigVerifyDmaRequest* req; + whMessageCrypto_PqcStatefulSigVerifyDmaResponse* resp; + uintptr_t sigAddr = 0; + uintptr_t msgAddr = 0; + whKeyId key_id; + + if ((ctx == NULL) || (key == NULL) || (sig == NULL) || (msg == NULL) || + (res == NULL)) { + return WH_ERROR_BADARGS; + } + + key_id = WH_DEVCTX_TO_KEYID(key->devCtx); + if (WH_KEYID_ISERASED(key_id)) { + return WH_ERROR_BADARGS; + } + + dataPtr = (uint8_t*)wh_CommClient_GetDataPtr(ctx->comm); + if (dataPtr == NULL) { + return WH_ERROR_BADARGS; + } + + req = (whMessageCrypto_PqcStatefulSigVerifyDmaRequest*) + _createCryptoRequestWithSubtype( + dataPtr, WC_PK_TYPE_PQC_STATEFUL_SIG_VERIFY, + WC_PQC_STATEFUL_SIG_TYPE_XMSS, ctx->cryptoAffinity); + + { + uint16_t group = WH_MESSAGE_GROUP_CRYPTO_DMA; + uint16_t action = WC_ALGO_TYPE_PK; + uint16_t req_len = + sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req); + + if (req_len > WOLFHSM_CFG_COMM_DATA_LEN) { + return WH_ERROR_BADARGS; + } + + memset(req, 0, sizeof(*req)); + req->keyId = key_id; + req->sig.sz = sigSz; + req->msg.sz = msgSz; + + ret = wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)sig, (void**)&sigAddr, sigSz, + WH_DMA_OPER_CLIENT_READ_PRE, (whDmaFlags){0}); + if (ret == WH_ERROR_OK) { + req->sig.addr = (uint64_t)(uintptr_t)sigAddr; + ret = wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)msg, (void**)&msgAddr, msgSz, + WH_DMA_OPER_CLIENT_READ_PRE, (whDmaFlags){0}); + } + if (ret == WH_ERROR_OK) { + req->msg.addr = (uint64_t)(uintptr_t)msgAddr; + ret = wh_Client_SendRequest(ctx, group, action, req_len, + (uint8_t*)dataPtr); + } + if (ret == WH_ERROR_OK) { + do { + ret = wh_Client_RecvResponse(ctx, &group, &action, &req_len, + (uint8_t*)dataPtr); + } while (ret == WH_ERROR_NOTREADY); + } + + (void)wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)sig, (void**)&sigAddr, sigSz, + WH_DMA_OPER_CLIENT_READ_POST, (whDmaFlags){0}); + (void)wh_Client_DmaProcessClientAddress( + ctx, (uintptr_t)msg, (void**)&msgAddr, msgSz, + WH_DMA_OPER_CLIENT_READ_POST, (whDmaFlags){0}); + + if (ret == WH_ERROR_OK) { + ret = _getCryptoResponse(dataPtr, + WC_PK_TYPE_PQC_STATEFUL_SIG_VERIFY, + (uint8_t**)&resp); + if (ret >= 0) { + *res = (int)resp->res; + ret = WH_ERROR_OK; + } + } + } + + return ret; +} + +int wh_Client_XmssSigsLeftDma(whClientContext* ctx, XmssKey* key, + word32* sigsLeft) +{ + int ret = WH_ERROR_OK; + uint8_t* dataPtr; + whMessageCrypto_PqcStatefulSigSigsLeftDmaRequest* req; + whMessageCrypto_PqcStatefulSigSigsLeftDmaResponse* res; + whKeyId key_id; + + if ((ctx == NULL) || (key == NULL) || (sigsLeft == NULL)) { + return WH_ERROR_BADARGS; + } + + key_id = WH_DEVCTX_TO_KEYID(key->devCtx); + if (WH_KEYID_ISERASED(key_id)) { + return WH_ERROR_BADARGS; + } + + dataPtr = (uint8_t*)wh_CommClient_GetDataPtr(ctx->comm); + if (dataPtr == NULL) { + return WH_ERROR_BADARGS; + } + + req = (whMessageCrypto_PqcStatefulSigSigsLeftDmaRequest*) + _createCryptoRequestWithSubtype( + dataPtr, WC_PK_TYPE_PQC_STATEFUL_SIG_SIGS_LEFT, + WC_PQC_STATEFUL_SIG_TYPE_XMSS, ctx->cryptoAffinity); + + { + uint16_t group = WH_MESSAGE_GROUP_CRYPTO_DMA; + uint16_t action = WC_ALGO_TYPE_PK; + uint16_t req_len = + sizeof(whMessageCrypto_GenericRequestHeader) + sizeof(*req); + + memset(req, 0, sizeof(*req)); + req->keyId = key_id; + + ret = wh_Client_SendRequest(ctx, group, action, req_len, + (uint8_t*)dataPtr); + if (ret == WH_ERROR_OK) { + do { + ret = wh_Client_RecvResponse(ctx, &group, &action, &req_len, + (uint8_t*)dataPtr); + } while (ret == WH_ERROR_NOTREADY); + } + if (ret == WH_ERROR_OK) { + ret = _getCryptoResponse(dataPtr, + WC_PK_TYPE_PQC_STATEFUL_SIG_SIGS_LEFT, + (uint8_t**)&res); + if (ret >= 0) { + *sigsLeft = res->sigsLeft; + ret = WH_ERROR_OK; + } + } + } + + return ret; +} + +#endif /* WOLFSSL_HAVE_XMSS */ + +#endif /* WOLFHSM_CFG_DMA */ +#endif /* WOLFSSL_HAVE_LMS || WOLFSSL_HAVE_XMSS */ + #endif /* !WOLFHSM_CFG_NO_CRYPTO && WOLFHSM_CFG_ENABLE_CLIENT */ diff --git a/src/wh_client_cryptocb.c b/src/wh_client_cryptocb.c index afb270c2a..f7712bd58 100644 --- a/src/wh_client_cryptocb.c +++ b/src/wh_client_cryptocb.c @@ -48,6 +48,12 @@ #include "wolfssl/wolfcrypt/sha256.h" #include "wolfssl/wolfcrypt/sha512.h" #include "wolfssl/wolfcrypt/wc_mlkem.h" +#if defined(WOLFSSL_HAVE_LMS) +#include "wolfssl/wolfcrypt/wc_lms.h" +#endif +#if defined(WOLFSSL_HAVE_XMSS) +#include "wolfssl/wolfcrypt/wc_xmss.h" +#endif #include "wolfhsm/wh_crypto.h" #include "wolfhsm/wh_client_crypto.h" @@ -64,6 +70,17 @@ static int _handlePqcDecaps(whClientContext* ctx, wc_CryptoInfo* info, int useDma); #endif /* WOLFSSL_HAVE_MLKEM */ +#if defined(WOLFSSL_HAVE_LMS) || defined(WOLFSSL_HAVE_XMSS) +static int _handlePqcStatefulSigKeyGen(whClientContext* ctx, + wc_CryptoInfo* info, int useDma); +static int _handlePqcStatefulSigSign(whClientContext* ctx, wc_CryptoInfo* info, + int useDma); +static int _handlePqcStatefulSigVerify(whClientContext* ctx, + wc_CryptoInfo* info, int useDma); +static int _handlePqcStatefulSigSigsLeft(whClientContext* ctx, + wc_CryptoInfo* info, int useDma); +#endif /* WOLFSSL_HAVE_LMS || WOLFSSL_HAVE_XMSS */ + #if defined(HAVE_DILITHIUM) || defined(HAVE_FALCON) static int _handlePqcSigKeyGen(whClientContext* ctx, wc_CryptoInfo* info, int useDma); @@ -447,6 +464,25 @@ int wh_Client_CryptoCb(int devId, wc_CryptoInfo* info, void* inCtx) #endif /* WOLFSSL_HAVE_MLKEM */ +#if defined(WOLFSSL_HAVE_LMS) || defined(WOLFSSL_HAVE_XMSS) + case WC_PK_TYPE_PQC_STATEFUL_SIG_KEYGEN: + ret = _handlePqcStatefulSigKeyGen(ctx, info, 0); + break; + + case WC_PK_TYPE_PQC_STATEFUL_SIG_SIGN: + ret = _handlePqcStatefulSigSign(ctx, info, 0); + break; + + case WC_PK_TYPE_PQC_STATEFUL_SIG_VERIFY: + ret = _handlePqcStatefulSigVerify(ctx, info, 0); + break; + + case WC_PK_TYPE_PQC_STATEFUL_SIG_SIGS_LEFT: + ret = _handlePqcStatefulSigSigsLeft(ctx, info, 0); + break; + +#endif /* WOLFSSL_HAVE_LMS || WOLFSSL_HAVE_XMSS */ + #if defined(HAVE_DILITHIUM) || defined(HAVE_FALCON) case WC_PK_TYPE_PQC_SIG_KEYGEN: ret = _handlePqcSigKeyGen(ctx, info, 0); @@ -786,6 +822,268 @@ static int _handlePqcDecaps(whClientContext* ctx, wc_CryptoInfo* info, } #endif /* WOLFSSL_HAVE_MLKEM */ +#if defined(WOLFSSL_HAVE_LMS) || defined(WOLFSSL_HAVE_XMSS) +static int _handlePqcStatefulSigKeyGen(whClientContext* ctx, + wc_CryptoInfo* info, int useDma) +{ + int ret = CRYPTOCB_UNAVAILABLE; + int type = info->pk.pqc_stateful_sig_kg.type; + +#ifndef WOLFHSM_CFG_DMA + (void)ctx; + if (useDma) { + return WC_HW_E; + } +#endif + + switch (type) { +#if defined(WOLFSSL_HAVE_LMS) && !defined(WOLFSSL_LMS_VERIFY_ONLY) + case WC_PQC_STATEFUL_SIG_TYPE_LMS: +#ifdef WOLFHSM_CFG_DMA + if (useDma) { + ret = wh_Client_LmsMakeExportKeyDma( + ctx, (LmsKey*)info->pk.pqc_stateful_sig_kg.key); + } + else +#endif /* WOLFHSM_CFG_DMA */ + { + /* Non-DMA transport not supported in v1; signatures exceed the + * default WOLFHSM_CFG_COMM_DATA_LEN. */ + ret = CRYPTOCB_UNAVAILABLE; + } + break; +#endif /* WOLFSSL_HAVE_LMS && !WOLFSSL_LMS_VERIFY_ONLY */ +#if defined(WOLFSSL_HAVE_XMSS) && !defined(WOLFSSL_XMSS_VERIFY_ONLY) + case WC_PQC_STATEFUL_SIG_TYPE_XMSS: +#ifdef WOLFHSM_CFG_DMA + if (useDma) { + ret = wh_Client_XmssMakeExportKeyDma( + ctx, (XmssKey*)info->pk.pqc_stateful_sig_kg.key); + } + else +#endif /* WOLFHSM_CFG_DMA */ + { + ret = CRYPTOCB_UNAVAILABLE; + } + break; +#endif /* WOLFSSL_HAVE_XMSS && !WOLFSSL_XMSS_VERIFY_ONLY */ + + default: + ret = CRYPTOCB_UNAVAILABLE; + break; + } + + if (ret == WH_ERROR_BADARGS) { + ret = BAD_FUNC_ARG; + } + else if (ret == WH_ERROR_NOTIMPL) { + ret = CRYPTOCB_UNAVAILABLE; + } + + return ret; +} + +static int _handlePqcStatefulSigSign(whClientContext* ctx, wc_CryptoInfo* info, + int useDma) +{ + int ret = CRYPTOCB_UNAVAILABLE; + int type = info->pk.pqc_stateful_sig_sign.type; + +#ifndef WOLFHSM_CFG_DMA + (void)ctx; + if (useDma) { + return WC_HW_E; + } +#endif + + switch (type) { +#if defined(WOLFSSL_HAVE_LMS) && !defined(WOLFSSL_LMS_VERIFY_ONLY) + case WC_PQC_STATEFUL_SIG_TYPE_LMS: +#ifdef WOLFHSM_CFG_DMA + if (useDma) { + ret = wh_Client_LmsSignDma( + ctx, + info->pk.pqc_stateful_sig_sign.msg, + info->pk.pqc_stateful_sig_sign.msgSz, + info->pk.pqc_stateful_sig_sign.out, + info->pk.pqc_stateful_sig_sign.outSz, + (LmsKey*)info->pk.pqc_stateful_sig_sign.key); + } + else +#endif /* WOLFHSM_CFG_DMA */ + { + ret = CRYPTOCB_UNAVAILABLE; + } + break; +#endif /* WOLFSSL_HAVE_LMS && !WOLFSSL_LMS_VERIFY_ONLY */ +#if defined(WOLFSSL_HAVE_XMSS) && !defined(WOLFSSL_XMSS_VERIFY_ONLY) + case WC_PQC_STATEFUL_SIG_TYPE_XMSS: +#ifdef WOLFHSM_CFG_DMA + if (useDma) { + ret = wh_Client_XmssSignDma( + ctx, + info->pk.pqc_stateful_sig_sign.msg, + info->pk.pqc_stateful_sig_sign.msgSz, + info->pk.pqc_stateful_sig_sign.out, + info->pk.pqc_stateful_sig_sign.outSz, + (XmssKey*)info->pk.pqc_stateful_sig_sign.key); + } + else +#endif /* WOLFHSM_CFG_DMA */ + { + ret = CRYPTOCB_UNAVAILABLE; + } + break; +#endif /* WOLFSSL_HAVE_XMSS && !WOLFSSL_XMSS_VERIFY_ONLY */ + + default: + ret = CRYPTOCB_UNAVAILABLE; + break; + } + + if (ret == WH_ERROR_BADARGS) { + ret = BAD_FUNC_ARG; + } + else if (ret == WH_ERROR_NOTIMPL) { + ret = CRYPTOCB_UNAVAILABLE; + } + + return ret; +} + +static int _handlePqcStatefulSigVerify(whClientContext* ctx, + wc_CryptoInfo* info, int useDma) +{ + int ret = CRYPTOCB_UNAVAILABLE; + int type = info->pk.pqc_stateful_sig_verify.type; + +#ifndef WOLFHSM_CFG_DMA + (void)ctx; + if (useDma) { + return WC_HW_E; + } +#endif + + switch (type) { +#ifdef WOLFSSL_HAVE_LMS + case WC_PQC_STATEFUL_SIG_TYPE_LMS: +#ifdef WOLFHSM_CFG_DMA + if (useDma) { + ret = wh_Client_LmsVerifyDma( + ctx, + info->pk.pqc_stateful_sig_verify.sig, + info->pk.pqc_stateful_sig_verify.sigSz, + info->pk.pqc_stateful_sig_verify.msg, + info->pk.pqc_stateful_sig_verify.msgSz, + info->pk.pqc_stateful_sig_verify.res, + (LmsKey*)info->pk.pqc_stateful_sig_verify.key); + } + else +#endif /* WOLFHSM_CFG_DMA */ + { + ret = CRYPTOCB_UNAVAILABLE; + } + break; +#endif /* WOLFSSL_HAVE_LMS */ +#ifdef WOLFSSL_HAVE_XMSS + case WC_PQC_STATEFUL_SIG_TYPE_XMSS: +#ifdef WOLFHSM_CFG_DMA + if (useDma) { + ret = wh_Client_XmssVerifyDma( + ctx, + info->pk.pqc_stateful_sig_verify.sig, + info->pk.pqc_stateful_sig_verify.sigSz, + info->pk.pqc_stateful_sig_verify.msg, + info->pk.pqc_stateful_sig_verify.msgSz, + info->pk.pqc_stateful_sig_verify.res, + (XmssKey*)info->pk.pqc_stateful_sig_verify.key); + } + else +#endif /* WOLFHSM_CFG_DMA */ + { + ret = CRYPTOCB_UNAVAILABLE; + } + break; +#endif /* WOLFSSL_HAVE_XMSS */ + + default: + ret = CRYPTOCB_UNAVAILABLE; + break; + } + + if (ret == WH_ERROR_BADARGS) { + ret = BAD_FUNC_ARG; + } + else if (ret == WH_ERROR_NOTIMPL) { + ret = CRYPTOCB_UNAVAILABLE; + } + + return ret; +} + +static int _handlePqcStatefulSigSigsLeft(whClientContext* ctx, + wc_CryptoInfo* info, int useDma) +{ + int ret = CRYPTOCB_UNAVAILABLE; + int type = info->pk.pqc_stateful_sig_sigs_left.type; + +#ifndef WOLFHSM_CFG_DMA + (void)ctx; + if (useDma) { + return WC_HW_E; + } +#endif + + switch (type) { +#if defined(WOLFSSL_HAVE_LMS) && !defined(WOLFSSL_LMS_VERIFY_ONLY) + case WC_PQC_STATEFUL_SIG_TYPE_LMS: +#ifdef WOLFHSM_CFG_DMA + if (useDma) { + ret = wh_Client_LmsSigsLeftDma( + ctx, + (LmsKey*)info->pk.pqc_stateful_sig_sigs_left.key, + info->pk.pqc_stateful_sig_sigs_left.sigsLeft); + } + else +#endif /* WOLFHSM_CFG_DMA */ + { + ret = CRYPTOCB_UNAVAILABLE; + } + break; +#endif /* WOLFSSL_HAVE_LMS && !WOLFSSL_LMS_VERIFY_ONLY */ +#if defined(WOLFSSL_HAVE_XMSS) && !defined(WOLFSSL_XMSS_VERIFY_ONLY) + case WC_PQC_STATEFUL_SIG_TYPE_XMSS: +#ifdef WOLFHSM_CFG_DMA + if (useDma) { + ret = wh_Client_XmssSigsLeftDma( + ctx, + (XmssKey*)info->pk.pqc_stateful_sig_sigs_left.key, + info->pk.pqc_stateful_sig_sigs_left.sigsLeft); + } + else +#endif /* WOLFHSM_CFG_DMA */ + { + ret = CRYPTOCB_UNAVAILABLE; + } + break; +#endif /* WOLFSSL_HAVE_XMSS && !WOLFSSL_XMSS_VERIFY_ONLY */ + + default: + ret = CRYPTOCB_UNAVAILABLE; + break; + } + + if (ret == WH_ERROR_BADARGS) { + ret = BAD_FUNC_ARG; + } + else if (ret == WH_ERROR_NOTIMPL) { + ret = CRYPTOCB_UNAVAILABLE; + } + + return ret; +} +#endif /* WOLFSSL_HAVE_LMS || WOLFSSL_HAVE_XMSS */ + #if defined(HAVE_FALCON) || defined(HAVE_DILITHIUM) static int _handlePqcSigKeyGen(whClientContext* ctx, wc_CryptoInfo* info, int useDma) @@ -1061,6 +1359,20 @@ int wh_Client_CryptoCbDma(int devId, wc_CryptoInfo* info, void* inCtx) ret = _handlePqcDecaps(ctx, info, 1); break; #endif /* WOLFSSL_HAVE_MLKEM */ +#if defined(WOLFSSL_HAVE_LMS) || defined(WOLFSSL_HAVE_XMSS) + case WC_PK_TYPE_PQC_STATEFUL_SIG_KEYGEN: + ret = _handlePqcStatefulSigKeyGen(ctx, info, 1); + break; + case WC_PK_TYPE_PQC_STATEFUL_SIG_SIGN: + ret = _handlePqcStatefulSigSign(ctx, info, 1); + break; + case WC_PK_TYPE_PQC_STATEFUL_SIG_VERIFY: + ret = _handlePqcStatefulSigVerify(ctx, info, 1); + break; + case WC_PK_TYPE_PQC_STATEFUL_SIG_SIGS_LEFT: + ret = _handlePqcStatefulSigSigsLeft(ctx, info, 1); + break; +#endif /* WOLFSSL_HAVE_LMS || WOLFSSL_HAVE_XMSS */ #if defined(HAVE_DILITHIUM) || defined(HAVE_FALCON) case WC_PK_TYPE_PQC_SIG_KEYGEN: ret = _handlePqcSigKeyGen(ctx, info, 1); diff --git a/src/wh_crypto.c b/src/wh_crypto.c index 3b9f6ced5..50c2bc7e3 100644 --- a/src/wh_crypto.c +++ b/src/wh_crypto.c @@ -45,6 +45,12 @@ #include "wolfssl/wolfcrypt/ed25519.h" #include "wolfssl/wolfcrypt/dilithium.h" #include "wolfssl/wolfcrypt/wc_mlkem.h" +#if defined(WOLFSSL_HAVE_LMS) +#include "wolfssl/wolfcrypt/wc_lms.h" +#endif +#if defined(WOLFSSL_HAVE_XMSS) +#include "wolfssl/wolfcrypt/wc_xmss.h" +#endif #include "wolfssl/wolfcrypt/memory.h" #include "wolfhsm/wh_error.h" @@ -485,6 +491,279 @@ int wh_Crypto_MlKemDeserializeKey(const uint8_t* buffer, uint16_t size, } #endif /* WOLFSSL_HAVE_MLKEM */ +#if defined(WOLFSSL_HAVE_LMS) || defined(WOLFSSL_HAVE_XMSS) +/* Stateful hash-based signature key serialization helpers. + * + * Slot blob layout: + * uint32_t magic; + * uint16_t pubLen; + * uint16_t privLen; + * uint16_t paramLen; + * uint16_t reserved; + * uint8_t paramDescriptor[paramLen]; + * uint8_t pub[pubLen]; + * uint8_t priv[privLen]; + * + * Native byte order: the blob is server-internal (NVM-stored) and never + * traverses the wire. */ + +#define WH_CRYPTO_STATEFUL_SIG_HEADER_SZ 12 /* magic + 4*uint16 */ + +static int _StatefulSigEncodeHeader(uint8_t* buffer, uint32_t magic, + uint16_t pubLen, uint16_t privLen, + uint16_t paramLen) +{ + uint16_t reserved = 0; + memcpy(buffer + 0, &magic, sizeof(magic)); + memcpy(buffer + 4, &pubLen, sizeof(pubLen)); + memcpy(buffer + 6, &privLen, sizeof(privLen)); + memcpy(buffer + 8, ¶mLen, sizeof(paramLen)); + memcpy(buffer + 10, &reserved, sizeof(reserved)); + return WH_ERROR_OK; +} + +static int _StatefulSigDecodeHeader(const uint8_t* buffer, uint16_t size, + uint32_t expectMagic, uint16_t* pubLen, + uint16_t* privLen, uint16_t* paramLen) +{ + uint32_t magic; + + if (size < WH_CRYPTO_STATEFUL_SIG_HEADER_SZ) { + return WH_ERROR_BADARGS; + } + memcpy(&magic, buffer + 0, sizeof(magic)); + if (magic != expectMagic) { + return WH_ERROR_BADARGS; + } + memcpy(pubLen, buffer + 4, sizeof(*pubLen)); + memcpy(privLen, buffer + 6, sizeof(*privLen)); + memcpy(paramLen, buffer + 8, sizeof(*paramLen)); + if ((uint32_t)WH_CRYPTO_STATEFUL_SIG_HEADER_SZ + *paramLen + *pubLen + + *privLen > size) { + return WH_ERROR_BADARGS; + } + return WH_ERROR_OK; +} +#endif /* WOLFSSL_HAVE_LMS || WOLFSSL_HAVE_XMSS */ + +#ifdef WOLFSSL_HAVE_LMS +int wh_Crypto_LmsSerializeKey(LmsKey* key, uint16_t max_size, uint8_t* buffer, + uint16_t* out_size) +{ + word32 pubLen32 = 0; + uint16_t pubLen; + uint16_t privLen; + uint16_t paramLen = 3; /* levels, height, winternitz */ + uint32_t totalLen; + int ret; + + if ((key == NULL) || (buffer == NULL) || (out_size == NULL) || + (key->params == NULL)) { + return WH_ERROR_BADARGS; + } + + ret = wc_LmsKey_GetPubLen(key, &pubLen32); + if (ret != 0) { + return WH_ERROR_BADARGS; + } + pubLen = (uint16_t)pubLen32; + privLen = (uint16_t)HSS_PRIVATE_KEY_LEN(key->params->hash_len); + + totalLen = (uint32_t)WH_CRYPTO_STATEFUL_SIG_HEADER_SZ + paramLen + pubLen + + privLen; + if (totalLen > max_size) { + return WH_ERROR_BUFFER_SIZE; + } + + (void)_StatefulSigEncodeHeader(buffer, + WH_CRYPTO_STATEFUL_SIG_BLOB_MAGIC_LMS, + pubLen, privLen, paramLen); + + /* paramDescriptor: levels, height, winternitz */ + buffer[WH_CRYPTO_STATEFUL_SIG_HEADER_SZ + 0] = key->params->levels; + buffer[WH_CRYPTO_STATEFUL_SIG_HEADER_SZ + 1] = key->params->height; + buffer[WH_CRYPTO_STATEFUL_SIG_HEADER_SZ + 2] = key->params->width; + + memcpy(buffer + WH_CRYPTO_STATEFUL_SIG_HEADER_SZ + paramLen, + key->pub, pubLen); + memcpy(buffer + WH_CRYPTO_STATEFUL_SIG_HEADER_SZ + paramLen + pubLen, + key->priv_raw, privLen); + + *out_size = (uint16_t)totalLen; + return WH_ERROR_OK; +} + +int wh_Crypto_LmsDeserializeKey(const uint8_t* buffer, uint16_t size, + LmsKey* key) +{ + uint16_t pubLen; + uint16_t privLen; + uint16_t paramLen; + word32 expectPubLen = 0; + int ret; + int levels; + int height; + int winternitz; + const uint8_t* p; + + if ((buffer == NULL) || (key == NULL)) { + return WH_ERROR_BADARGS; + } + + ret = _StatefulSigDecodeHeader(buffer, size, + WH_CRYPTO_STATEFUL_SIG_BLOB_MAGIC_LMS, + &pubLen, &privLen, ¶mLen); + if (ret != WH_ERROR_OK) { + return ret; + } + if (paramLen != 3) { + return WH_ERROR_BADARGS; + } + + p = buffer + WH_CRYPTO_STATEFUL_SIG_HEADER_SZ; + levels = (int)p[0]; + height = (int)p[1]; + winternitz = (int)p[2]; + + ret = wc_LmsKey_SetParameters(key, levels, height, winternitz); + if (ret != 0) { + return ret; + } + + /* Sanity-check pub size against the bound parameter set */ + ret = wc_LmsKey_GetPubLen(key, &expectPubLen); + if ((ret != 0) || (expectPubLen != pubLen)) { + return WH_ERROR_BADARGS; + } + if (privLen != (uint16_t)HSS_PRIVATE_KEY_LEN(key->params->hash_len)) { + return WH_ERROR_BADARGS; + } + + p = buffer + WH_CRYPTO_STATEFUL_SIG_HEADER_SZ + paramLen; + memcpy(key->pub, p, pubLen); + p += pubLen; + memcpy(key->priv_raw, p, privLen); + + return WH_ERROR_OK; +} +#endif /* WOLFSSL_HAVE_LMS */ + +#ifdef WOLFSSL_HAVE_XMSS +int wh_Crypto_XmssSerializeKey(XmssKey* key, const char* paramStr, + uint16_t max_size, uint8_t* buffer, + uint16_t* out_size) +{ + word32 pubLen32 = 0; + word32 privLen32 = 0; + uint16_t pubLen; + uint16_t privLen; + uint16_t paramLen; + uint32_t totalLen; + size_t strLen; + int ret; + + if ((key == NULL) || (paramStr == NULL) || (buffer == NULL) || + (out_size == NULL) || (key->params == NULL) || (key->sk == NULL)) { + return WH_ERROR_BADARGS; + } + + ret = wc_XmssKey_GetPubLen(key, &pubLen32); + if (ret != 0) { + return WH_ERROR_BADARGS; + } + ret = wc_XmssKey_GetPrivLen(key, &privLen32); + if (ret != 0) { + return WH_ERROR_BADARGS; + } + pubLen = (uint16_t)pubLen32; + privLen = (uint16_t)privLen32; + + strLen = strlen(paramStr); + if (strLen >= 0xFFFFu) { + return WH_ERROR_BADARGS; + } + paramLen = (uint16_t)(strLen + 1); /* include NUL */ + + totalLen = (uint32_t)WH_CRYPTO_STATEFUL_SIG_HEADER_SZ + paramLen + pubLen + + privLen; + if (totalLen > max_size) { + return WH_ERROR_BUFFER_SIZE; + } + + (void)_StatefulSigEncodeHeader(buffer, + WH_CRYPTO_STATEFUL_SIG_BLOB_MAGIC_XMSS, + pubLen, privLen, paramLen); + + memcpy(buffer + WH_CRYPTO_STATEFUL_SIG_HEADER_SZ, paramStr, paramLen); + memcpy(buffer + WH_CRYPTO_STATEFUL_SIG_HEADER_SZ + paramLen, + key->pk, pubLen); + memcpy(buffer + WH_CRYPTO_STATEFUL_SIG_HEADER_SZ + paramLen + pubLen, + key->sk, privLen); + + *out_size = (uint16_t)totalLen; + return WH_ERROR_OK; +} + +int wh_Crypto_XmssDeserializeKey(const uint8_t* buffer, uint16_t size, + XmssKey* key) +{ + uint16_t pubLen; + uint16_t privLen; + uint16_t paramLen; + word32 expectPubLen = 0; + word32 expectPrivLen = 0; + int ret; + const char* paramStr; + const uint8_t* p; + + if ((buffer == NULL) || (key == NULL)) { + return WH_ERROR_BADARGS; + } + + ret = _StatefulSigDecodeHeader(buffer, size, + WH_CRYPTO_STATEFUL_SIG_BLOB_MAGIC_XMSS, + &pubLen, &privLen, ¶mLen); + if (ret != WH_ERROR_OK) { + return ret; + } + if (paramLen == 0) { + return WH_ERROR_BADARGS; + } + + /* paramDescriptor must be NUL-terminated and within paramLen */ + paramStr = (const char*)(buffer + WH_CRYPTO_STATEFUL_SIG_HEADER_SZ); + if (paramStr[paramLen - 1] != '\0') { + return WH_ERROR_BADARGS; + } + + /* SetParamStr binds key->params; sk is allocated later by Reload via + * the read callback path (or directly if the caller wants to pre-load + * it). */ + ret = wc_XmssKey_SetParamStr(key, paramStr); + if (ret != 0) { + return ret; + } + + ret = wc_XmssKey_GetPubLen(key, &expectPubLen); + if ((ret != 0) || (expectPubLen != pubLen)) { + return WH_ERROR_BADARGS; + } + ret = wc_XmssKey_GetPrivLen(key, &expectPrivLen); + if ((ret != 0) || (expectPrivLen != privLen)) { + return WH_ERROR_BADARGS; + } + + p = buffer + WH_CRYPTO_STATEFUL_SIG_HEADER_SZ + paramLen; + memcpy(key->pk, p, pubLen); + /* The private key is left in the slot blob; downstream paths read it + * via the bridge ReadCb against the cached slot (sk is allocated by + * Reload, not by deserialize). */ + (void)privLen; + + return WH_ERROR_OK; +} +#endif /* WOLFSSL_HAVE_XMSS */ + #ifdef WOLFSSL_CMAC void wh_Crypto_CmacAesSaveStateToMsg(whMessageCrypto_CmacAesState* state, const Cmac* cmac) diff --git a/src/wh_message_crypto.c b/src/wh_message_crypto.c index 560aa8e90..87c55f521 100644 --- a/src/wh_message_crypto.c +++ b/src/wh_message_crypto.c @@ -1365,6 +1365,185 @@ int wh_MessageCrypto_TranslateMlKemDecapsDmaResponse( return 0; } +/* Stateful sig DMA Key Generation Request translation */ +int wh_MessageCrypto_TranslatePqcStatefulSigKeyGenDmaRequest( + uint16_t magic, + const whMessageCrypto_PqcStatefulSigKeyGenDmaRequest* src, + whMessageCrypto_PqcStatefulSigKeyGenDmaRequest* dest) +{ + int ret; + + if ((src == NULL) || (dest == NULL)) { + return WH_ERROR_BADARGS; + } + + ret = wh_MessageCrypto_TranslateDmaBuffer(magic, &src->pub, &dest->pub); + if (ret != 0) { + return ret; + } + + WH_T32(magic, dest, src, flags); + WH_T32(magic, dest, src, keyId); + WH_T32(magic, dest, src, access); + WH_T32(magic, dest, src, labelSize); + WH_T32(magic, dest, src, lmsLevels); + WH_T32(magic, dest, src, lmsHeight); + WH_T32(magic, dest, src, lmsWinternitz); + if (src != dest) { + memcpy(dest->label, src->label, sizeof(src->label)); + memcpy(dest->xmssParamStr, src->xmssParamStr, + sizeof(src->xmssParamStr)); + } + return 0; +} + +/* Stateful sig DMA Key Generation Response translation */ +int wh_MessageCrypto_TranslatePqcStatefulSigKeyGenDmaResponse( + uint16_t magic, + const whMessageCrypto_PqcStatefulSigKeyGenDmaResponse* src, + whMessageCrypto_PqcStatefulSigKeyGenDmaResponse* dest) +{ + int ret; + + if ((src == NULL) || (dest == NULL)) { + return WH_ERROR_BADARGS; + } + + ret = wh_MessageCrypto_TranslateDmaAddrStatus(magic, &src->dmaAddrStatus, + &dest->dmaAddrStatus); + if (ret != 0) { + return ret; + } + + WH_T32(magic, dest, src, keyId); + WH_T32(magic, dest, src, pubSize); + return 0; +} + +/* Stateful sig DMA Sign Request translation */ +int wh_MessageCrypto_TranslatePqcStatefulSigSignDmaRequest( + uint16_t magic, + const whMessageCrypto_PqcStatefulSigSignDmaRequest* src, + whMessageCrypto_PqcStatefulSigSignDmaRequest* dest) +{ + int ret; + + if ((src == NULL) || (dest == NULL)) { + return WH_ERROR_BADARGS; + } + + ret = wh_MessageCrypto_TranslateDmaBuffer(magic, &src->msg, &dest->msg); + if (ret != 0) { + return ret; + } + ret = wh_MessageCrypto_TranslateDmaBuffer(magic, &src->sig, &dest->sig); + if (ret != 0) { + return ret; + } + + WH_T32(magic, dest, src, options); + WH_T32(magic, dest, src, keyId); + return 0; +} + +/* Stateful sig DMA Sign Response translation */ +int wh_MessageCrypto_TranslatePqcStatefulSigSignDmaResponse( + uint16_t magic, + const whMessageCrypto_PqcStatefulSigSignDmaResponse* src, + whMessageCrypto_PqcStatefulSigSignDmaResponse* dest) +{ + int ret; + + if ((src == NULL) || (dest == NULL)) { + return WH_ERROR_BADARGS; + } + + ret = wh_MessageCrypto_TranslateDmaAddrStatus(magic, &src->dmaAddrStatus, + &dest->dmaAddrStatus); + if (ret != 0) { + return ret; + } + + WH_T32(magic, dest, src, sigLen); + return 0; +} + +/* Stateful sig DMA Verify Request translation */ +int wh_MessageCrypto_TranslatePqcStatefulSigVerifyDmaRequest( + uint16_t magic, + const whMessageCrypto_PqcStatefulSigVerifyDmaRequest* src, + whMessageCrypto_PqcStatefulSigVerifyDmaRequest* dest) +{ + int ret; + + if ((src == NULL) || (dest == NULL)) { + return WH_ERROR_BADARGS; + } + + ret = wh_MessageCrypto_TranslateDmaBuffer(magic, &src->sig, &dest->sig); + if (ret != 0) { + return ret; + } + ret = wh_MessageCrypto_TranslateDmaBuffer(magic, &src->msg, &dest->msg); + if (ret != 0) { + return ret; + } + + WH_T32(magic, dest, src, options); + WH_T32(magic, dest, src, keyId); + return 0; +} + +/* Stateful sig DMA Verify Response translation */ +int wh_MessageCrypto_TranslatePqcStatefulSigVerifyDmaResponse( + uint16_t magic, + const whMessageCrypto_PqcStatefulSigVerifyDmaResponse* src, + whMessageCrypto_PqcStatefulSigVerifyDmaResponse* dest) +{ + int ret; + + if ((src == NULL) || (dest == NULL)) { + return WH_ERROR_BADARGS; + } + + ret = wh_MessageCrypto_TranslateDmaAddrStatus(magic, &src->dmaAddrStatus, + &dest->dmaAddrStatus); + if (ret != 0) { + return ret; + } + + WH_T32(magic, dest, src, res); + return 0; +} + +/* Stateful sig DMA Signatures-Left Request translation */ +int wh_MessageCrypto_TranslatePqcStatefulSigSigsLeftDmaRequest( + uint16_t magic, + const whMessageCrypto_PqcStatefulSigSigsLeftDmaRequest* src, + whMessageCrypto_PqcStatefulSigSigsLeftDmaRequest* dest) +{ + if ((src == NULL) || (dest == NULL)) { + return WH_ERROR_BADARGS; + } + + WH_T32(magic, dest, src, keyId); + return 0; +} + +/* Stateful sig DMA Signatures-Left Response translation */ +int wh_MessageCrypto_TranslatePqcStatefulSigSigsLeftDmaResponse( + uint16_t magic, + const whMessageCrypto_PqcStatefulSigSigsLeftDmaResponse* src, + whMessageCrypto_PqcStatefulSigSigsLeftDmaResponse* dest) +{ + if ((src == NULL) || (dest == NULL)) { + return WH_ERROR_BADARGS; + } + + WH_T32(magic, dest, src, sigsLeft); + return 0; +} + /* Ed25519 DMA Sign Request translation */ int wh_MessageCrypto_TranslateEd25519SignDmaRequest( uint16_t magic, const whMessageCrypto_Ed25519SignDmaRequest* src, diff --git a/src/wh_server_crypto.c b/src/wh_server_crypto.c index c0358fcc2..2d3c261cd 100644 --- a/src/wh_server_crypto.c +++ b/src/wh_server_crypto.c @@ -45,6 +45,12 @@ #include "wolfssl/wolfcrypt/cmac.h" #include "wolfssl/wolfcrypt/dilithium.h" #include "wolfssl/wolfcrypt/wc_mlkem.h" +#if defined(WOLFSSL_HAVE_LMS) +#include "wolfssl/wolfcrypt/wc_lms.h" +#endif +#if defined(WOLFSSL_HAVE_XMSS) +#include "wolfssl/wolfcrypt/wc_xmss.h" +#endif #include "wolfssl/wolfcrypt/hmac.h" #include "wolfssl/wolfcrypt/kdf.h" @@ -879,6 +885,255 @@ int wh_Server_MlKemKeyCacheExport(whServerContext* ctx, whKeyId keyId, } #endif /* WOLFSSL_HAVE_MLKEM */ +#if (defined(WOLFSSL_HAVE_LMS) || defined(WOLFSSL_HAVE_XMSS)) && \ + defined(WOLFHSM_CFG_DMA) +/* Stateful-key persistence bridge. + * + * wolfCrypt's wc_LmsKey_Sign and wc_XmssKey_Sign require write/read callbacks + * for the software path. We wire write_private_key directly to atomic NVM + * commit (wh_Nvm_AddObjectWithReclaim): wolfCrypt's contract is to advance + * the index, call write_cb, and only emit the signature if write_cb returned + * success. That gives us pre-commit-then-emit ordering for free — see + * doc/LMS_XMSS_CryptoCb.md and the plan for the crash-safety analysis. + * + * The bridge keeps a pointer into the server's cache slot blob (laid out by + * wh_Crypto_{Lms,Xmss}SerializeKey). Each write_cb invocation overwrites the + * priv region of the slot in place and re-commits the entire slot. */ +typedef struct whServerStatefulSigBridge { + whServerContext* server; + whKeyId keyId; + whNvmMetadata* meta; /* points at the cache slot's metadata */ + uint8_t* slotBuf; /* points at the cache slot's data buffer */ + uint16_t hdrSz; /* offset to priv region inside slotBuf */ + uint16_t pubLen; /* offset of priv = hdrSz + paramLen + pubLen */ + uint16_t paramLen; + uint16_t slotCapacity; +} whServerStatefulSigBridge; + +/* Compute the priv-region offset inside the slot blob from a bridge. */ +static uint16_t _StatefulBridgePrivOffset(const whServerStatefulSigBridge* b) +{ + return (uint16_t)(b->hdrSz + b->paramLen + b->pubLen); +} + +/* Update the slot blob's privLen field (header + 6 -> priv length). */ +static void _StatefulBridgeWritePrivLen(uint8_t* slotBuf, uint16_t privLen) +{ + /* See wh_crypto.c for layout: privLen is at offset +6. */ + memcpy(slotBuf + 6, &privLen, sizeof(privLen)); +} + +#if defined(WOLFSSL_HAVE_LMS) && defined(WOLFHSM_CFG_DMA) +static int _LmsBridgeWriteCb(const byte* priv, word32 privSz, void* context) +{ + whServerStatefulSigBridge* b = (whServerStatefulSigBridge*)context; + uint16_t privOff; + uint32_t newLen; + int rc; + + if ((b == NULL) || (priv == NULL) || (b->slotBuf == NULL) || + (b->meta == NULL)) { + return WC_LMS_RC_BAD_ARG; + } + + privOff = _StatefulBridgePrivOffset(b); + newLen = (uint32_t)privOff + privSz; + if (newLen > b->slotCapacity) { + return WC_LMS_RC_WRITE_FAIL; + } + + memcpy(b->slotBuf + privOff, priv, privSz); + _StatefulBridgeWritePrivLen(b->slotBuf, (uint16_t)privSz); + b->meta->len = (whNvmSize)newLen; + + /* Atomic dual-partition commit. Wolfcrypt aborts the sign if this + * returns anything other than _SAVED_TO_NV_MEMORY, so the signature + * never escapes for an un-persisted index. */ + rc = wh_Nvm_AddObjectWithReclaim(b->server->nvm, b->meta, b->meta->len, + b->slotBuf); + return (rc == WH_ERROR_OK) ? WC_LMS_RC_SAVED_TO_NV_MEMORY + : WC_LMS_RC_WRITE_FAIL; +} + +static int _LmsBridgeReadCb(byte* priv, word32 privSz, void* context) +{ + whServerStatefulSigBridge* b = (whServerStatefulSigBridge*)context; + uint16_t privOff; + + if ((b == NULL) || (priv == NULL) || (b->slotBuf == NULL)) { + return WC_LMS_RC_BAD_ARG; + } + + privOff = _StatefulBridgePrivOffset(b); + if ((uint32_t)privOff + privSz > b->meta->len) { + return WC_LMS_RC_READ_FAIL; + } + + memcpy(priv, b->slotBuf + privOff, privSz); + return WC_LMS_RC_READ_TO_MEMORY; +} +#endif /* WOLFSSL_HAVE_LMS && WOLFHSM_CFG_DMA */ + +#if defined(WOLFSSL_HAVE_XMSS) && defined(WOLFHSM_CFG_DMA) +static enum wc_XmssRc _XmssBridgeWriteCb(const byte* priv, word32 privSz, + void* context) +{ + whServerStatefulSigBridge* b = (whServerStatefulSigBridge*)context; + uint16_t privOff; + uint32_t newLen; + int rc; + + if ((b == NULL) || (priv == NULL) || (b->slotBuf == NULL) || + (b->meta == NULL)) { + return WC_XMSS_RC_BAD_ARG; + } + + privOff = _StatefulBridgePrivOffset(b); + newLen = (uint32_t)privOff + privSz; + if (newLen > b->slotCapacity) { + return WC_XMSS_RC_WRITE_FAIL; + } + + memcpy(b->slotBuf + privOff, priv, privSz); + _StatefulBridgeWritePrivLen(b->slotBuf, (uint16_t)privSz); + b->meta->len = (whNvmSize)newLen; + + rc = wh_Nvm_AddObjectWithReclaim(b->server->nvm, b->meta, b->meta->len, + b->slotBuf); + return (rc == WH_ERROR_OK) ? WC_XMSS_RC_SAVED_TO_NV_MEMORY + : WC_XMSS_RC_WRITE_FAIL; +} + +static enum wc_XmssRc _XmssBridgeReadCb(byte* priv, word32 privSz, + void* context) +{ + whServerStatefulSigBridge* b = (whServerStatefulSigBridge*)context; + uint16_t privOff; + + if ((b == NULL) || (priv == NULL) || (b->slotBuf == NULL)) { + return WC_XMSS_RC_BAD_ARG; + } + + privOff = _StatefulBridgePrivOffset(b); + if ((uint32_t)privOff + privSz > b->meta->len) { + return WC_XMSS_RC_READ_FAIL; + } + + memcpy(priv, b->slotBuf + privOff, privSz); + return WC_XMSS_RC_READ_TO_MEMORY; +} +#endif /* WOLFSSL_HAVE_XMSS && WOLFHSM_CFG_DMA */ +#endif /* (WOLFSSL_HAVE_LMS || WOLFSSL_HAVE_XMSS) && WOLFHSM_CFG_DMA */ + +#ifdef WOLFSSL_HAVE_LMS +int wh_Server_LmsKeyCacheImport(whServerContext* ctx, LmsKey* key, + whKeyId keyId, whNvmFlags flags, + uint16_t label_len, uint8_t* label) +{ + int ret = WH_ERROR_OK; + uint8_t* cacheBuf; + whNvmMetadata* cacheMeta; + uint16_t slotCapacity = WOLFHSM_CFG_SERVER_KEYCACHE_BIG_BUFSIZE; + uint16_t blobSize; + + if ((ctx == NULL) || (key == NULL) || (WH_KEYID_ISERASED(keyId)) || + ((label != NULL) && (label_len > sizeof(cacheMeta->label)))) { + return WH_ERROR_BADARGS; + } + + ret = wh_Server_KeystoreGetCacheSlotChecked(ctx, keyId, slotCapacity, + &cacheBuf, &cacheMeta); + if (ret == WH_ERROR_OK) { + ret = wh_Crypto_LmsSerializeKey(key, slotCapacity, cacheBuf, &blobSize); + } + if (ret == WH_ERROR_OK) { + cacheMeta->id = keyId; + cacheMeta->len = blobSize; + cacheMeta->flags = flags; + cacheMeta->access = WH_NVM_ACCESS_ANY; + if ((label != NULL) && (label_len > 0)) { + memcpy(cacheMeta->label, label, label_len); + } + } + return ret; +} + +int wh_Server_LmsKeyCacheExport(whServerContext* ctx, whKeyId keyId, + LmsKey* key) +{ + uint8_t* cacheBuf; + whNvmMetadata* cacheMeta; + int ret; + + if ((ctx == NULL) || (key == NULL) || (WH_KEYID_ISERASED(keyId))) { + return WH_ERROR_BADARGS; + } + + ret = wh_Server_KeystoreFreshenKey(ctx, keyId, &cacheBuf, &cacheMeta); + if (ret == WH_ERROR_OK) { + ret = wh_Crypto_LmsDeserializeKey(cacheBuf, (uint16_t)cacheMeta->len, + key); + } + return ret; +} +#endif /* WOLFSSL_HAVE_LMS */ + +#ifdef WOLFSSL_HAVE_XMSS +int wh_Server_XmssKeyCacheImport(whServerContext* ctx, XmssKey* key, + const char* paramStr, whKeyId keyId, + whNvmFlags flags, uint16_t label_len, + uint8_t* label) +{ + int ret = WH_ERROR_OK; + uint8_t* cacheBuf; + whNvmMetadata* cacheMeta; + uint16_t slotCapacity = WOLFHSM_CFG_SERVER_KEYCACHE_BIG_BUFSIZE; + uint16_t blobSize; + + if ((ctx == NULL) || (key == NULL) || (paramStr == NULL) || + (WH_KEYID_ISERASED(keyId)) || + ((label != NULL) && (label_len > sizeof(cacheMeta->label)))) { + return WH_ERROR_BADARGS; + } + + ret = wh_Server_KeystoreGetCacheSlotChecked(ctx, keyId, slotCapacity, + &cacheBuf, &cacheMeta); + if (ret == WH_ERROR_OK) { + ret = wh_Crypto_XmssSerializeKey(key, paramStr, slotCapacity, cacheBuf, + &blobSize); + } + if (ret == WH_ERROR_OK) { + cacheMeta->id = keyId; + cacheMeta->len = blobSize; + cacheMeta->flags = flags; + cacheMeta->access = WH_NVM_ACCESS_ANY; + if ((label != NULL) && (label_len > 0)) { + memcpy(cacheMeta->label, label, label_len); + } + } + return ret; +} + +int wh_Server_XmssKeyCacheExport(whServerContext* ctx, whKeyId keyId, + XmssKey* key) +{ + uint8_t* cacheBuf; + whNvmMetadata* cacheMeta; + int ret; + + if ((ctx == NULL) || (key == NULL) || (WH_KEYID_ISERASED(keyId))) { + return WH_ERROR_BADARGS; + } + + ret = wh_Server_KeystoreFreshenKey(ctx, keyId, &cacheBuf, &cacheMeta); + if (ret == WH_ERROR_OK) { + ret = wh_Crypto_XmssDeserializeKey(cacheBuf, (uint16_t)cacheMeta->len, + key); + } + return ret; +} +#endif /* WOLFSSL_HAVE_XMSS */ + /** Request/Response Handling functions */ @@ -6528,123 +6783,1093 @@ static int _HandleMlKemDecapsDma(whServerContext* ctx, uint16_t magic, #endif } -static int _HandlePqcKemAlgorithmDma(whServerContext* ctx, uint16_t magic, - int devId, const void* cryptoDataIn, - uint16_t cryptoInSize, void* cryptoDataOut, - uint16_t* cryptoOutSize, - uint32_t pkAlgoType, uint32_t pqAlgoType) +#if defined(WOLFSSL_HAVE_LMS) || defined(WOLFSSL_HAVE_XMSS) +/* Decode the slot blob's header lengths into the bridge struct. The blob + * format (see wh_crypto.c) places pubLen at +4, privLen at +6, paramLen at + * +8. */ +static int _StatefulBridgeFromSlot(whServerStatefulSigBridge* b, + whServerContext* server, + whKeyId keyId, + uint8_t* slotBuf, whNvmMetadata* meta, + uint16_t slotCapacity) { - int ret = WH_ERROR_NOHANDLER; + uint16_t pubLen, paramLen; - switch (pqAlgoType) { - case WC_PQC_KEM_TYPE_KYBER: { - switch (pkAlgoType) { - case WC_PK_TYPE_PQC_KEM_KEYGEN: - ret = _HandleMlKemKeyGenDma(ctx, magic, devId, cryptoDataIn, - cryptoInSize, cryptoDataOut, - cryptoOutSize); - break; - case WC_PK_TYPE_PQC_KEM_ENCAPS: - ret = _HandleMlKemEncapsDma(ctx, magic, devId, cryptoDataIn, - cryptoInSize, cryptoDataOut, - cryptoOutSize); - break; - case WC_PK_TYPE_PQC_KEM_DECAPS: - ret = _HandleMlKemDecapsDma(ctx, magic, devId, cryptoDataIn, - cryptoInSize, cryptoDataOut, - cryptoOutSize); - break; - default: - ret = WH_ERROR_NOHANDLER; - break; - } - } break; - default: - ret = WH_ERROR_NOHANDLER; - break; + if ((b == NULL) || (server == NULL) || (slotBuf == NULL) || (meta == NULL)) { + return WH_ERROR_BADARGS; } + memcpy(&pubLen, slotBuf + 4, sizeof(pubLen)); + memcpy(¶mLen, slotBuf + 8, sizeof(paramLen)); - return ret; + b->server = server; + b->keyId = keyId; + b->meta = meta; + b->slotBuf = slotBuf; + b->hdrSz = 12; /* WH_CRYPTO_STATEFUL_SIG_HEADER_SZ */ + b->paramLen = paramLen; + b->pubLen = pubLen; + b->slotCapacity = slotCapacity; + return WH_ERROR_OK; +} +#endif /* WOLFSSL_HAVE_LMS || WOLFSSL_HAVE_XMSS */ + +#ifdef WOLFSSL_HAVE_LMS +/* Dummy cbs used during keygen. wc_LmsKey_MakeKey requires both cbs to be + * set; we don't actually persist via cb during keygen because the slot blob + * (including pub) is assembled after MakeKey populates key->pub and + * key->priv_raw. See _HandleLmsKeyGenDma for the full sequence. */ +static int _LmsDummyWriteCb(const byte* priv, word32 privSz, void* context) +{ + (void)priv; (void)privSz; (void)context; + return WC_LMS_RC_SAVED_TO_NV_MEMORY; +} +static int _LmsDummyReadCb(byte* priv, word32 privSz, void* context) +{ + (void)priv; (void)privSz; (void)context; + return WC_LMS_RC_READ_TO_MEMORY; } -#endif /* WOLFSSL_HAVE_MLKEM */ -#if defined(WOLFSSL_CMAC) && !defined(NO_AES) && defined(WOLFSSL_AES_DIRECT) -static int _HandleCmacDma(whServerContext* ctx, uint16_t magic, int devId, - uint16_t seq, const void* cryptoDataIn, - uint16_t inSize, void* cryptoDataOut, - uint16_t* outSize) +static int _HandleLmsKeyGenDma(whServerContext* ctx, uint16_t magic, int devId, + const void* cryptoDataIn, uint16_t inSize, + void* cryptoDataOut, uint16_t* outSize) { - (void)seq; +#ifdef WOLFSSL_LMS_VERIFY_ONLY + (void)ctx; (void)magic; (void)devId; (void)cryptoDataIn; (void)inSize; + (void)cryptoDataOut; (void)outSize; + return WH_ERROR_NOHANDLER; +#else + int ret; + LmsKey key[1]; + void* clientPubAddr = NULL; + word32 pubLen32 = 0; + whKeyId keyId; + whMessageCrypto_PqcStatefulSigKeyGenDmaRequest req; + whMessageCrypto_PqcStatefulSigKeyGenDmaResponse res; - int ret = 0; - whMessageCrypto_CmacAesDmaRequest req; - whMessageCrypto_CmacAesDmaResponse res; + memset(&res, 0, sizeof(res)); - if (inSize < sizeof(whMessageCrypto_CmacAesDmaRequest)) { + if (inSize < sizeof(req)) { return WH_ERROR_BADARGS; } - /* Translate request */ - ret = wh_MessageCrypto_TranslateCmacAesDmaRequest( - magic, (whMessageCrypto_CmacAesDmaRequest*)cryptoDataIn, &req); + ret = wh_MessageCrypto_TranslatePqcStatefulSigKeyGenDmaRequest( + magic, (whMessageCrypto_PqcStatefulSigKeyGenDmaRequest*)cryptoDataIn, + &req); if (ret != WH_ERROR_OK) { return ret; } - /* Validate variable-length fields fit within inSize */ - uint32_t available = inSize - sizeof(whMessageCrypto_CmacAesDmaRequest); - if (req.keySz > available) { - return WH_ERROR_BADARGS; - } - if (req.keySz > AES_256_KEY_SIZE) { - return WH_ERROR_BADARGS; + ret = wc_LmsKey_Init(key, NULL, devId); + if (ret != 0) { + return ret; } - word32 len; - - /* Pointers to inline trailing data */ - uint8_t* key = - (uint8_t*)(cryptoDataIn) + sizeof(whMessageCrypto_CmacAesDmaRequest); - uint8_t* out = - (uint8_t*)(cryptoDataOut) + sizeof(whMessageCrypto_CmacAesDmaResponse); - - memset(&res, 0, sizeof(res)); + ret = wc_LmsKey_SetParameters(key, (int)req.lmsLevels, (int)req.lmsHeight, + (int)req.lmsWinternitz); + if (ret == 0) { + ret = wc_LmsKey_SetWriteCb(key, _LmsDummyWriteCb); + } + if (ret == 0) { + ret = wc_LmsKey_SetReadCb(key, _LmsDummyReadCb); + } + if (ret == 0) { + ret = wc_LmsKey_SetContext(key, NULL); + } + if (ret == 0) { + ret = wc_LmsKey_MakeKey(key, ctx->crypto->rng); + } - /* DMA translated address for input */ - void* inAddr = NULL; + if (ret == 0) { + keyId = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, + ctx->comm->client_id, req.keyId); + if (WH_KEYID_ISERASED(keyId)) { + ret = wh_Server_KeystoreGetUniqueId(ctx, &keyId); + } + } - uint8_t tmpKey[AES_256_KEY_SIZE]; - uint32_t tmpKeyLen = sizeof(tmpKey); - Cmac cmac[1]; + if (ret == 0) { + ret = wh_Server_LmsKeyCacheImport(ctx, key, keyId, req.flags, + (uint16_t)req.labelSize, req.label); + } - /* Attempt oneshot if input and output are both present */ - if (req.input.sz != 0 && req.outSz != 0) { - len = req.outSz; + /* For non-ephemeral keys, commit to NVM so the key survives a server + * restart. Ephemeral keys are cache-only. */ + if ((ret == 0) && ((req.flags & WH_NVM_FLAGS_EPHEMERAL) == 0)) { + ret = wh_Server_KeystoreCommitKey(ctx, keyId); + } - /* Translate DMA address for input */ + /* Stream the public key out via the client-supplied DMA buffer. */ + if (ret == 0) { + ret = wc_LmsKey_GetPubLen(key, &pubLen32); + } + if (ret == 0 && req.pub.sz < pubLen32) { + ret = WH_ERROR_BUFFER_SIZE; + } + if (ret == 0) { ret = wh_Server_DmaProcessClientAddress( - ctx, req.input.addr, &inAddr, req.input.sz, - WH_DMA_OPER_CLIENT_READ_PRE, (whServerDmaFlags){0}); - if (ret == WH_ERROR_ACCESS) { - res.dmaAddrStatus.badAddr = req.input; + ctx, (uintptr_t)req.pub.addr, &clientPubAddr, pubLen32, + WH_DMA_OPER_CLIENT_WRITE_PRE, (whServerDmaFlags){0}); + if (ret != 0) { + res.dmaAddrStatus.badAddr = req.pub; } + } + if (ret == 0) { + memcpy(clientPubAddr, key->pub, pubLen32); + res.keyId = wh_KeyId_TranslateToClient(keyId); + res.pubSize = pubLen32; + } + (void)wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.pub.addr, &clientPubAddr, pubLen32, + WH_DMA_OPER_CLIENT_WRITE_POST, (whServerDmaFlags){0}); - /* Resolve key */ - if (ret == WH_ERROR_OK) { - ret = _CmacResolveKey(ctx, key, req.keySz, req.keyId, tmpKey, - &tmpKeyLen); - } + wc_LmsKey_Free(key); - if (ret == WH_ERROR_OK && req.keySz != 0) { - /* Client-supplied key - direct one-shot */ - WH_DEBUG_SERVER_VERBOSE("dma cmac generate oneshot\n"); + (void)wh_MessageCrypto_TranslatePqcStatefulSigKeyGenDmaResponse( + magic, &res, + (whMessageCrypto_PqcStatefulSigKeyGenDmaResponse*)cryptoDataOut); + *outSize = sizeof(res); + return ret; +#endif /* WOLFSSL_LMS_VERIFY_ONLY */ +} - ret = wc_AesCmacGenerate_ex(cmac, out, &len, inAddr, req.input.sz, - tmpKey, (word32)tmpKeyLen, NULL, devId); - } - else if (ret == WH_ERROR_OK) { - /* HSM-local key via keyId - init then generate */ - WH_DEBUG_SERVER_VERBOSE("dma cmac generate oneshot with keyId:%x\n", +static int _HandleLmsSignDma(whServerContext* ctx, uint16_t magic, int devId, + const void* cryptoDataIn, uint16_t inSize, + void* cryptoDataOut, uint16_t* outSize) +{ +#ifdef WOLFSSL_LMS_VERIFY_ONLY + (void)ctx; (void)magic; (void)devId; (void)cryptoDataIn; (void)inSize; + (void)cryptoDataOut; (void)outSize; + return WH_ERROR_NOHANDLER; +#else + int ret; + LmsKey key[1]; + int keyInited = 0; + void* msgAddr = NULL; + void* sigAddr = NULL; + word32 sigLen; + whKeyId keyId; + uint8_t* cacheBuf; + whNvmMetadata* cacheMeta; + whServerStatefulSigBridge bridge; + whMessageCrypto_PqcStatefulSigSignDmaRequest req; + whMessageCrypto_PqcStatefulSigSignDmaResponse res; + + memset(&res, 0, sizeof(res)); + + if (inSize < sizeof(req)) { + return WH_ERROR_BADARGS; + } + ret = wh_MessageCrypto_TranslatePqcStatefulSigSignDmaRequest( + magic, (whMessageCrypto_PqcStatefulSigSignDmaRequest*)cryptoDataIn, + &req); + if (ret != WH_ERROR_OK) { + return ret; + } + + keyId = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, + ctx->comm->client_id, req.keyId); + if (WH_KEYID_ISERASED(keyId)) { + return WH_ERROR_BADARGS; + } + + sigLen = (word32)req.sig.sz; + + /* Hold the NVM lock for the entire load -> sign -> commit sequence so + * concurrent sign requests on the same keyId can't race past each other. + * Pattern from wh_server_counter.c. */ + ret = WH_SERVER_NVM_LOCK(ctx); + if (ret != WH_ERROR_OK) { + return ret; + } + + ret = wh_Server_KeystoreFreshenKey(ctx, keyId, &cacheBuf, &cacheMeta); + if (ret == WH_ERROR_OK) { + ret = wc_LmsKey_Init(key, NULL, devId); + if (ret == 0) { + keyInited = 1; + } + } + if (ret == WH_ERROR_OK) { + ret = wh_Crypto_LmsDeserializeKey(cacheBuf, (uint16_t)cacheMeta->len, + key); + } + if (ret == WH_ERROR_OK) { + ret = _StatefulBridgeFromSlot( + &bridge, ctx, keyId, cacheBuf, cacheMeta, + WOLFHSM_CFG_SERVER_KEYCACHE_BIG_BUFSIZE); + } + if (ret == WH_ERROR_OK) { + (void)wc_LmsKey_SetWriteCb(key, _LmsBridgeWriteCb); + (void)wc_LmsKey_SetReadCb(key, _LmsBridgeReadCb); + (void)wc_LmsKey_SetContext(key, &bridge); + ret = wc_LmsKey_Reload(key); + } + if (ret == WH_ERROR_OK) { + ret = wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.msg.addr, &msgAddr, req.msg.sz, + WH_DMA_OPER_CLIENT_READ_PRE, (whServerDmaFlags){0}); + if (ret != WH_ERROR_OK) { + res.dmaAddrStatus.badAddr = req.msg; + } + } + if (ret == WH_ERROR_OK) { + ret = wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.sig.addr, &sigAddr, req.sig.sz, + WH_DMA_OPER_CLIENT_WRITE_PRE, (whServerDmaFlags){0}); + if (ret != WH_ERROR_OK) { + res.dmaAddrStatus.badAddr = req.sig; + } + } + if (ret == WH_ERROR_OK) { + /* wolfCrypt's flow (verified against wc_lms.c:1439-1474 post-patch): + * 1. wc_hss_sign computes the signature into sig and advances + * key->priv_raw in memory. + * 2. write_private_key (our bridge) is called with the new + * priv_raw and atomically commits it to NVM. + * 3. If the bridge returns anything other than + * WC_LMS_RC_SAVED_TO_NV_MEMORY, wolfCrypt does ForceZero(sig) + * and returns IO_FAILED_E. + * Net effect: a signature is exposed to the caller only if the NVM + * commit succeeded. A process crash anywhere in the sequence either + * (a) leaves the old state in NVM with no signature exposed, or + * (b) commits the new state with the signature lost in transit - + * one wasted index but never an index reused with a fresh sig. */ + ret = wc_LmsKey_Sign(key, sigAddr, &sigLen, msgAddr, (int)req.msg.sz); + if (ret == 0) { + res.sigLen = sigLen; + } + } + + (void)wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.sig.addr, &sigAddr, sigLen, + WH_DMA_OPER_CLIENT_WRITE_POST, (whServerDmaFlags){0}); + (void)wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.msg.addr, &msgAddr, req.msg.sz, + WH_DMA_OPER_CLIENT_READ_POST, (whServerDmaFlags){0}); + + if (keyInited) { + wc_LmsKey_Free(key); + } + + if ((req.options & WH_MESSAGE_CRYPTO_STATEFUL_SIG_OPTIONS_EVICT) != 0) { + (void)wh_Server_KeystoreEvictKey(ctx, keyId); + } + + (void)WH_SERVER_NVM_UNLOCK(ctx); + + (void)wh_MessageCrypto_TranslatePqcStatefulSigSignDmaResponse( + magic, &res, + (whMessageCrypto_PqcStatefulSigSignDmaResponse*)cryptoDataOut); + *outSize = sizeof(res); + return ret; +#endif /* WOLFSSL_LMS_VERIFY_ONLY */ +} + +static int _HandleLmsVerifyDma(whServerContext* ctx, uint16_t magic, int devId, + const void* cryptoDataIn, uint16_t inSize, + void* cryptoDataOut, uint16_t* outSize) +{ + int ret; + LmsKey key[1]; + int keyInited = 0; + void* sigAddr = NULL; + void* msgAddr = NULL; + whKeyId keyId; + whMessageCrypto_PqcStatefulSigVerifyDmaRequest req; + whMessageCrypto_PqcStatefulSigVerifyDmaResponse res; + + memset(&res, 0, sizeof(res)); + + if (inSize < sizeof(req)) { + return WH_ERROR_BADARGS; + } + ret = wh_MessageCrypto_TranslatePqcStatefulSigVerifyDmaRequest( + magic, (whMessageCrypto_PqcStatefulSigVerifyDmaRequest*)cryptoDataIn, + &req); + if (ret != WH_ERROR_OK) { + return ret; + } + + keyId = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, + ctx->comm->client_id, req.keyId); + if (WH_KEYID_ISERASED(keyId)) { + return WH_ERROR_BADARGS; + } + + ret = wc_LmsKey_Init(key, NULL, devId); + if (ret == 0) { + keyInited = 1; + ret = wh_Server_LmsKeyCacheExport(ctx, keyId, key); + } + if (ret == WH_ERROR_OK) { + /* Deserialize leaves the key in PARMSET; wc_LmsKey_Verify needs + * OK or VERIFYONLY. Pub is populated and that's all verify uses. */ + key->state = WC_LMS_STATE_VERIFYONLY; + } + + if (ret == WH_ERROR_OK) { + ret = wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.sig.addr, &sigAddr, req.sig.sz, + WH_DMA_OPER_CLIENT_READ_PRE, (whServerDmaFlags){0}); + if (ret != WH_ERROR_OK) { + res.dmaAddrStatus.badAddr = req.sig; + } + } + if (ret == WH_ERROR_OK) { + ret = wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.msg.addr, &msgAddr, req.msg.sz, + WH_DMA_OPER_CLIENT_READ_PRE, (whServerDmaFlags){0}); + if (ret != WH_ERROR_OK) { + res.dmaAddrStatus.badAddr = req.msg; + } + } + if (ret == WH_ERROR_OK) { + int verifyRet = wc_LmsKey_Verify(key, sigAddr, (word32)req.sig.sz, + msgAddr, (int)req.msg.sz); + if (verifyRet == 0) { + res.res = 1; + } + else if (verifyRet == WC_NO_ERR_TRACE(SIG_VERIFY_E)) { + res.res = 0; + } + else { + ret = verifyRet; + } + } + + (void)wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.sig.addr, &sigAddr, req.sig.sz, + WH_DMA_OPER_CLIENT_READ_POST, (whServerDmaFlags){0}); + (void)wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.msg.addr, &msgAddr, req.msg.sz, + WH_DMA_OPER_CLIENT_READ_POST, (whServerDmaFlags){0}); + + if (keyInited) { + wc_LmsKey_Free(key); + } + + if ((req.options & WH_MESSAGE_CRYPTO_STATEFUL_SIG_OPTIONS_EVICT) != 0) { + (void)wh_Server_KeystoreEvictKey(ctx, keyId); + } + + (void)wh_MessageCrypto_TranslatePqcStatefulSigVerifyDmaResponse( + magic, &res, + (whMessageCrypto_PqcStatefulSigVerifyDmaResponse*)cryptoDataOut); + *outSize = sizeof(res); + return ret; +} + +static int _HandleLmsSigsLeftDma(whServerContext* ctx, uint16_t magic, + int devId, const void* cryptoDataIn, + uint16_t inSize, void* cryptoDataOut, + uint16_t* outSize) +{ +#ifdef WOLFSSL_LMS_VERIFY_ONLY + (void)ctx; (void)magic; (void)devId; (void)cryptoDataIn; (void)inSize; + (void)cryptoDataOut; (void)outSize; + return WH_ERROR_NOHANDLER; +#else + int ret; + LmsKey key[1]; + int keyInited = 0; + int sigsLeft; + whKeyId keyId; + whMessageCrypto_PqcStatefulSigSigsLeftDmaRequest req; + whMessageCrypto_PqcStatefulSigSigsLeftDmaResponse res; + + memset(&res, 0, sizeof(res)); + + if (inSize < sizeof(req)) { + return WH_ERROR_BADARGS; + } + ret = wh_MessageCrypto_TranslatePqcStatefulSigSigsLeftDmaRequest( + magic, + (whMessageCrypto_PqcStatefulSigSigsLeftDmaRequest*)cryptoDataIn, &req); + if (ret != WH_ERROR_OK) { + return ret; + } + + keyId = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, + ctx->comm->client_id, req.keyId); + if (WH_KEYID_ISERASED(keyId)) { + return WH_ERROR_BADARGS; + } + + ret = wc_LmsKey_Init(key, NULL, devId); + if (ret == 0) { + keyInited = 1; + ret = wh_Server_LmsKeyCacheExport(ctx, keyId, key); + } + if (ret == WH_ERROR_OK) { + sigsLeft = wc_LmsKey_SigsLeft(key); + if (sigsLeft >= 0) { + res.sigsLeft = (uint32_t)sigsLeft; + ret = WH_ERROR_OK; + } + } + + if (keyInited) { + wc_LmsKey_Free(key); + } + + (void)wh_MessageCrypto_TranslatePqcStatefulSigSigsLeftDmaResponse( + magic, &res, + (whMessageCrypto_PqcStatefulSigSigsLeftDmaResponse*)cryptoDataOut); + *outSize = sizeof(res); + return ret; +#endif /* WOLFSSL_LMS_VERIFY_ONLY */ +} +#endif /* WOLFSSL_HAVE_LMS */ + +#ifdef WOLFSSL_HAVE_XMSS +/* wolfCrypt's wc_XmssKey_MakeKey calls write_private_key with the freshly + * generated sk and then ForceZero's key->sk (see wc_xmss.c). To get a usable + * sk back into key->sk for the subsequent serialize step, we capture it here + * via a context-pointed buffer. */ +typedef struct { + byte* buf; + word32 cap; + word32 len; +} _XmssSkCapture; + +static enum wc_XmssRc _XmssKeygenWriteCb(const byte* priv, word32 privSz, + void* context) +{ + _XmssSkCapture* cap = (_XmssSkCapture*)context; + if ((cap == NULL) || (priv == NULL) || (privSz > cap->cap)) { + return WC_XMSS_RC_WRITE_FAIL; + } + memcpy(cap->buf, priv, privSz); + cap->len = privSz; + return WC_XMSS_RC_SAVED_TO_NV_MEMORY; +} +static enum wc_XmssRc _XmssDummyReadCb(byte* priv, word32 privSz, + void* context) +{ + (void)priv; (void)privSz; (void)context; + return WC_XMSS_RC_READ_TO_MEMORY; +} + +static int _HandleXmssKeyGenDma(whServerContext* ctx, uint16_t magic, + int devId, const void* cryptoDataIn, + uint16_t inSize, void* cryptoDataOut, + uint16_t* outSize) +{ +#ifdef WOLFSSL_XMSS_VERIFY_ONLY + (void)ctx; (void)magic; (void)devId; (void)cryptoDataIn; (void)inSize; + (void)cryptoDataOut; (void)outSize; + return WH_ERROR_NOHANDLER; +#else + int ret; + XmssKey key[1]; + void* clientPubAddr = NULL; + word32 pubLen32 = 0; + word32 privLen32 = 0; + whKeyId keyId; + whMessageCrypto_PqcStatefulSigKeyGenDmaRequest req; + whMessageCrypto_PqcStatefulSigKeyGenDmaResponse res; + _XmssSkCapture sk_cap; + /* WC_XMSS_MAX_SK comes from the params table; sized for the largest + * supported XMSS variant. The variants enabled in user_settings.h all + * fit in 4 KiB, but use the wolfCrypt-reported priv length to be + * exact. */ + byte sk_buf[4096]; + + memset(&res, 0, sizeof(res)); + memset(&sk_cap, 0, sizeof(sk_cap)); + sk_cap.buf = sk_buf; + sk_cap.cap = (word32)sizeof(sk_buf); + + if (inSize < sizeof(req)) { + return WH_ERROR_BADARGS; + } + ret = wh_MessageCrypto_TranslatePqcStatefulSigKeyGenDmaRequest( + magic, (whMessageCrypto_PqcStatefulSigKeyGenDmaRequest*)cryptoDataIn, + &req); + if (ret != WH_ERROR_OK) { + return ret; + } + + /* xmssParamStr arrives via the request struct. The client + * (wh_Client_XmssMakeKeyDma) currently has a TODO for populating it; for + * v1 the server enforces NUL-termination. */ + req.xmssParamStr[sizeof(req.xmssParamStr) - 1] = '\0'; + + ret = wc_XmssKey_Init(key, NULL, devId); + if (ret != 0) { + return ret; + } + + ret = wc_XmssKey_SetParamStr(key, req.xmssParamStr); + if (ret == 0) { + /* Use a real capture cb: wolfCrypt ForceZero's key->sk after MakeKey + * (see wc_xmss.c), so we copy sk into sk_buf via the cb and restore + * it on key->sk before serializing into the cache slot. */ + ret = wc_XmssKey_SetWriteCb(key, _XmssKeygenWriteCb); + } + if (ret == 0) { + ret = wc_XmssKey_SetReadCb(key, _XmssDummyReadCb); + } + if (ret == 0) { + ret = wc_XmssKey_SetContext(key, &sk_cap); + } + if (ret == 0) { + ret = wc_XmssKey_MakeKey(key, ctx->crypto->rng); + } + + if (ret == 0) { + /* Sanity-check the captured sk size against what wolfCrypt expects. */ + ret = wc_XmssKey_GetPrivLen(key, &privLen32); + if ((ret == 0) && (sk_cap.len != privLen32)) { + ret = WH_ERROR_ABORTED; + } + } + if (ret == 0) { + /* Restore sk so SerializeKey captures real bytes, not the + * MakeKey-zeroed buffer. */ + memcpy(key->sk, sk_cap.buf, sk_cap.len); + } + + if (ret == 0) { + keyId = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, + ctx->comm->client_id, req.keyId); + if (WH_KEYID_ISERASED(keyId)) { + ret = wh_Server_KeystoreGetUniqueId(ctx, &keyId); + } + } + + if (ret == 0) { + ret = wh_Server_XmssKeyCacheImport(ctx, key, req.xmssParamStr, keyId, + req.flags, (uint16_t)req.labelSize, + req.label); + } + + if ((ret == 0) && ((req.flags & WH_NVM_FLAGS_EPHEMERAL) == 0)) { + ret = wh_Server_KeystoreCommitKey(ctx, keyId); + } + + if (ret == 0) { + ret = wc_XmssKey_GetPubLen(key, &pubLen32); + } + if (ret == 0 && req.pub.sz < pubLen32) { + ret = WH_ERROR_BUFFER_SIZE; + } + if (ret == 0) { + ret = wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.pub.addr, &clientPubAddr, pubLen32, + WH_DMA_OPER_CLIENT_WRITE_PRE, (whServerDmaFlags){0}); + if (ret != 0) { + res.dmaAddrStatus.badAddr = req.pub; + } + } + if (ret == 0) { + memcpy(clientPubAddr, key->pk, pubLen32); + res.keyId = wh_KeyId_TranslateToClient(keyId); + res.pubSize = pubLen32; + } + (void)wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.pub.addr, &clientPubAddr, pubLen32, + WH_DMA_OPER_CLIENT_WRITE_POST, (whServerDmaFlags){0}); + + wc_XmssKey_Free(key); + + (void)wh_MessageCrypto_TranslatePqcStatefulSigKeyGenDmaResponse( + magic, &res, + (whMessageCrypto_PqcStatefulSigKeyGenDmaResponse*)cryptoDataOut); + *outSize = sizeof(res); + return ret; +#endif /* WOLFSSL_XMSS_VERIFY_ONLY */ +} + +static int _HandleXmssSignDma(whServerContext* ctx, uint16_t magic, int devId, + const void* cryptoDataIn, uint16_t inSize, + void* cryptoDataOut, uint16_t* outSize) +{ +#ifdef WOLFSSL_XMSS_VERIFY_ONLY + (void)ctx; (void)magic; (void)devId; (void)cryptoDataIn; (void)inSize; + (void)cryptoDataOut; (void)outSize; + return WH_ERROR_NOHANDLER; +#else + int ret; + XmssKey key[1]; + int keyInited = 0; + void* msgAddr = NULL; + void* sigAddr = NULL; + word32 sigLen; + whKeyId keyId; + uint8_t* cacheBuf; + whNvmMetadata* cacheMeta; + whServerStatefulSigBridge bridge; + whMessageCrypto_PqcStatefulSigSignDmaRequest req; + whMessageCrypto_PqcStatefulSigSignDmaResponse res; + + memset(&res, 0, sizeof(res)); + + if (inSize < sizeof(req)) { + return WH_ERROR_BADARGS; + } + ret = wh_MessageCrypto_TranslatePqcStatefulSigSignDmaRequest( + magic, (whMessageCrypto_PqcStatefulSigSignDmaRequest*)cryptoDataIn, + &req); + if (ret != WH_ERROR_OK) { + return ret; + } + + keyId = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, + ctx->comm->client_id, req.keyId); + if (WH_KEYID_ISERASED(keyId)) { + return WH_ERROR_BADARGS; + } + + sigLen = (word32)req.sig.sz; + + /* See _HandleLmsSignDma for the NVM-lock rationale. */ + ret = WH_SERVER_NVM_LOCK(ctx); + if (ret != WH_ERROR_OK) { + return ret; + } + + ret = wh_Server_KeystoreFreshenKey(ctx, keyId, &cacheBuf, &cacheMeta); + if (ret == WH_ERROR_OK) { + ret = wc_XmssKey_Init(key, NULL, devId); + if (ret == 0) { + keyInited = 1; + } + } + if (ret == WH_ERROR_OK) { + ret = wh_Crypto_XmssDeserializeKey(cacheBuf, (uint16_t)cacheMeta->len, + key); + } + if (ret == WH_ERROR_OK) { + ret = _StatefulBridgeFromSlot( + &bridge, ctx, keyId, cacheBuf, cacheMeta, + WOLFHSM_CFG_SERVER_KEYCACHE_BIG_BUFSIZE); + } + if (ret == WH_ERROR_OK) { + (void)wc_XmssKey_SetWriteCb(key, _XmssBridgeWriteCb); + (void)wc_XmssKey_SetReadCb(key, _XmssBridgeReadCb); + (void)wc_XmssKey_SetContext(key, &bridge); + ret = wc_XmssKey_Reload(key); + } + if (ret == WH_ERROR_OK) { + ret = wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.msg.addr, &msgAddr, req.msg.sz, + WH_DMA_OPER_CLIENT_READ_PRE, (whServerDmaFlags){0}); + if (ret != WH_ERROR_OK) { + res.dmaAddrStatus.badAddr = req.msg; + } + } + if (ret == WH_ERROR_OK) { + ret = wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.sig.addr, &sigAddr, req.sig.sz, + WH_DMA_OPER_CLIENT_WRITE_PRE, (whServerDmaFlags){0}); + if (ret != WH_ERROR_OK) { + res.dmaAddrStatus.badAddr = req.sig; + } + } + if (ret == WH_ERROR_OK) { + ret = wc_XmssKey_Sign(key, sigAddr, &sigLen, msgAddr, (int)req.msg.sz); + if (ret == 0) { + res.sigLen = sigLen; + } + } + + (void)wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.sig.addr, &sigAddr, sigLen, + WH_DMA_OPER_CLIENT_WRITE_POST, (whServerDmaFlags){0}); + (void)wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.msg.addr, &msgAddr, req.msg.sz, + WH_DMA_OPER_CLIENT_READ_POST, (whServerDmaFlags){0}); + + if (keyInited) { + wc_XmssKey_Free(key); + } + + if ((req.options & WH_MESSAGE_CRYPTO_STATEFUL_SIG_OPTIONS_EVICT) != 0) { + (void)wh_Server_KeystoreEvictKey(ctx, keyId); + } + + (void)WH_SERVER_NVM_UNLOCK(ctx); + + (void)wh_MessageCrypto_TranslatePqcStatefulSigSignDmaResponse( + magic, &res, + (whMessageCrypto_PqcStatefulSigSignDmaResponse*)cryptoDataOut); + *outSize = sizeof(res); + return ret; +#endif /* WOLFSSL_XMSS_VERIFY_ONLY */ +} + +static int _HandleXmssVerifyDma(whServerContext* ctx, uint16_t magic, + int devId, const void* cryptoDataIn, + uint16_t inSize, void* cryptoDataOut, + uint16_t* outSize) +{ + int ret; + XmssKey key[1]; + int keyInited = 0; + void* sigAddr = NULL; + void* msgAddr = NULL; + whKeyId keyId; + whMessageCrypto_PqcStatefulSigVerifyDmaRequest req; + whMessageCrypto_PqcStatefulSigVerifyDmaResponse res; + + memset(&res, 0, sizeof(res)); + + if (inSize < sizeof(req)) { + return WH_ERROR_BADARGS; + } + ret = wh_MessageCrypto_TranslatePqcStatefulSigVerifyDmaRequest( + magic, (whMessageCrypto_PqcStatefulSigVerifyDmaRequest*)cryptoDataIn, + &req); + if (ret != WH_ERROR_OK) { + return ret; + } + + keyId = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, + ctx->comm->client_id, req.keyId); + if (WH_KEYID_ISERASED(keyId)) { + return WH_ERROR_BADARGS; + } + + ret = wc_XmssKey_Init(key, NULL, devId); + if (ret == 0) { + keyInited = 1; + ret = wh_Server_XmssKeyCacheExport(ctx, keyId, key); + } + if (ret == WH_ERROR_OK) { + /* Deserialize leaves the key in PARMSET; wc_XmssKey_Verify needs + * OK or VERIFYONLY. Pub is populated and that's all verify uses. */ + key->state = WC_XMSS_STATE_VERIFYONLY; + } + + if (ret == WH_ERROR_OK) { + ret = wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.sig.addr, &sigAddr, req.sig.sz, + WH_DMA_OPER_CLIENT_READ_PRE, (whServerDmaFlags){0}); + if (ret != WH_ERROR_OK) { + res.dmaAddrStatus.badAddr = req.sig; + } + } + if (ret == WH_ERROR_OK) { + ret = wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.msg.addr, &msgAddr, req.msg.sz, + WH_DMA_OPER_CLIENT_READ_PRE, (whServerDmaFlags){0}); + if (ret != WH_ERROR_OK) { + res.dmaAddrStatus.badAddr = req.msg; + } + } + if (ret == WH_ERROR_OK) { + int verifyRet = wc_XmssKey_Verify(key, sigAddr, (word32)req.sig.sz, + msgAddr, (int)req.msg.sz); + if (verifyRet == 0) { + res.res = 1; + } + else if (verifyRet == WC_NO_ERR_TRACE(SIG_VERIFY_E)) { + res.res = 0; + } + else { + ret = verifyRet; + } + } + + (void)wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.sig.addr, &sigAddr, req.sig.sz, + WH_DMA_OPER_CLIENT_READ_POST, (whServerDmaFlags){0}); + (void)wh_Server_DmaProcessClientAddress( + ctx, (uintptr_t)req.msg.addr, &msgAddr, req.msg.sz, + WH_DMA_OPER_CLIENT_READ_POST, (whServerDmaFlags){0}); + + if (keyInited) { + wc_XmssKey_Free(key); + } + + if ((req.options & WH_MESSAGE_CRYPTO_STATEFUL_SIG_OPTIONS_EVICT) != 0) { + (void)wh_Server_KeystoreEvictKey(ctx, keyId); + } + + (void)wh_MessageCrypto_TranslatePqcStatefulSigVerifyDmaResponse( + magic, &res, + (whMessageCrypto_PqcStatefulSigVerifyDmaResponse*)cryptoDataOut); + *outSize = sizeof(res); + return ret; +} + +static int _HandleXmssSigsLeftDma(whServerContext* ctx, uint16_t magic, + int devId, const void* cryptoDataIn, + uint16_t inSize, void* cryptoDataOut, + uint16_t* outSize) +{ +#ifdef WOLFSSL_XMSS_VERIFY_ONLY + (void)ctx; (void)magic; (void)devId; (void)cryptoDataIn; (void)inSize; + (void)cryptoDataOut; (void)outSize; + return WH_ERROR_NOHANDLER; +#else + int ret; + XmssKey key[1]; + int keyInited = 0; + int sigsLeft; + whKeyId keyId; + uint8_t* cacheBuf; + whNvmMetadata* cacheMeta; + whServerStatefulSigBridge bridge; + whMessageCrypto_PqcStatefulSigSigsLeftDmaRequest req; + whMessageCrypto_PqcStatefulSigSigsLeftDmaResponse res; + + memset(&res, 0, sizeof(res)); + + if (inSize < sizeof(req)) { + return WH_ERROR_BADARGS; + } + ret = wh_MessageCrypto_TranslatePqcStatefulSigSigsLeftDmaRequest( + magic, + (whMessageCrypto_PqcStatefulSigSigsLeftDmaRequest*)cryptoDataIn, &req); + if (ret != WH_ERROR_OK) { + return ret; + } + + keyId = wh_KeyId_TranslateFromClient(WH_KEYTYPE_CRYPTO, + ctx->comm->client_id, req.keyId); + if (WH_KEYID_ISERASED(keyId)) { + return WH_ERROR_BADARGS; + } + + ret = wh_Server_KeystoreFreshenKey(ctx, keyId, &cacheBuf, &cacheMeta); + if (ret == WH_ERROR_OK) { + ret = wc_XmssKey_Init(key, NULL, devId); + if (ret == 0) { + keyInited = 1; + } + } + if (ret == WH_ERROR_OK) { + ret = wh_Crypto_XmssDeserializeKey(cacheBuf, (uint16_t)cacheMeta->len, + key); + } + if (ret == WH_ERROR_OK) { + ret = _StatefulBridgeFromSlot( + &bridge, ctx, keyId, cacheBuf, cacheMeta, + WOLFHSM_CFG_SERVER_KEYCACHE_BIG_BUFSIZE); + } + if (ret == WH_ERROR_OK) { + /* Reload uses the bridge ReadCb to populate sk from the cached blob, + * then transitions state to OK so SigsLeft can run. */ + (void)wc_XmssKey_SetWriteCb(key, _XmssBridgeWriteCb); + (void)wc_XmssKey_SetReadCb(key, _XmssBridgeReadCb); + (void)wc_XmssKey_SetContext(key, &bridge); + ret = wc_XmssKey_Reload(key); + } + if (ret == WH_ERROR_OK) { + sigsLeft = wc_XmssKey_SigsLeft(key); + if (sigsLeft >= 0) { + res.sigsLeft = (uint32_t)sigsLeft; + } + } + + if (keyInited) { + wc_XmssKey_Free(key); + } + + (void)wh_MessageCrypto_TranslatePqcStatefulSigSigsLeftDmaResponse( + magic, &res, + (whMessageCrypto_PqcStatefulSigSigsLeftDmaResponse*)cryptoDataOut); + *outSize = sizeof(res); + return ret; +#endif /* WOLFSSL_XMSS_VERIFY_ONLY */ +} +#endif /* WOLFSSL_HAVE_XMSS */ + +#if defined(WOLFSSL_HAVE_LMS) || defined(WOLFSSL_HAVE_XMSS) +static int _HandlePqcStatefulSigAlgorithmDma( + whServerContext* ctx, uint16_t magic, int devId, const void* cryptoDataIn, + uint16_t cryptoInSize, void* cryptoDataOut, uint16_t* cryptoOutSize, + uint32_t pkAlgoType, uint32_t pqAlgoType) +{ + int ret = WH_ERROR_NOHANDLER; + + switch (pqAlgoType) { +#ifdef WOLFSSL_HAVE_LMS + case WC_PQC_STATEFUL_SIG_TYPE_LMS: + switch (pkAlgoType) { + case WC_PK_TYPE_PQC_STATEFUL_SIG_KEYGEN: + ret = _HandleLmsKeyGenDma(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + cryptoOutSize); + break; + case WC_PK_TYPE_PQC_STATEFUL_SIG_SIGN: + ret = _HandleLmsSignDma(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + cryptoOutSize); + break; + case WC_PK_TYPE_PQC_STATEFUL_SIG_VERIFY: + ret = _HandleLmsVerifyDma(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + cryptoOutSize); + break; + case WC_PK_TYPE_PQC_STATEFUL_SIG_SIGS_LEFT: + ret = _HandleLmsSigsLeftDma(ctx, magic, devId, + cryptoDataIn, cryptoInSize, + cryptoDataOut, cryptoOutSize); + break; + default: + ret = WH_ERROR_NOHANDLER; + break; + } + break; +#endif /* WOLFSSL_HAVE_LMS */ +#ifdef WOLFSSL_HAVE_XMSS + case WC_PQC_STATEFUL_SIG_TYPE_XMSS: + switch (pkAlgoType) { + case WC_PK_TYPE_PQC_STATEFUL_SIG_KEYGEN: + ret = _HandleXmssKeyGenDma(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + cryptoOutSize); + break; + case WC_PK_TYPE_PQC_STATEFUL_SIG_SIGN: + ret = _HandleXmssSignDma(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + cryptoOutSize); + break; + case WC_PK_TYPE_PQC_STATEFUL_SIG_VERIFY: + ret = _HandleXmssVerifyDma(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + cryptoOutSize); + break; + case WC_PK_TYPE_PQC_STATEFUL_SIG_SIGS_LEFT: + ret = _HandleXmssSigsLeftDma(ctx, magic, devId, + cryptoDataIn, cryptoInSize, + cryptoDataOut, cryptoOutSize); + break; + default: + ret = WH_ERROR_NOHANDLER; + break; + } + break; +#endif /* WOLFSSL_HAVE_XMSS */ + default: + ret = WH_ERROR_NOHANDLER; + break; + } + + return ret; +} +#endif /* WOLFSSL_HAVE_LMS || WOLFSSL_HAVE_XMSS */ + +static int _HandlePqcKemAlgorithmDma(whServerContext* ctx, uint16_t magic, + int devId, const void* cryptoDataIn, + uint16_t cryptoInSize, void* cryptoDataOut, + uint16_t* cryptoOutSize, + uint32_t pkAlgoType, uint32_t pqAlgoType) +{ + int ret = WH_ERROR_NOHANDLER; + + switch (pqAlgoType) { + case WC_PQC_KEM_TYPE_KYBER: { + switch (pkAlgoType) { + case WC_PK_TYPE_PQC_KEM_KEYGEN: + ret = _HandleMlKemKeyGenDma(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + cryptoOutSize); + break; + case WC_PK_TYPE_PQC_KEM_ENCAPS: + ret = _HandleMlKemEncapsDma(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + cryptoOutSize); + break; + case WC_PK_TYPE_PQC_KEM_DECAPS: + ret = _HandleMlKemDecapsDma(ctx, magic, devId, cryptoDataIn, + cryptoInSize, cryptoDataOut, + cryptoOutSize); + break; + default: + ret = WH_ERROR_NOHANDLER; + break; + } + } break; + default: + ret = WH_ERROR_NOHANDLER; + break; + } + + return ret; +} +#endif /* WOLFSSL_HAVE_MLKEM */ + +#if defined(WOLFSSL_CMAC) && !defined(NO_AES) && defined(WOLFSSL_AES_DIRECT) +static int _HandleCmacDma(whServerContext* ctx, uint16_t magic, int devId, + uint16_t seq, const void* cryptoDataIn, + uint16_t inSize, void* cryptoDataOut, + uint16_t* outSize) +{ + (void)seq; + + int ret = 0; + whMessageCrypto_CmacAesDmaRequest req; + whMessageCrypto_CmacAesDmaResponse res; + + if (inSize < sizeof(whMessageCrypto_CmacAesDmaRequest)) { + return WH_ERROR_BADARGS; + } + + /* Translate request */ + ret = wh_MessageCrypto_TranslateCmacAesDmaRequest( + magic, (whMessageCrypto_CmacAesDmaRequest*)cryptoDataIn, &req); + if (ret != WH_ERROR_OK) { + return ret; + } + + /* Validate variable-length fields fit within inSize */ + uint32_t available = inSize - sizeof(whMessageCrypto_CmacAesDmaRequest); + if (req.keySz > available) { + return WH_ERROR_BADARGS; + } + if (req.keySz > AES_256_KEY_SIZE) { + return WH_ERROR_BADARGS; + } + + word32 len; + + /* Pointers to inline trailing data */ + uint8_t* key = + (uint8_t*)(cryptoDataIn) + sizeof(whMessageCrypto_CmacAesDmaRequest); + uint8_t* out = + (uint8_t*)(cryptoDataOut) + sizeof(whMessageCrypto_CmacAesDmaResponse); + + memset(&res, 0, sizeof(res)); + + /* DMA translated address for input */ + void* inAddr = NULL; + + uint8_t tmpKey[AES_256_KEY_SIZE]; + uint32_t tmpKeyLen = sizeof(tmpKey); + Cmac cmac[1]; + + /* Attempt oneshot if input and output are both present */ + if (req.input.sz != 0 && req.outSz != 0) { + len = req.outSz; + + /* Translate DMA address for input */ + ret = wh_Server_DmaProcessClientAddress( + ctx, req.input.addr, &inAddr, req.input.sz, + WH_DMA_OPER_CLIENT_READ_PRE, (whServerDmaFlags){0}); + if (ret == WH_ERROR_ACCESS) { + res.dmaAddrStatus.badAddr = req.input; + } + + /* Resolve key */ + if (ret == WH_ERROR_OK) { + ret = _CmacResolveKey(ctx, key, req.keySz, req.keyId, tmpKey, + &tmpKeyLen); + } + + if (ret == WH_ERROR_OK && req.keySz != 0) { + /* Client-supplied key - direct one-shot */ + WH_DEBUG_SERVER_VERBOSE("dma cmac generate oneshot\n"); + + ret = wc_AesCmacGenerate_ex(cmac, out, &len, inAddr, req.input.sz, + tmpKey, (word32)tmpKeyLen, NULL, devId); + } + else if (ret == WH_ERROR_OK) { + /* HSM-local key via keyId - init then generate */ + WH_DEBUG_SERVER_VERBOSE("dma cmac generate oneshot with keyId:%x\n", req.keyId); ret = wc_InitCmac_ex(cmac, tmpKey, (word32)tmpKeyLen, WC_CMAC_AES, @@ -6949,6 +8174,17 @@ int wh_Server_HandleCryptoDmaRequest(whServerContext* ctx, uint16_t magic, rqstHeader.algoSubType); break; #endif /* WOLFSSL_HAVE_MLKEM */ +#if defined(WOLFSSL_HAVE_LMS) || defined(WOLFSSL_HAVE_XMSS) + case WC_PK_TYPE_PQC_STATEFUL_SIG_KEYGEN: + case WC_PK_TYPE_PQC_STATEFUL_SIG_SIGN: + case WC_PK_TYPE_PQC_STATEFUL_SIG_VERIFY: + case WC_PK_TYPE_PQC_STATEFUL_SIG_SIGS_LEFT: + ret = _HandlePqcStatefulSigAlgorithmDma( + ctx, magic, devId, cryptoDataIn, cryptoInSize, + cryptoDataOut, &cryptoOutSize, rqstHeader.algoType, + rqstHeader.algoSubType); + break; +#endif /* WOLFSSL_HAVE_LMS || WOLFSSL_HAVE_XMSS */ #ifdef HAVE_ED25519 case WC_PK_TYPE_ED25519_SIGN: ret = _HandleEd25519SignDma(ctx, magic, devId, cryptoDataIn, diff --git a/test/config/user_settings.h b/test/config/user_settings.h index e6fdf0629..25f6a4d9d 100644 --- a/test/config/user_settings.h +++ b/test/config/user_settings.h @@ -143,6 +143,14 @@ #define WOLFSSL_HAVE_MLKEM #define WOLFSSL_WC_MLKEM +/* LMS / HSS Options (RFC 8554, NIST SP 800-208) */ +#define WOLFSSL_HAVE_LMS +#define WOLFSSL_WC_LMS + +/* XMSS / XMSS^MT Options (RFC 8391, NIST SP 800-208) */ +#define WOLFSSL_HAVE_XMSS +#define WOLFSSL_WC_XMSS + /* Ed25519 Options */ #define HAVE_ED25519 diff --git a/test/wh_test_check_struct_padding.c b/test/wh_test_check_struct_padding.c index a0e549132..3fb7c7c70 100644 --- a/test/wh_test_check_struct_padding.c +++ b/test/wh_test_check_struct_padding.c @@ -152,6 +152,14 @@ whMessageCrypto_MlKemEncapsDmaRequest pkMlkemEncapsDmaReq; whMessageCrypto_MlKemEncapsDmaResponse pkMlkemEncapsDmaRes; whMessageCrypto_MlKemDecapsDmaRequest pkMlkemDecapsDmaReq; whMessageCrypto_MlKemDecapsDmaResponse pkMlkemDecapsDmaRes; +whMessageCrypto_PqcStatefulSigKeyGenDmaRequest pqStatefulSigKeygenDmaReq; +whMessageCrypto_PqcStatefulSigKeyGenDmaResponse pqStatefulSigKeygenDmaRes; +whMessageCrypto_PqcStatefulSigSignDmaRequest pqStatefulSigSignDmaReq; +whMessageCrypto_PqcStatefulSigSignDmaResponse pqStatefulSigSignDmaRes; +whMessageCrypto_PqcStatefulSigVerifyDmaRequest pqStatefulSigVerifyDmaReq; +whMessageCrypto_PqcStatefulSigVerifyDmaResponse pqStatefulSigVerifyDmaRes; +whMessageCrypto_PqcStatefulSigSigsLeftDmaRequest pqStatefulSigSigsLeftDmaReq; +whMessageCrypto_PqcStatefulSigSigsLeftDmaResponse pqStatefulSigSigsLeftDmaRes; #endif /* WOLFHSM_CFG_DMA */ #endif /* !WOLFHSM_CFG_NO_CRYPTO */ diff --git a/test/wh_test_crypto.c b/test/wh_test_crypto.c index 189f549de..a4bd7a301 100644 --- a/test/wh_test_crypto.c +++ b/test/wh_test_crypto.c @@ -33,6 +33,12 @@ #include "wolfssl/wolfcrypt/kdf.h" #include "wolfssl/wolfcrypt/ed25519.h" #include "wolfssl/wolfcrypt/wc_mlkem.h" +#if defined(WOLFSSL_HAVE_LMS) +#include "wolfssl/wolfcrypt/wc_lms.h" +#endif +#if defined(WOLFSSL_HAVE_XMSS) +#include "wolfssl/wolfcrypt/wc_xmss.h" +#endif #include "wolfhsm/wh_error.h" @@ -8557,6 +8563,304 @@ static int whTestCrypto_MlKemDmaClient(whClientContext* ctx, int devId, !defined(WOLFSSL_MLKEM_NO_DECAPSULATE) */ #endif /* WOLFSSL_HAVE_MLKEM */ +#if defined(WOLFHSM_CFG_DMA) && \ + defined(WOLFSSL_HAVE_LMS) && !defined(WOLFSSL_LMS_VERIFY_ONLY) +/* L=1, H=5, W=8 keeps the signature ~1.3 KB and gives 2^5 = 32 signatures. */ +#define WH_TEST_LMS_LEVELS (1) +#define WH_TEST_LMS_HEIGHT (5) +#define WH_TEST_LMS_WINTERNITZ (8) +/* Generous buffer that fits L1_H5_W8 (~1328) and any W<8 variant of the same + * height (W=1 ~8688). Keeps off the stack so ASAN builds stay happy. */ +static byte whTest_LmsSigBuf[8800]; + +static int whTestCrypto_LmsCryptoCb(whClientContext* ctx, int devId, + WC_RNG* rng) +{ + int ret = 0; + LmsKey key[1]; + int keyInited = 0; + word32 sigLen = 0; + word32 sigCap = 0; + const byte msg[] = "wolfHSM LMS cryptocb test"; + word32 msgSz = (word32)sizeof(msg) - 1; + + (void)rng; + + memset(whTest_LmsSigBuf, 0, sizeof(whTest_LmsSigBuf)); + + ret = wc_LmsKey_Init(key, NULL, devId); + if (ret != 0) { + WH_ERROR_PRINT("Failed wc_LmsKey_Init devId=0x%X ret=%d\n", devId, ret); + return ret; + } + keyInited = 1; + + if (ret == 0) { + ret = wc_LmsKey_SetParameters(key, WH_TEST_LMS_LEVELS, + WH_TEST_LMS_HEIGHT, + WH_TEST_LMS_WINTERNITZ); + if (ret != 0) { + WH_ERROR_PRINT("Failed LMS SetParameters ret=%d\n", ret); + } + } + + if (ret == 0) { + ret = wc_LmsKey_GetSigLen(key, &sigCap); + if (ret != 0) { + WH_ERROR_PRINT("Failed LMS GetSigLen ret=%d\n", ret); + } + else if (sigCap > sizeof(whTest_LmsSigBuf)) { + WH_ERROR_PRINT("LMS sig buffer too small: need=%u have=%u\n", + (unsigned)sigCap, + (unsigned)sizeof(whTest_LmsSigBuf)); + ret = BUFFER_E; + } + } + + /* MakeKey via cryptocb: server caches private key (ephemeral) and + * returns the public key over DMA. */ + if (ret == 0) { + ret = wc_LmsKey_MakeKey(key, rng); + if (ret != 0) { + WH_ERROR_PRINT("Failed LMS MakeKey ret=%d\n", ret); + } + } + + /* wc_LmsKey_SigsLeft returns a boolean: nonzero = signatures available, + * 0 = exhausted. Fresh key should report nonzero. */ + if (ret == 0) { + if (wc_LmsKey_SigsLeft(key) == 0) { + WH_ERROR_PRINT("LMS reported exhausted on fresh key\n"); + ret = -1; + } + } + + /* Sign via cryptocb. */ + if (ret == 0) { + sigLen = sigCap; + ret = wc_LmsKey_Sign(key, whTest_LmsSigBuf, &sigLen, msg, (int)msgSz); + if (ret != 0) { + WH_ERROR_PRINT("Failed LMS Sign ret=%d\n", ret); + } + else if (sigLen != sigCap) { + WH_ERROR_PRINT("LMS Sign produced unexpected length=%u expected=%u\n", + (unsigned)sigLen, (unsigned)sigCap); + ret = -1; + } + } + + /* Verify the signature via cryptocb. */ + if (ret == 0) { + ret = wc_LmsKey_Verify(key, whTest_LmsSigBuf, sigLen, msg, (int)msgSz); + if (ret != 0) { + WH_ERROR_PRINT("Failed LMS Verify ret=%d\n", ret); + } + } + + /* Tampered signature must fail to verify. */ + if (ret == 0) { + whTest_LmsSigBuf[0] ^= 0xFF; + ret = wc_LmsKey_Verify(key, whTest_LmsSigBuf, sigLen, msg, (int)msgSz); + whTest_LmsSigBuf[0] ^= 0xFF; + if (ret == 0) { + WH_ERROR_PRINT("LMS Verify unexpectedly accepted tampered sig\n"); + ret = -1; + } + else { + ret = 0; + } + } + + /* Wrong message must also fail to verify. */ + if (ret == 0) { + const byte wrongMsg[] = "wolfHSM LMS cryptocb wrong"; + ret = wc_LmsKey_Verify(key, whTest_LmsSigBuf, sigLen, wrongMsg, + (int)(sizeof(wrongMsg) - 1)); + if (ret == 0) { + WH_ERROR_PRINT("LMS Verify unexpectedly accepted wrong message\n"); + ret = -1; + } + else { + ret = 0; + } + } + + /* H=5 means 32 sigs total; after one sign, the key is still not + * exhausted. */ + if (ret == 0) { + if (wc_LmsKey_SigsLeft(key) == 0) { + WH_ERROR_PRINT("LMS reported exhausted after one sign\n"); + ret = -1; + } + } + + if (keyInited) { + whKeyId evictId = WH_KEYID_ERASED; + if ((wh_Client_LmsGetKeyId(key, &evictId) == 0) && + !WH_KEYID_ISERASED(evictId)) { + int evictRet = wh_Client_KeyEvict(ctx, evictId); + if ((evictRet != 0) && (ret == 0)) { + WH_ERROR_PRINT("Failed LMS evict keyId=0x%X ret=%d\n", + (unsigned)evictId, evictRet); + ret = evictRet; + } + } + wc_LmsKey_Free(key); + } + + if (ret == 0) { + WH_TEST_PRINT("LMS CryptoCb DEVID=0x%X SUCCESS\n", devId); + } + + return ret; +} +#endif /* WOLFHSM_CFG_DMA && WOLFSSL_HAVE_LMS && !WOLFSSL_LMS_VERIFY_ONLY */ + +#if defined(WOLFHSM_CFG_DMA) && \ + defined(WOLFSSL_HAVE_XMSS) && !defined(WOLFSSL_XMSS_VERIFY_ONLY) +/* "XMSS-SHA2_10_256" is the smallest standardized XMSS parameter set + * (height 10, 1024 signatures). pubLen=68, sigLen=2500. */ +#define WH_TEST_XMSS_PARAM_STR "XMSS-SHA2_10_256" +static byte whTest_XmssSigBuf[2500]; + +static int whTestCrypto_XmssCryptoCb(whClientContext* ctx, int devId, + WC_RNG* rng) +{ + int ret = 0; + XmssKey key[1]; + int keyInited = 0; + word32 sigLen = 0; + word32 sigCap = 0; + const byte msg[] = "wolfHSM XMSS cryptocb test"; + word32 msgSz = (word32)sizeof(msg) - 1; + + (void)rng; + + memset(whTest_XmssSigBuf, 0, sizeof(whTest_XmssSigBuf)); + + ret = wc_XmssKey_Init(key, NULL, devId); + if (ret != 0) { + WH_ERROR_PRINT("Failed wc_XmssKey_Init devId=0x%X ret=%d\n", devId, ret); + return ret; + } + keyInited = 1; + + if (ret == 0) { + ret = wc_XmssKey_SetParamStr(key, WH_TEST_XMSS_PARAM_STR); + if (ret != 0) { + WH_ERROR_PRINT("Failed XMSS SetParamStr=\"%s\" ret=%d\n", + WH_TEST_XMSS_PARAM_STR, ret); + } + } + + if (ret == 0) { + ret = wc_XmssKey_GetSigLen(key, &sigCap); + if (ret != 0) { + WH_ERROR_PRINT("Failed XMSS GetSigLen ret=%d\n", ret); + } + else if (sigCap > sizeof(whTest_XmssSigBuf)) { + WH_ERROR_PRINT("XMSS sig buffer too small: need=%u have=%u\n", + (unsigned)sigCap, + (unsigned)sizeof(whTest_XmssSigBuf)); + ret = BUFFER_E; + } + } + + /* MakeKey via cryptocb: server caches private key (ephemeral) and + * returns the public key over DMA. */ + if (ret == 0) { + ret = wc_XmssKey_MakeKey(key, rng); + if (ret != 0) { + WH_ERROR_PRINT("Failed XMSS MakeKey ret=%d\n", ret); + } + } + + /* wc_XmssKey_SigsLeft returns a boolean: nonzero = signatures available, + * 0 = exhausted. */ + if (ret == 0) { + if (wc_XmssKey_SigsLeft(key) == 0) { + WH_ERROR_PRINT("XMSS reported exhausted on fresh key\n"); + ret = -1; + } + } + + if (ret == 0) { + sigLen = sigCap; + ret = wc_XmssKey_Sign(key, whTest_XmssSigBuf, &sigLen, msg, (int)msgSz); + if (ret != 0) { + WH_ERROR_PRINT("Failed XMSS Sign ret=%d\n", ret); + } + else if (sigLen != sigCap) { + WH_ERROR_PRINT("XMSS Sign produced unexpected length=%u expected=%u\n", + (unsigned)sigLen, (unsigned)sigCap); + ret = -1; + } + } + + if (ret == 0) { + ret = wc_XmssKey_Verify(key, whTest_XmssSigBuf, sigLen, msg, (int)msgSz); + if (ret != 0) { + WH_ERROR_PRINT("Failed XMSS Verify ret=%d\n", ret); + } + } + + if (ret == 0) { + whTest_XmssSigBuf[0] ^= 0xFF; + ret = wc_XmssKey_Verify(key, whTest_XmssSigBuf, sigLen, msg, (int)msgSz); + whTest_XmssSigBuf[0] ^= 0xFF; + if (ret == 0) { + WH_ERROR_PRINT("XMSS Verify unexpectedly accepted tampered sig\n"); + ret = -1; + } + else { + ret = 0; + } + } + + if (ret == 0) { + const byte wrongMsg[] = "wolfHSM XMSS cryptocb wrong"; + ret = wc_XmssKey_Verify(key, whTest_XmssSigBuf, sigLen, wrongMsg, + (int)(sizeof(wrongMsg) - 1)); + if (ret == 0) { + WH_ERROR_PRINT("XMSS Verify unexpectedly accepted wrong message\n"); + ret = -1; + } + else { + ret = 0; + } + } + + /* H=10 means 1024 sigs total; after one sign, the key is still not + * exhausted. */ + if (ret == 0) { + if (wc_XmssKey_SigsLeft(key) == 0) { + WH_ERROR_PRINT("XMSS reported exhausted after one sign\n"); + ret = -1; + } + } + + if (keyInited) { + whKeyId evictId = WH_KEYID_ERASED; + if ((wh_Client_XmssGetKeyId(key, &evictId) == 0) && + !WH_KEYID_ISERASED(evictId)) { + int evictRet = wh_Client_KeyEvict(ctx, evictId); + if ((evictRet != 0) && (ret == 0)) { + WH_ERROR_PRINT("Failed XMSS evict keyId=0x%X ret=%d\n", + (unsigned)evictId, evictRet); + ret = evictRet; + } + } + wc_XmssKey_Free(key); + } + + if (ret == 0) { + WH_TEST_PRINT("XMSS CryptoCb DEVID=0x%X SUCCESS\n", devId); + } + + return ret; +} +#endif /* WOLFHSM_CFG_DMA && WOLFSSL_HAVE_XMSS && !WOLFSSL_XMSS_VERIFY_ONLY */ + /* Test key usage policy enforcement for various crypto operations */ int whTest_CryptoKeyUsagePolicies(whClientContext* client, WC_RNG* rng) { @@ -9565,6 +9869,20 @@ int whTest_CryptoClientConfig(whClientConfig* config) !defined(WOLFSSL_MLKEM_NO_DECAPSULATE) */ #endif /* WOLFSSL_HAVE_MLKEM */ +#if defined(WOLFHSM_CFG_DMA) && \ + defined(WOLFSSL_HAVE_LMS) && !defined(WOLFSSL_LMS_VERIFY_ONLY) + if (ret == 0) { + ret = whTestCrypto_LmsCryptoCb(client, WH_DEV_ID_DMA, rng); + } +#endif + +#if defined(WOLFHSM_CFG_DMA) && \ + defined(WOLFSSL_HAVE_XMSS) && !defined(WOLFSSL_XMSS_VERIFY_ONLY) + if (ret == 0) { + ret = whTestCrypto_XmssCryptoCb(client, WH_DEV_ID_DMA, rng); + } +#endif + #ifdef WOLFHSM_CFG_DEBUG_VERBOSE if (ret == 0) { (void)whTest_ShowNvmAvailable(client); diff --git a/wolfhsm/wh_client_crypto.h b/wolfhsm/wh_client_crypto.h index 4b069a694..c9aa1951a 100644 --- a/wolfhsm/wh_client_crypto.h +++ b/wolfhsm/wh_client_crypto.h @@ -2240,5 +2240,76 @@ int wh_Client_MlKemDecapsulateDma(whClientContext* ctx, MlKemKey* key, #endif /* WOLFSSL_HAVE_MLKEM */ +#if defined(WOLFSSL_HAVE_LMS) || defined(WOLFSSL_HAVE_XMSS) +#ifdef WOLFHSM_CFG_DMA + +#ifdef WOLFSSL_HAVE_LMS + +/* Bind / read the wolfHSM key id stored in key->devCtx. */ +int wh_Client_LmsSetKeyId(LmsKey* key, whKeyId keyId); +int wh_Client_LmsGetKeyId(LmsKey* key, whKeyId* outId); + +/* Generate an LMS key on the server. The key's parameter set + * (levels/height/winternitz) must be bound on the in-memory key before this + * call (e.g. via wc_LmsKey_SetParameters). On success the key's devCtx + * carries the server-side keyId. + * + * If flags include WH_NVM_FLAGS_EPHEMERAL, the server returns the public key + * via DMA and the caller can sign with it as long as it remains cached on + * the server. Otherwise the key is committed to the keystore. */ +int wh_Client_LmsMakeKeyDma(whClientContext* ctx, LmsKey* key, + whKeyId* inout_key_id, whNvmFlags flags, + uint16_t label_len, uint8_t* label); + +/* Convenience wrapper: WH_NVM_FLAGS_EPHEMERAL keygen, returns pub via DMA. */ +int wh_Client_LmsMakeExportKeyDma(whClientContext* ctx, LmsKey* key); + +/* Sign msg with an HSM-resident LMS key (key->devCtx carries the keyId). + * The new private state is committed atomically to NVM by the server before + * the signature is returned. */ +int wh_Client_LmsSignDma(whClientContext* ctx, const byte* msg, word32 msgSz, + byte* sig, word32* sigSz, LmsKey* key); + +/* Verify sig against msg using an HSM-resident LMS key. *res is set to 1 on + * success, 0 on signature mismatch. */ +int wh_Client_LmsVerifyDma(whClientContext* ctx, const byte* sig, word32 sigSz, + const byte* msg, word32 msgSz, int* res, + LmsKey* key); + +/* Query remaining signatures on an HSM-resident LMS key. */ +int wh_Client_LmsSigsLeftDma(whClientContext* ctx, LmsKey* key, + word32* sigsLeft); + +#endif /* WOLFSSL_HAVE_LMS */ + +#ifdef WOLFSSL_HAVE_XMSS + +int wh_Client_XmssSetKeyId(XmssKey* key, whKeyId keyId); +int wh_Client_XmssGetKeyId(XmssKey* key, whKeyId* outId); + +/* Generate an XMSS / XMSS^MT key on the server. The parameter string must be + * bound on the in-memory key (via wc_XmssKey_SetParamStr) before this call. + */ +int wh_Client_XmssMakeKeyDma(whClientContext* ctx, XmssKey* key, + whKeyId* inout_key_id, whNvmFlags flags, + uint16_t label_len, uint8_t* label); + +int wh_Client_XmssMakeExportKeyDma(whClientContext* ctx, XmssKey* key); + +int wh_Client_XmssSignDma(whClientContext* ctx, const byte* msg, word32 msgSz, + byte* sig, word32* sigSz, XmssKey* key); + +int wh_Client_XmssVerifyDma(whClientContext* ctx, const byte* sig, + word32 sigSz, const byte* msg, word32 msgSz, + int* res, XmssKey* key); + +int wh_Client_XmssSigsLeftDma(whClientContext* ctx, XmssKey* key, + word32* sigsLeft); + +#endif /* WOLFSSL_HAVE_XMSS */ + +#endif /* WOLFHSM_CFG_DMA */ +#endif /* WOLFSSL_HAVE_LMS || WOLFSSL_HAVE_XMSS */ + #endif /* !WOLFHSM_CFG_NO_CRYPTO */ #endif /* !WOLFHSM_WH_CLIENT_CRYPTO_H_ */ diff --git a/wolfhsm/wh_crypto.h b/wolfhsm/wh_crypto.h index 6f5cfc9c3..a11825ab4 100644 --- a/wolfhsm/wh_crypto.h +++ b/wolfhsm/wh_crypto.h @@ -129,6 +129,67 @@ int wh_Crypto_MlKemDeserializeKey(const uint8_t* buffer, uint16_t size, MlKemKey* key); #endif /* WOLFSSL_HAVE_MLKEM */ +/* Stateful hash-based signature key serialization (LMS / XMSS). + * + * The slot blob layout is: + * uint32_t magic; + * uint16_t pubLen; + * uint16_t privLen; + * uint16_t paramLen; + * uint16_t reserved; (must be 0) + * uint8_t paramDescriptor[paramLen]; + * uint8_t pub[pubLen]; + * uint8_t priv[privLen]; + * + * paramDescriptor encodes the parameter set: + * LMS : 3 bytes (levels, height, winternitz) - paramLen == 3 + * XMSS : NUL-terminated parameter string, paramLen == strlen+1 + * + * The blob is server-internal (NVM-stored) and uses native byte order. */ +#define WH_CRYPTO_STATEFUL_SIG_BLOB_MAGIC_LMS 0x4C4D5301u /* 'LMS\1' */ +#define WH_CRYPTO_STATEFUL_SIG_BLOB_MAGIC_XMSS 0x584D5301u /* 'XMS\1' */ + +#ifdef WOLFSSL_HAVE_LMS +/* Store an LmsKey (parameter set + public key + priv_raw) into a byte + * sequence. The key must have a parameter set bound (params != NULL) and pub + * populated. priv_raw is read directly from the key. + * + * @param [in] key LmsKey to serialize. + * @param [in] max_size Capacity of buffer in bytes. + * @param [out] buffer Destination buffer. + * @param [in,out] out_size On success, total blob size. + * @return WH_ERROR_OK on success, WH_ERROR_BUFFER_SIZE if max_size is too + * small, WH_ERROR_BADARGS otherwise. */ +int wh_Crypto_LmsSerializeKey(LmsKey* key, uint16_t max_size, uint8_t* buffer, + uint16_t* out_size); + +/* Restore an LmsKey from a byte sequence. The caller must pass a key that + * has been wc_LmsKey_Init'd. After this call returns, the key has its params + * set, key->pub populated, and key->priv_raw populated. The caller must still + * install read/write callbacks and call wc_LmsKey_Reload before signing. + * + * @param [in] buffer Source blob. + * @param [in] size Blob size in bytes. + * @param [in,out] key Initialized LmsKey to populate. + * @return WH_ERROR_OK on success, WH_ERROR_BADARGS on malformed blob. */ +int wh_Crypto_LmsDeserializeKey(const uint8_t* buffer, uint16_t size, + LmsKey* key); +#endif /* WOLFSSL_HAVE_LMS */ + +#ifdef WOLFSSL_HAVE_XMSS +/* Store an XmssKey (param string + public key + secret state) into a byte + * sequence. */ +int wh_Crypto_XmssSerializeKey(XmssKey* key, const char* paramStr, + uint16_t max_size, uint8_t* buffer, + uint16_t* out_size); + +/* Restore an XmssKey from a byte sequence. The caller must pass a key that + * has been wc_XmssKey_Init'd. The function calls wc_XmssKey_SetParamStr + * (which allocates key->sk) and copies pub and sk from the blob. */ +int wh_Crypto_XmssDeserializeKey(const uint8_t* buffer, uint16_t size, + XmssKey* key); +#endif /* WOLFSSL_HAVE_XMSS */ + #endif /* !WOLFHSM_CFG_NO_CRYPTO */ #endif /* WOLFHSM_WH_CRYPTO_H_ */ diff --git a/wolfhsm/wh_message_crypto.h b/wolfhsm/wh_message_crypto.h index 3305a5fd5..0b89e726d 100644 --- a/wolfhsm/wh_message_crypto.h +++ b/wolfhsm/wh_message_crypto.h @@ -1510,6 +1510,123 @@ int wh_MessageCrypto_TranslateMlKemDecapsDmaResponse( uint16_t magic, const whMessageCrypto_MlKemDecapsDmaResponse* src, whMessageCrypto_MlKemDecapsDmaResponse* dest); +/* Stateful hash-based signature (LMS / XMSS) DMA messages. + * + * The discriminator (LMS vs XMSS) rides on the generic request header's + * algoSubType field, set to WC_PQC_STATEFUL_SIG_TYPE_LMS or _XMSS by the + * client. Parameter selection on keygen uses lmsLevels/lmsHeight/lmsWinternitz + * when algoSubType == LMS, or xmssParamStr when algoSubType == XMSS. + * xmssParamStr is sized to fit the longest XMSS^MT name (e.g. + * "XMSSMT-SHAKE256_60/12_256") plus NUL. + */ + +/* Stateful sig DMA Key Generation Request */ +typedef struct { + whMessageCrypto_DmaBuffer pub; /* Server writes pub key here */ + uint32_t flags; + uint32_t keyId; + uint32_t access; + uint32_t labelSize; + uint32_t lmsLevels; + uint32_t lmsHeight; + uint32_t lmsWinternitz; + uint8_t label[WH_NVM_LABEL_LEN]; + char xmssParamStr[32]; + uint8_t WH_PAD[4]; /* Pad to 8-byte alignment */ +} whMessageCrypto_PqcStatefulSigKeyGenDmaRequest; + +/* Stateful sig DMA Key Generation Response */ +typedef struct { + whMessageCrypto_DmaAddrStatus dmaAddrStatus; + uint32_t keyId; + uint32_t pubSize; +} whMessageCrypto_PqcStatefulSigKeyGenDmaResponse; + +/* Stateful sig DMA Sign Request */ +typedef struct { + whMessageCrypto_DmaBuffer msg; /* Message to sign */ + whMessageCrypto_DmaBuffer sig; /* Server writes signature here */ + uint32_t options; +#define WH_MESSAGE_CRYPTO_STATEFUL_SIG_OPTIONS_EVICT (1 << 0) + uint32_t keyId; +} whMessageCrypto_PqcStatefulSigSignDmaRequest; + +/* Stateful sig DMA Sign Response */ +typedef struct { + whMessageCrypto_DmaAddrStatus dmaAddrStatus; + uint32_t sigLen; + uint8_t WH_PAD[4]; +} whMessageCrypto_PqcStatefulSigSignDmaResponse; + +/* Stateful sig DMA Verify Request */ +typedef struct { + whMessageCrypto_DmaBuffer sig; /* Signature to verify */ + whMessageCrypto_DmaBuffer msg; /* Message that was signed */ + uint32_t options; + uint32_t keyId; +} whMessageCrypto_PqcStatefulSigVerifyDmaRequest; + +/* Stateful sig DMA Verify Response */ +typedef struct { + whMessageCrypto_DmaAddrStatus dmaAddrStatus; + uint32_t res; /* 1 if signature valid, 0 otherwise */ + uint8_t WH_PAD[4]; +} whMessageCrypto_PqcStatefulSigVerifyDmaResponse; + +/* Stateful sig DMA Signatures-Left Request. + * + * No DMA buffers are required for this query; the request is named with the + * Dma suffix purely for naming consistency with the rest of the family. */ +typedef struct { + uint32_t keyId; +} whMessageCrypto_PqcStatefulSigSigsLeftDmaRequest; + +/* Stateful sig DMA Signatures-Left Response. */ +typedef struct { + uint32_t sigsLeft; +} whMessageCrypto_PqcStatefulSigSigsLeftDmaResponse; + +/* Stateful sig DMA translation functions */ +int wh_MessageCrypto_TranslatePqcStatefulSigKeyGenDmaRequest( + uint16_t magic, + const whMessageCrypto_PqcStatefulSigKeyGenDmaRequest* src, + whMessageCrypto_PqcStatefulSigKeyGenDmaRequest* dest); + +int wh_MessageCrypto_TranslatePqcStatefulSigKeyGenDmaResponse( + uint16_t magic, + const whMessageCrypto_PqcStatefulSigKeyGenDmaResponse* src, + whMessageCrypto_PqcStatefulSigKeyGenDmaResponse* dest); + +int wh_MessageCrypto_TranslatePqcStatefulSigSignDmaRequest( + uint16_t magic, + const whMessageCrypto_PqcStatefulSigSignDmaRequest* src, + whMessageCrypto_PqcStatefulSigSignDmaRequest* dest); + +int wh_MessageCrypto_TranslatePqcStatefulSigSignDmaResponse( + uint16_t magic, + const whMessageCrypto_PqcStatefulSigSignDmaResponse* src, + whMessageCrypto_PqcStatefulSigSignDmaResponse* dest); + +int wh_MessageCrypto_TranslatePqcStatefulSigVerifyDmaRequest( + uint16_t magic, + const whMessageCrypto_PqcStatefulSigVerifyDmaRequest* src, + whMessageCrypto_PqcStatefulSigVerifyDmaRequest* dest); + +int wh_MessageCrypto_TranslatePqcStatefulSigVerifyDmaResponse( + uint16_t magic, + const whMessageCrypto_PqcStatefulSigVerifyDmaResponse* src, + whMessageCrypto_PqcStatefulSigVerifyDmaResponse* dest); + +int wh_MessageCrypto_TranslatePqcStatefulSigSigsLeftDmaRequest( + uint16_t magic, + const whMessageCrypto_PqcStatefulSigSigsLeftDmaRequest* src, + whMessageCrypto_PqcStatefulSigSigsLeftDmaRequest* dest); + +int wh_MessageCrypto_TranslatePqcStatefulSigSigsLeftDmaResponse( + uint16_t magic, + const whMessageCrypto_PqcStatefulSigSigsLeftDmaResponse* src, + whMessageCrypto_PqcStatefulSigSigsLeftDmaResponse* dest); + /* Ed25519 DMA Sign Request */ typedef struct { whMessageCrypto_DmaBuffer msg; /* Message buffer */ diff --git a/wolfhsm/wh_server_crypto.h b/wolfhsm/wh_server_crypto.h index af945ce92..d6efc6e6c 100644 --- a/wolfhsm/wh_server_crypto.h +++ b/wolfhsm/wh_server_crypto.h @@ -114,6 +114,29 @@ int wh_Server_MlKemKeyCacheExport(whServerContext* ctx, whKeyId keyId, MlKemKey* key); #endif /* WOLFSSL_HAVE_MLKEM */ +#ifdef WOLFSSL_HAVE_LMS +/* Persist an LmsKey (param descriptor + pub + priv_raw) into the server key + * cache. Subsequent sign operations reload state from this slot via + * wh_Server_LmsKeyCacheExport. */ +int wh_Server_LmsKeyCacheImport(whServerContext* ctx, LmsKey* key, + whKeyId keyId, whNvmFlags flags, + uint16_t label_len, uint8_t* label); +/* Restore an LmsKey from a server key cache slot. The key is left in a state + * suitable for installing read/write callbacks before invoking + * wc_LmsKey_Reload. */ +int wh_Server_LmsKeyCacheExport(whServerContext* ctx, whKeyId keyId, + LmsKey* key); +#endif /* WOLFSSL_HAVE_LMS */ + +#ifdef WOLFSSL_HAVE_XMSS +int wh_Server_XmssKeyCacheImport(whServerContext* ctx, XmssKey* key, + const char* paramStr, whKeyId keyId, + whNvmFlags flags, uint16_t label_len, + uint8_t* label); +int wh_Server_XmssKeyCacheExport(whServerContext* ctx, whKeyId keyId, + XmssKey* key); +#endif /* WOLFSSL_HAVE_XMSS */ + #ifdef HAVE_HKDF /* Store HKDF output into a server key cache with optional metadata */ int wh_Server_HkdfKeyCacheImport(whServerContext* ctx, const uint8_t* keyData,