|
| 1 | +package com.trilead.ssh2.crypto.dh; |
| 2 | + |
| 3 | +import com.google.crypto.tink.subtle.X25519; |
| 4 | +import java.io.IOException; |
| 5 | +import java.lang.reflect.Method; |
| 6 | +import java.math.BigInteger; |
| 7 | +import java.security.InvalidKeyException; |
| 8 | +import java.security.KeyFactory; |
| 9 | +import java.security.KeyPair; |
| 10 | +import java.security.KeyPairGenerator; |
| 11 | +import java.security.MessageDigest; |
| 12 | +import java.security.NoSuchAlgorithmException; |
| 13 | +import java.security.PrivateKey; |
| 14 | +import java.security.spec.PKCS8EncodedKeySpec; |
| 15 | + |
| 16 | +/** |
| 17 | + * ML-KEM-768 hybrid key exchange implementation (mlkem768x25519-sha256). |
| 18 | + * Combines post-quantum ML-KEM-768 with classical X25519 key exchange. |
| 19 | + * Uses reflection to access Java 23+ KEM APIs while maintaining compatibility with Java 11+. |
| 20 | + * Implements draft-ietf-sshm-mlkem-hybrid-kex-03 specification. |
| 21 | + */ |
| 22 | +public class MlKemHybridExchange extends GenericDhExchange { |
| 23 | + |
| 24 | + public static final String NAME = "mlkem768x25519-sha256"; |
| 25 | + private static final int MLKEM768_PUBLIC_KEY_SIZE = 1184; |
| 26 | + private static final int MLKEM768_CIPHERTEXT_SIZE = 1088; |
| 27 | + private static final int MLKEM768_SHARED_SECRET_SIZE = 32; |
| 28 | + private static final int X25519_KEY_SIZE = 32; |
| 29 | + |
| 30 | + private byte[] mlkemPublicKey; |
| 31 | + private byte[] mlkemPrivateKeyEncoded; |
| 32 | + private byte[] x25519PublicKey; |
| 33 | + private byte[] x25519PrivateKey; |
| 34 | + |
| 35 | + private byte[] mlkemSharedSecret; |
| 36 | + private byte[] x25519SharedSecret; |
| 37 | + private byte[] serverX25519PublicKey; |
| 38 | + private byte[] serverReply; |
| 39 | + private byte[] hybridSharedSecretK; |
| 40 | + |
| 41 | + private Object kemInstance; |
| 42 | + |
| 43 | + public MlKemHybridExchange() { |
| 44 | + super(); |
| 45 | + } |
| 46 | + |
| 47 | + @Override |
| 48 | + public void init(String name) throws IOException { |
| 49 | + if (!NAME.equals(name)) { |
| 50 | + throw new IOException("Invalid algorithm: " + name); |
| 51 | + } |
| 52 | + |
| 53 | + try { |
| 54 | + KeyPairGenerator mlkemKpg = KeyPairGenerator.getInstance("ML-KEM-768"); |
| 55 | + KeyPair mlkemKeyPair = mlkemKpg.generateKeyPair(); |
| 56 | + byte[] x509Encoded = mlkemKeyPair.getPublic().getEncoded(); |
| 57 | + mlkemPublicKey = extractRawMlKemPublicKey(x509Encoded); |
| 58 | + mlkemPrivateKeyEncoded = mlkemKeyPair.getPrivate().getEncoded(); |
| 59 | + |
| 60 | + if (mlkemPublicKey.length != MLKEM768_PUBLIC_KEY_SIZE) { |
| 61 | + throw new IOException( |
| 62 | + "Unexpected ML-KEM-768 public key size: " |
| 63 | + + mlkemPublicKey.length |
| 64 | + + " (expected " |
| 65 | + + MLKEM768_PUBLIC_KEY_SIZE |
| 66 | + + ")"); |
| 67 | + } |
| 68 | + |
| 69 | + x25519PrivateKey = X25519.generatePrivateKey(); |
| 70 | + x25519PublicKey = X25519.publicFromPrivate(x25519PrivateKey); |
| 71 | + |
| 72 | + if (x25519PublicKey.length != X25519_KEY_SIZE) { |
| 73 | + throw new IOException( |
| 74 | + "Unexpected X25519 public key size: " |
| 75 | + + x25519PublicKey.length |
| 76 | + + " (expected " |
| 77 | + + X25519_KEY_SIZE |
| 78 | + + ")"); |
| 79 | + } |
| 80 | + |
| 81 | + } catch (NoSuchAlgorithmException e) { |
| 82 | + throw new IOException("ML-KEM-768 or X25519 not available", e); |
| 83 | + } catch (InvalidKeyException e) { |
| 84 | + throw new IOException("Failed to generate key pair", e); |
| 85 | + } |
| 86 | + } |
| 87 | + |
| 88 | + @Override |
| 89 | + public byte[] getE() { |
| 90 | + byte[] init = new byte[mlkemPublicKey.length + x25519PublicKey.length]; |
| 91 | + System.arraycopy(mlkemPublicKey, 0, init, 0, mlkemPublicKey.length); |
| 92 | + System.arraycopy( |
| 93 | + x25519PublicKey, 0, init, mlkemPublicKey.length, x25519PublicKey.length); |
| 94 | + return init; |
| 95 | + } |
| 96 | + |
| 97 | + @Override |
| 98 | + protected byte[] getServerE() { |
| 99 | + return serverReply != null ? serverReply.clone() : new byte[0]; |
| 100 | + } |
| 101 | + |
| 102 | + @Override |
| 103 | + public void setF(byte[] f) throws IOException { |
| 104 | + if (f.length != MLKEM768_CIPHERTEXT_SIZE + X25519_KEY_SIZE) { |
| 105 | + throw new IOException( |
| 106 | + "Invalid S_REPLY length: " |
| 107 | + + f.length |
| 108 | + + " (expected " |
| 109 | + + (MLKEM768_CIPHERTEXT_SIZE + X25519_KEY_SIZE) |
| 110 | + + ")"); |
| 111 | + } |
| 112 | + |
| 113 | + serverReply = f.clone(); |
| 114 | + |
| 115 | + try { |
| 116 | + byte[] mlkemCiphertext = new byte[MLKEM768_CIPHERTEXT_SIZE]; |
| 117 | + System.arraycopy(f, 0, mlkemCiphertext, 0, MLKEM768_CIPHERTEXT_SIZE); |
| 118 | + |
| 119 | + serverX25519PublicKey = new byte[X25519_KEY_SIZE]; |
| 120 | + System.arraycopy(f, MLKEM768_CIPHERTEXT_SIZE, serverX25519PublicKey, 0, X25519_KEY_SIZE); |
| 121 | + |
| 122 | + mlkemSharedSecret = performMlKemDecapsulation(mlkemCiphertext); |
| 123 | + |
| 124 | + x25519SharedSecret = X25519.computeSharedSecret(x25519PrivateKey, serverX25519PublicKey); |
| 125 | + validateX25519SharedSecret(x25519SharedSecret); |
| 126 | + |
| 127 | + byte[] combined = new byte[MLKEM768_SHARED_SECRET_SIZE + X25519_KEY_SIZE]; |
| 128 | + System.arraycopy(mlkemSharedSecret, 0, combined, 0, MLKEM768_SHARED_SECRET_SIZE); |
| 129 | + System.arraycopy(x25519SharedSecret, 0, combined, MLKEM768_SHARED_SECRET_SIZE, X25519_KEY_SIZE); |
| 130 | + |
| 131 | + hybridSharedSecretK = computeHybridSharedSecret(combined); |
| 132 | + sharedSecret = new BigInteger(1, hybridSharedSecretK); |
| 133 | + |
| 134 | + } catch (InvalidKeyException e) { |
| 135 | + throw new IOException("X25519 key agreement failed", e); |
| 136 | + } catch (Exception e) { |
| 137 | + throw new IOException("ML-KEM decapsulation or key agreement failed", e); |
| 138 | + } |
| 139 | + } |
| 140 | + |
| 141 | + private byte[] performMlKemDecapsulation(byte[] ciphertext) throws IOException { |
| 142 | + try { |
| 143 | + if (kemInstance == null) { |
| 144 | + Class<?> kemClass = Class.forName("javax.crypto.KEM"); |
| 145 | + Method getInstance = kemClass.getMethod("getInstance", String.class); |
| 146 | + kemInstance = getInstance.invoke(null, "ML-KEM"); |
| 147 | + } |
| 148 | + |
| 149 | + KeyFactory kf = KeyFactory.getInstance("ML-KEM"); |
| 150 | + PKCS8EncodedKeySpec privateKeySpec = new PKCS8EncodedKeySpec(mlkemPrivateKeyEncoded); |
| 151 | + PrivateKey mlkemPrivateKey = kf.generatePrivate(privateKeySpec); |
| 152 | + |
| 153 | + Class<?> kemClass = Class.forName("javax.crypto.KEM"); |
| 154 | + Method newDecapsulator = kemClass.getMethod("newDecapsulator", PrivateKey.class); |
| 155 | + Object decapsulator = newDecapsulator.invoke(kemInstance, mlkemPrivateKey); |
| 156 | + |
| 157 | + Class<?> decapsulatorClass = Class.forName("javax.crypto.KEM$Decapsulator"); |
| 158 | + Method decapsulateMethod = decapsulatorClass.getMethod("decapsulate", byte[].class); |
| 159 | + Object secretKey = decapsulateMethod.invoke(decapsulator, ciphertext); |
| 160 | + |
| 161 | + javax.crypto.SecretKey sk = (javax.crypto.SecretKey) secretKey; |
| 162 | + return sk.getEncoded(); |
| 163 | + |
| 164 | + } catch (ClassNotFoundException e) { |
| 165 | + throw new IOException("ML-KEM not available (Java 23+ required)", e); |
| 166 | + } catch (NoSuchAlgorithmException e) { |
| 167 | + throw new IOException("ML-KEM not available", e); |
| 168 | + } catch (Exception e) { |
| 169 | + throw new IOException("ML-KEM decapsulation failed", e); |
| 170 | + } |
| 171 | + } |
| 172 | + |
| 173 | + private void validateX25519SharedSecret(byte[] sharedSecret) throws IOException { |
| 174 | + int allBytes = 0; |
| 175 | + for (int i = 0; i < sharedSecret.length; i++) { |
| 176 | + allBytes |= sharedSecret[i]; |
| 177 | + } |
| 178 | + if (allBytes == 0) { |
| 179 | + throw new IOException("Invalid X25519 shared secret; all zeroes"); |
| 180 | + } |
| 181 | + } |
| 182 | + |
| 183 | + private byte[] computeHybridSharedSecret(byte[] combined) throws IOException { |
| 184 | + try { |
| 185 | + MessageDigest md = MessageDigest.getInstance("SHA-256"); |
| 186 | + return md.digest(combined); |
| 187 | + } catch (NoSuchAlgorithmException e) { |
| 188 | + throw new IOException("SHA-256 not available", e); |
| 189 | + } |
| 190 | + } |
| 191 | + |
| 192 | + @Override |
| 193 | + public String getHashAlgo() { |
| 194 | + return "SHA-256"; |
| 195 | + } |
| 196 | + |
| 197 | + @Override |
| 198 | + public byte[] getK() { |
| 199 | + if (hybridSharedSecretK == null) { |
| 200 | + throw new IllegalStateException("Shared secret not yet known, need f first!"); |
| 201 | + } |
| 202 | + return hybridSharedSecretK.clone(); |
| 203 | + } |
| 204 | + |
| 205 | + private static byte[] extractRawMlKemPublicKey(byte[] x509Encoded) throws IOException { |
| 206 | + if (x509Encoded.length < 22) { |
| 207 | + throw new IOException("X.509 encoded ML-KEM public key too short"); |
| 208 | + } |
| 209 | + |
| 210 | + if (x509Encoded[0] != 0x30) { |
| 211 | + throw new IOException("Invalid X.509 encoding: expected SEQUENCE tag"); |
| 212 | + } |
| 213 | + |
| 214 | + if (x509Encoded[17] != 0x03) { |
| 215 | + throw new IOException("Invalid X.509 encoding: BIT STRING not found at expected position"); |
| 216 | + } |
| 217 | + |
| 218 | + if (x509Encoded[21] != 0x00) { |
| 219 | + throw new IOException("Invalid X.509 encoding: unexpected unused bits in BIT STRING"); |
| 220 | + } |
| 221 | + |
| 222 | + byte[] rawKey = new byte[MLKEM768_PUBLIC_KEY_SIZE]; |
| 223 | + System.arraycopy(x509Encoded, 22, rawKey, 0, MLKEM768_PUBLIC_KEY_SIZE); |
| 224 | + return rawKey; |
| 225 | + } |
| 226 | +} |
0 commit comments