libcrux/hpke/
kem.rs

1#![doc = include_str!("KEM_Readme.md")]
2#![doc = include_str!("KEM_Security.md")]
3#![allow(non_camel_case_types, non_snake_case)]
4
5use crate::kem::*;
6
7use super::errors::*;
8use super::kdf::*;
9
10/// ## Key Encapsulation Mechanisms (KEMs)
11///
12/// | Value  | KEM                        | Nsecret  | Nenc | Npk | Nsk | Auth | Reference               |
13/// |:-------|:---------------------------|:---------|:-----|:----|:----|:-----|:------------------------|
14/// | 0x0000 | (reserved)                 | N/A      | N/A  | N/A | N/A | yes  | N/A                     |
15/// | 0x0010 | DHKEM(P-256, HKDF-SHA256)  | 32       | 65   | 65  | 32  | yes  | [NISTCurves], [RFC5869] |
16/// | 0x0011 | DHKEM(P-384, HKDF-SHA384)  | 48       | 97   | 97  | 48  | yes  | [NISTCurves], [RFC5869] |
17/// | 0x0012 | DHKEM(P-521, HKDF-SHA512)  | 64       | 133  | 133 | 66  | yes  | [NISTCurves], [RFC5869] |
18/// | 0x0020 | DHKEM(X25519, HKDF-SHA256) | 32       | 32   | 32  | 32  | yes  | [RFC7748], [RFC5869]    |
19/// | 0x0021 | DHKEM(X448, HKDF-SHA512)   | 64       | 56   | 56  | 56  | yes  | [RFC7748], [RFC5869]    |
20///
21/// The `Auth` column indicates if the KEM algorithm provides the [`AuthEncap()`]/[`AuthDecap()`]
22/// interface and is therefore suitable for the Auth and AuthPSK modes. The meaning of all
23/// other columns is explained below. All algorithms are suitable for the
24/// PSK mode.
25///
26/// ### KEM Identifiers
27///
28/// The "HPKE KEM Identifiers" registry lists identifiers for key encapsulation
29/// algorithms defined for use with HPKE. These identifiers are two-byte values,
30/// so the maximum possible value is 0xFFFF = 65535.
31///
32/// Template:
33///
34/// * Value: The two-byte identifier for the algorithm
35/// * KEM: The name of the algorithm
36/// * Nsecret: The length in bytes of a KEM shared secret produced by the algorithm
37/// * Nenc: The length in bytes of an encoded encapsulated key produced by the algorithm
38/// * Npk: The length in bytes of an encoded public key for the algorithm
39/// * Nsk: The length in bytes of an encoded private key for the algorithm
40/// * Auth: A boolean indicating if this algorithm provides the [`AuthEncap()`]/[`AuthDecap()`] interface
41/// * Reference: Where this algorithm is defined
42///
43/// [NISTCurves]: https://doi.org/10.6028/nist.fips.186-4
44/// [RFC7748]: https://www.rfc-editor.org/info/rfc7748
45/// [RFC5869]: https://www.rfc-editor.org/info/rfc5869
46#[derive(Clone, Copy, PartialEq, Debug)]
47pub enum KEM {
48    /// 0x0010
49    DHKEM_P256_HKDF_SHA256,
50    /// 0x0011
51    DHKEM_P384_HKDF_SHA384,
52    /// 0x0012
53    DHKEM_P521_HKDF_SHA512,
54    /// 0x0020
55    DHKEM_X25519_HKDF_SHA256,
56    /// 0x0021
57    DHKEM_X448_HKDF_SHA512,
58}
59
60/// [`u16`] value of the `kem_id`.
61///
62/// See [`KEM`] for details.
63pub fn kem_value(kem_id: KEM) -> u16 {
64    match kem_id {
65        KEM::DHKEM_P256_HKDF_SHA256 => 0x0010u16,
66        KEM::DHKEM_P384_HKDF_SHA384 => 0x0011u16,
67        KEM::DHKEM_P521_HKDF_SHA512 => 0x0012u16,
68        KEM::DHKEM_X25519_HKDF_SHA256 => 0x00020u16,
69        KEM::DHKEM_X448_HKDF_SHA512 => 0x0021u16,
70    }
71}
72
73/// Get the [`KDF`] algorithm for the given `kem_id`.
74///
75/// See [`KEM`] for details.
76fn kdf_for_kem(kem_id: KEM) -> KDF {
77    match kem_id {
78        KEM::DHKEM_P256_HKDF_SHA256 => KDF::HKDF_SHA256,
79        KEM::DHKEM_P384_HKDF_SHA384 => KDF::HKDF_SHA384,
80        KEM::DHKEM_P521_HKDF_SHA512 => KDF::HKDF_SHA512,
81        KEM::DHKEM_X25519_HKDF_SHA256 => KDF::HKDF_SHA256,
82        KEM::DHKEM_X448_HKDF_SHA512 => KDF::HKDF_SHA512,
83    }
84}
85
86/// Convert the KEM type to the KEM algorithm of libcrux.
87fn kem_to_named_group(alg: KEM) -> Algorithm {
88    match alg {
89        KEM::DHKEM_P256_HKDF_SHA256 => Algorithm::Secp256r1,
90        KEM::DHKEM_P384_HKDF_SHA384 => Algorithm::Secp384r1,
91        KEM::DHKEM_P521_HKDF_SHA512 => Algorithm::Secp521r1,
92        KEM::DHKEM_X25519_HKDF_SHA256 => Algorithm::X25519,
93        KEM::DHKEM_X448_HKDF_SHA512 => Algorithm::X448,
94    }
95}
96
97/// Get the length of the shared secret.
98///
99/// See [`KEM`] for details.
100pub fn Nsecret(kem_id: KEM) -> usize {
101    match kem_id {
102        KEM::DHKEM_P256_HKDF_SHA256 => 32,
103        KEM::DHKEM_P384_HKDF_SHA384 => 48,
104        KEM::DHKEM_P521_HKDF_SHA512 => 64,
105        KEM::DHKEM_X25519_HKDF_SHA256 => 32,
106        KEM::DHKEM_X448_HKDF_SHA512 => 64,
107    }
108}
109
110/// Get the length of the encoded encapsulated key.
111///
112/// See [`KEM`] for details.
113pub fn Nenc(kem_id: KEM) -> usize {
114    match kem_id {
115        KEM::DHKEM_P256_HKDF_SHA256 => 65,
116        KEM::DHKEM_P384_HKDF_SHA384 => 97,
117        KEM::DHKEM_P521_HKDF_SHA512 => 133,
118        KEM::DHKEM_X25519_HKDF_SHA256 => 32,
119        KEM::DHKEM_X448_HKDF_SHA512 => 56,
120    }
121}
122
123/// Get the length of the private key.
124///
125/// See [`KEM`] for details.
126pub fn Nsk(kem_id: KEM) -> usize {
127    match kem_id {
128        KEM::DHKEM_P256_HKDF_SHA256 => 32,
129        KEM::DHKEM_P384_HKDF_SHA384 => 48,
130        KEM::DHKEM_P521_HKDF_SHA512 => 66,
131        KEM::DHKEM_X25519_HKDF_SHA256 => 32,
132        KEM::DHKEM_X448_HKDF_SHA512 => 56,
133    }
134}
135
136/// Get the length of the encoded public key.
137///
138/// See [`KEM`] for details.
139pub fn Npk(kem_id: KEM) -> usize {
140    match kem_id {
141        KEM::DHKEM_P256_HKDF_SHA256 => 65,
142        KEM::DHKEM_P384_HKDF_SHA384 => 97,
143        KEM::DHKEM_P521_HKDF_SHA512 => 133,
144        KEM::DHKEM_X25519_HKDF_SHA256 => 32,
145        KEM::DHKEM_X448_HKDF_SHA512 => 56,
146    }
147}
148
149/// The length in bytes of a Diffie-Hellman shared secret produced by [`DH()`].
150///
151/// |        | [`Ndh`] |
152/// | ------ | ------- |
153/// | P-256  | 32      |
154/// | P-384  | 48      |
155/// | P-521  | 66      |
156/// | X25519 | 32      |
157/// | X448   | 56      |
158pub fn Ndh(kem_id: KEM) -> usize {
159    match kem_id {
160        KEM::DHKEM_P256_HKDF_SHA256 => 32,
161        KEM::DHKEM_P384_HKDF_SHA384 => 48,
162        KEM::DHKEM_P521_HKDF_SHA512 => 66,
163        KEM::DHKEM_X25519_HKDF_SHA256 => 32,
164        KEM::DHKEM_X448_HKDF_SHA512 => 56,
165    }
166}
167
168pub type PrivateKey = Vec<u8>;
169pub type PublicKey = Vec<u8>;
170pub type KeyPair = (PrivateKey, PublicKey);
171pub type PublicKeyIn = [u8];
172pub type PrivateKeyIn = [u8];
173pub type SharedSecret = Vec<u8>;
174pub type SerializedPublicKey = Vec<u8>;
175pub type Randomness = Vec<u8>;
176
177pub type EncapResult = Result<(SharedSecret, SerializedPublicKey), HpkeError>;
178
179// === Label ===
180
181/// "dkp_prk"
182fn dkp_prk_label() -> Vec<u8> {
183    vec![0x64u8, 0x6bu8, 0x70u8, 0x5fu8, 0x70u8, 0x72u8, 0x6bu8]
184}
185
186/// "eae_prk"
187fn eae_prk_label() -> Vec<u8> {
188    vec![0x65u8, 0x61u8, 0x65u8, 0x5fu8, 0x70u8, 0x72u8, 0x6bu8]
189}
190
191/// "sk"
192fn sk_label() -> Vec<u8> {
193    vec![0x73u8, 0x6bu8]
194}
195
196/// "candidate"
197fn candidate_label() -> Vec<u8> {
198    vec![
199        0x63u8, 0x61u8, 0x6eu8, 0x64u8, 0x69u8, 0x64u8, 0x61u8, 0x74u8, 0x65u8,
200    ]
201}
202
203/// "shared_secret"
204fn shared_secret_label() -> Vec<u8> {
205    vec![
206        0x73u8, 0x68u8, 0x61u8, 0x72u8, 0x65u8, 0x64u8, 0x5fu8, 0x73u8, 0x65u8, 0x63u8, 0x72u8,
207        0x65u8, 0x74u8,
208    ]
209}
210
211/// Get an empty byte sequence.
212fn empty() -> Vec<u8> {
213    vec![]
214}
215
216/// Get the label for the KEM with the cipher suite ID.
217/// "KEM"
218fn suite_id(alg: KEM) -> Vec<u8> {
219    let mut suite_id = vec![0x4bu8, 0x45u8, 0x4du8]; // "KEM"
220    suite_id.extend_from_slice(&kem_value(alg).to_be_bytes());
221    suite_id
222}
223
224/// For the variants of DHKEM defined in this document, the size [`Nsecret`] of the
225/// KEM shared secret is equal to the output length of the hash function
226/// underlying the KDF. For P-256, P-384 and P-521, the size `Ndh` of the
227/// Diffie-Hellman shared secret is equal to 32, 48, and 66, respectively,
228/// corresponding to the x-coordinate of the resulting elliptic curve point.
229/// For X25519 and X448, the size [`Ndh`] of is equal to 32 and 56, respectively.
230fn shared_secret_from_dh(alg: KEM, mut secret: Vec<u8>) -> Vec<u8> {
231    match alg {
232        KEM::DHKEM_P256_HKDF_SHA256 => secret.drain(0..Ndh(alg)).collect(),
233        KEM::DHKEM_P384_HKDF_SHA384 => secret.drain(0..Ndh(alg)).collect(),
234        KEM::DHKEM_P521_HKDF_SHA512 => secret.drain(0..Ndh(alg)).collect(),
235        KEM::DHKEM_X25519_HKDF_SHA256 => secret,
236        KEM::DHKEM_X448_HKDF_SHA512 => secret,
237    }
238}
239
240/// Perform a non-interactive Diffie-Hellman exchange using the private key
241/// `skX` and public key `pkY` to produce a Diffie-Hellman shared
242/// secret of length `Ndh`. This function can raise a
243/// [`ValidationError`](`HpkeError::ValidationError`) as described in
244/// [validation](#validation-of-inputs-and-outputs).
245pub fn DH(alg: KEM, sk: &PrivateKeyIn, pk: &PublicKeyIn) -> Result<SharedSecret, HpkeError> {
246    match crate::ecdh::derive(kem_to_named_group(alg).try_into().unwrap(), pk, sk) {
247        Ok(secret) => HpkeBytesResult::Ok(shared_secret_from_dh(alg, secret)),
248        Err(_) => HpkeBytesResult::Err(HpkeError::ValidationError),
249    }
250}
251
252fn pk(alg: KEM, sk: &PrivateKeyIn) -> Result<PublicKey, HpkeError> {
253    match crate::kem::secret_to_public(kem_to_named_group(alg), sk) {
254        Ok(pk) => HpkeBytesResult::Ok(pk),
255        Err(_) => HpkeBytesResult::Err(HpkeError::ValidationError),
256    }
257}
258
259/// Prepend 0x04 to the byte sequence.
260fn nist_curve_to_uncompressed(pk: PublicKey) -> PublicKey {
261    let mut out = vec![0x04u8];
262    out.extend_from_slice(&pk);
263    out
264}
265
266/// Produce a byte string of length `Npk` encoding the public key `pkX`.
267///
268/// For P-256, P-384 and P-521, the [`SerializePublicKey()`] function of the
269/// KEM performs the uncompressed Elliptic-Curve-Point-to-Octet-String
270/// conversion according to [SECG]. [`DeserializePublicKey()`] performs the
271/// uncompressed Octet-String-to-Elliptic-Curve-Point conversion.
272///
273/// For X25519 and X448, the `SerializePublicKey()` and `DeserializePublicKey()`
274/// functions are the identity function, since these curves already use
275/// fixed-length byte strings for public keys.
276///
277/// Some deserialized public keys MUST be validated before they can be used.
278///
279/// [secg]: https://secg.org/sec1-v2.pdf
280pub fn SerializePublicKey(alg: KEM, pk: PublicKey) -> PublicKey {
281    match alg {
282        KEM::DHKEM_P256_HKDF_SHA256 => nist_curve_to_uncompressed(pk),
283        KEM::DHKEM_P384_HKDF_SHA384 => nist_curve_to_uncompressed(pk),
284        KEM::DHKEM_P521_HKDF_SHA512 => nist_curve_to_uncompressed(pk),
285        KEM::DHKEM_X25519_HKDF_SHA256 => pk,
286        KEM::DHKEM_X448_HKDF_SHA512 => pk,
287    }
288}
289
290/// Remove the leading 0x04 from the public key.
291fn nist_curve_from_uncompressed(pk: &PublicKeyIn) -> Vec<u8> {
292    if pk[0] == 0x04 {
293        pk[1..].to_vec()
294    } else {
295        pk.to_vec()
296    }
297}
298
299/// Parse a byte string of length `Npk` to recover a
300/// public key. This function can raise a `DeserializeError` error upon `pkXm`
301/// deserialization failure.
302pub fn DeserializePublicKey(alg: KEM, enc: &[u8]) -> HpkeBytesResult {
303    HpkeBytesResult::Ok(match alg {
304        KEM::DHKEM_P256_HKDF_SHA256 => nist_curve_from_uncompressed(enc),
305        KEM::DHKEM_P384_HKDF_SHA384 => nist_curve_from_uncompressed(enc),
306        KEM::DHKEM_P521_HKDF_SHA512 => nist_curve_from_uncompressed(enc),
307        KEM::DHKEM_X25519_HKDF_SHA256 => enc.to_vec(),
308        KEM::DHKEM_X448_HKDF_SHA512 => enc.to_vec(),
309    })
310}
311
312/// ```text
313/// def ExtractAndExpand(dh, kem_context):
314///   eae_prk = LabeledExtract("", "eae_prk", dh)
315///   shared_secret = LabeledExpand(eae_prk, "shared_secret",
316///                                 kem_context, Nsecret)
317///   return shared_secret
318/// ```
319fn ExtractAndExpand(
320    alg: KEM,
321    suite_id: Vec<u8>,
322    dh: SharedSecret,
323    kem_context: &[u8],
324) -> HpkeBytesResult {
325    let kdf = kdf_for_kem(alg);
326    let eae_prk = LabeledExtract(kdf, suite_id.clone(), &empty(), eae_prk_label(), &dh)?;
327    LabeledExpand(
328        kdf,
329        suite_id,
330        &eae_prk,
331        shared_secret_label(),
332        kem_context,
333        Nsecret(alg),
334    )
335}
336
337fn I2OSP(counter: usize) -> Vec<u8> {
338    vec![counter as u8]
339}
340
341/// For X25519 and X448, the `DeriveKeyPair()` function applies a KDF to the input:
342///
343/// ```text
344/// def DeriveKeyPair(ikm):
345///   dkp_prk = LabeledExtract("", "dkp_prk", ikm)
346///   sk = LabeledExpand(dkp_prk, "sk", "", Nsk)
347///   return (sk, pk(sk))
348/// ```
349pub fn DeriveKeyPairX(alg: KEM, ikm: &InputKeyMaterial) -> Result<KeyPair, HpkeError> {
350    let kdf = kdf_for_kem(alg);
351    let dkp_prk = LabeledExtract(kdf, suite_id(alg), &empty(), dkp_prk_label(), ikm)?;
352
353    let sk = LabeledExpand(kdf, suite_id(alg), &dkp_prk, sk_label(), &empty(), Nsk(alg))?;
354
355    match crate::kem::secret_to_public(kem_to_named_group(alg), &sk) {
356        Ok(pk) => Result::<KeyPair, HpkeError>::Ok((sk, pk)),
357        Err(_) => Result::<KeyPair, HpkeError>::Err(HpkeError::CryptoError),
358    }
359}
360
361/// ### DeriveKeyPair
362///
363/// The keys that [`DeriveKeyPair()`] produces have only as much entropy as the provided
364/// input keying material. For a given KEM, the `ikm` parameter given to [`DeriveKeyPair()`] SHOULD
365/// have length at least [`Nsk`], and SHOULD have at least [`Nsk`] bytes of entropy.
366///
367/// All invocations of KDF functions (such as [`LabeledExtract()`] or [`LabeledExpand()`]) in any
368/// DHKEM's [`DeriveKeyPair()`] function use the DHKEM's associated KDF (as opposed to
369/// the ciphersuite's KDF).
370///
371/// For P-256, P-384 and P-521, the [`DeriveKeyPair()`] function of the KEM performs
372/// rejection sampling over field elements.
373///
374/// ```text
375/// def DeriveKeyPair(ikm):
376///   dkp_prk = LabeledExtract("", "dkp_prk", ikm)
377///   sk = 0
378///   counter = 0
379///   while sk == 0 or sk >= order:
380///     if counter > 255:
381///       raise DeriveKeyPairError
382///     bytes = LabeledExpand(dkp_prk, "candidate",
383///                           I2OSP(counter, 1), Nsk)
384///     bytes[0] = bytes[0] & bitmask
385///     sk = OS2IP(bytes)
386///     counter = counter + 1
387///   return (sk, pk(sk))
388/// ```
389///
390/// `order` is the order of the curve being used (see section D.1.2 of [NISTCurves]), and
391/// is listed below for completeness.
392///
393/// ```text
394/// P-256:
395/// 0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551
396///
397/// P-384:
398/// 0xffffffffffffffffffffffffffffffffffffffffffffffffc7634d81f4372ddf
399///   581a0db248b0a77aecec196accc52973
400///
401/// P-521:
402/// 0x01ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
403///   fa51868783bf2f966b7fcc0148f709a5d03bb5c9b8899c47aebb6fb71e91386409
404/// ```
405///
406/// `bitmask` is defined to be 0xFF for P-256 and P-384, and 0x01 for P-521.
407/// The precise likelihood of `DeriveKeyPair()` failing with DeriveKeyPairError
408/// depends on the group being used, but it is negligibly small in all cases.
409/// See [hpke errors](`mod@crate::hpke::errors`) for information about dealing with such failures.
410///
411/// For X25519 and X448, the [`DeriveKeyPair()`] function applies a KDF to the input:
412///
413/// ```text
414/// def DeriveKeyPair(ikm):
415///   dkp_prk = LabeledExtract("", "dkp_prk", ikm)
416///   sk = LabeledExpand(dkp_prk, "sk", "", Nsk)
417///   return (sk, pk(sk))
418/// ```
419///
420/// [NISTCurves]: https://doi.org/10.6028/nist.fips.186-4
421pub fn DeriveKeyPair(alg: KEM, ikm: &InputKeyMaterial) -> Result<KeyPair, HpkeError> {
422    let kdf = kdf_for_kem(alg);
423    let dkp_prk = LabeledExtract(kdf, suite_id(alg), &empty(), dkp_prk_label(), ikm)?;
424
425    let named_group = kem_to_named_group(alg);
426    let sk = if alg == KEM::DHKEM_X25519_HKDF_SHA256 || alg == KEM::DHKEM_X448_HKDF_SHA512 {
427        LabeledExpand(kdf, suite_id(alg), &dkp_prk, sk_label(), &empty(), 32)?
428    } else {
429        let mut bitmask = 0xFFu8;
430        if alg == KEM::DHKEM_P521_HKDF_SHA512 {
431            bitmask = 0x01u8;
432        }
433        let mut sk = Vec::new();
434        for counter in 0..256 {
435            if sk.len() == 0 {
436                // Only keep looking if we didn't find one.
437                let mut bytes = LabeledExpand(
438                    kdf,
439                    suite_id(alg),
440                    &dkp_prk,
441                    candidate_label(),
442                    &I2OSP(counter),
443                    32,
444                )?;
445                bytes[0] = bytes[0] & bitmask;
446                // This check ensure sk != 0 or sk < order
447                if crate::ecdh::validate_scalar(named_group.try_into().unwrap(), &bytes).is_ok() {
448                    sk = bytes;
449                }
450            }
451        }
452        sk
453    };
454    if sk.len() == 0 {
455        Result::<KeyPair, HpkeError>::Err(HpkeError::DeriveKeyPairError)
456    } else {
457        let pk = pk(alg, &sk)?;
458        Ok((sk, pk))
459    }
460}
461
462/// Randomized algorithm to generate a key pair `(skX, pkX)`.
463pub fn GenerateKeyPair(alg: KEM, randomness: Randomness) -> Result<KeyPair, HpkeError> {
464    if randomness.len() != Nsk(alg) {
465        Err(HpkeError::InvalidParameters)
466    } else {
467        DeriveKeyPair(alg, &randomness)
468    }
469}
470
471/// ```text
472/// def Encap(pkR):
473///   skE, pkE = GenerateKeyPair()
474///   dh = DH(skE, pkR)
475///   enc = SerializePublicKey(pkE)
476///
477///   pkRm = SerializePublicKey(pkR)
478///   kem_context = concat(enc, pkRm)
479///
480///   shared_secret = ExtractAndExpand(dh, kem_context)
481/// ```
482pub fn Encap(alg: KEM, pkR: &PublicKeyIn, randomness: Randomness) -> EncapResult {
483    let (skE, pkE) = GenerateKeyPair(alg, randomness)?;
484    let dh = DH(alg, &skE, pkR)?;
485    let enc = SerializePublicKey(alg, pkE);
486
487    let pkRm = SerializePublicKey(alg, pkR.to_vec());
488    let mut kem_context = enc.clone();
489    kem_context.extend_from_slice(&pkRm);
490
491    let shared_secret = ExtractAndExpand(alg, suite_id(alg), dh, &kem_context)?;
492    EncapResult::Ok((shared_secret, enc))
493}
494
495/// ```text
496/// def Decap(enc, skR):
497///   pkE = DeserializePublicKey(enc)
498///   dh = DH(skR, pkE)
499///
500///   pkRm = SerializePublicKey(pk(skR))
501///   kem_context = concat(enc, pkRm)
502///
503///   shared_secret = ExtractAndExpand(dh, kem_context)
504///   return shared_secret
505/// ```
506pub fn Decap(alg: KEM, enc: &[u8], skR: &PrivateKeyIn) -> Result<SharedSecret, HpkeError> {
507    let pkE = DeserializePublicKey(alg, enc)?;
508    let dh = DH(alg, skR, &pkE)?;
509
510    let pkR = pk(alg, skR)?;
511    let pkRm = SerializePublicKey(alg, pkR);
512    let mut kem_context = enc.to_vec();
513    kem_context.extend_from_slice(&pkRm);
514
515    ExtractAndExpand(alg, suite_id(alg), dh, &kem_context)
516}
517
518/// ```text
519/// def AuthEncap(pkR, skS):
520///   skE, pkE = GenerateKeyPair()
521///   dh = concat(DH(skE, pkR), DH(skS, pkR))
522///   enc = SerializePublicKey(pkE)
523///
524///   pkRm = SerializePublicKey(pkR)
525///   pkSm = SerializePublicKey(pk(skS))
526///   kem_context = concat(enc, pkRm, pkSm)
527///
528///   shared_secret = ExtractAndExpand(dh, kem_context)
529///   return shared_secret, enc
530/// ```
531pub fn AuthEncap(
532    alg: KEM,
533    pkR: &PublicKeyIn,
534    skS: &PrivateKeyIn,
535    randomness: Randomness,
536) -> EncapResult {
537    let (skE, pkE) = GenerateKeyPair(alg, randomness)?;
538    let dhE = DH(alg, &skE, pkR)?;
539    let dhS = DH(alg, skS, pkR)?;
540    let mut dh = dhE;
541    dh.extend_from_slice(&dhS);
542    let enc = SerializePublicKey(alg, pkE);
543
544    let pkRm = SerializePublicKey(alg, pkR.to_vec());
545    let pkS = pk(alg, skS)?;
546    let pkSm = SerializePublicKey(alg, pkS);
547    let mut kem_context = enc.clone();
548    kem_context.extend_from_slice(&pkRm);
549    kem_context.extend_from_slice(&pkSm);
550
551    let shared_secret = ExtractAndExpand(alg, suite_id(alg), dh, &kem_context)?;
552    EncapResult::Ok((shared_secret, enc))
553}
554
555/// ```text
556/// def AuthDecap(enc, skR, pkS):
557///   pkE = DeserializePublicKey(enc)
558///   dh = concat(DH(skR, pkE), DH(skR, pkS))
559///
560///   pkRm = SerializePublicKey(pk(skR))
561///   pkSm = SerializePublicKey(pkS)
562///   kem_context = concat(enc, pkRm, pkSm)
563///
564///   shared_secret = ExtractAndExpand(dh, kem_context)
565///   return shared_secret
566/// ```
567pub fn AuthDecap(
568    alg: KEM,
569    enc: &[u8],
570    skR: &PrivateKeyIn,
571    pkS: &PublicKeyIn,
572) -> Result<SharedSecret, HpkeError> {
573    let pkE = DeserializePublicKey(alg, enc)?;
574    let dhE = DH(alg, skR, &pkE)?;
575    let dhS = DH(alg, skR, pkS)?;
576    let mut dh = dhE;
577    dh.extend_from_slice(&dhS);
578
579    let pkR = pk(alg, skR)?;
580    let pkRm = SerializePublicKey(alg, pkR);
581    let pkSm = SerializePublicKey(alg, pkS.to_vec());
582    let mut kem_context = enc.to_vec();
583    kem_context.extend_from_slice(&pkRm);
584    kem_context.extend_from_slice(&pkSm);
585
586    ExtractAndExpand(alg, suite_id(alg), dh, &kem_context)
587}
588
589#[test]
590fn derive_x25519() {
591    use std::num::ParseIntError;
592
593    fn from_hex(s: &str) -> Vec<u8> {
594        debug_assert!(s.len() % 2 == 0);
595        let b: Result<Vec<u8>, ParseIntError> = (0..s.len())
596            .step_by(2)
597            .map(|i| u8::from_str_radix(&s[i..i + 2], 16))
598            .collect();
599        b.expect("Error parsing hex string")
600    }
601
602    // A.1.1. test vector
603    let ikm_e = from_hex("7268600d403fce431561aef583ee1613527cff655c1343f29812e66706df3234");
604    let ikm_r = from_hex("6db9df30aa07dd42ee5e8181afdb977e538f5e1fec8a06223f33f7013e525037");
605    let expected_sk_e =
606        from_hex("52c4a758a802cd8b936eceea314432798d5baf2d7e9235dc084ab1b9cfa2f736");
607    let expected_pk_e =
608        from_hex("37fda3567bdbd628e88668c3c8d7e97d1d1253b6d4ea6d44c150f741f1bf4431");
609    let expected_sk_r =
610        from_hex("4612c550263fc8ad58375df3f557aac531d26850903e55a9f23f21d8534e8ac8");
611    let expected_pk_r =
612        from_hex("3948cfe0ad1ddb695d780e59077195da6c56506b027329794ab02bca80815c4d");
613
614    let (sk_e, pk_e) =
615        DeriveKeyPair(KEM::DHKEM_X25519_HKDF_SHA256, &ikm_e).expect("Error deriving key pair");
616    let (sk_r, pk_r) =
617        DeriveKeyPair(KEM::DHKEM_X25519_HKDF_SHA256, &ikm_r).expect("Error deriving key pair");
618
619    assert_eq!(expected_sk_e, sk_e);
620    assert_eq!(expected_sk_r, sk_r);
621    assert_eq!(expected_pk_e, pk_e);
622    assert_eq!(expected_pk_r, pk_r);
623}