rustls/msgs/
handshake.rs

1use alloc::collections::BTreeSet;
2#[cfg(feature = "logging")]
3use alloc::string::String;
4use alloc::vec;
5use alloc::vec::Vec;
6use core::ops::Deref;
7use core::{fmt, iter};
8
9use pki_types::{CertificateDer, DnsName};
10
11#[cfg(feature = "tls12")]
12use crate::crypto::ActiveKeyExchange;
13use crate::crypto::SecureRandom;
14use crate::enums::{
15    CertificateCompressionAlgorithm, CertificateType, CipherSuite, EchClientHelloType,
16    HandshakeType, ProtocolVersion, SignatureScheme,
17};
18use crate::error::InvalidMessage;
19#[cfg(feature = "tls12")]
20use crate::ffdhe_groups::FfdheGroup;
21use crate::log::warn;
22use crate::msgs::base::{MaybeEmpty, NonEmpty, Payload, PayloadU8, PayloadU16, PayloadU24};
23use crate::msgs::codec::{self, Codec, LengthPrefixedBuffer, ListLength, Reader, TlsListElement};
24use crate::msgs::enums::{
25    CertificateStatusType, ClientCertificateType, Compression, ECCurveType, ECPointFormat,
26    EchVersion, ExtensionType, HpkeAead, HpkeKdf, HpkeKem, KeyUpdateRequest, NamedGroup,
27    PskKeyExchangeMode, ServerNameType,
28};
29use crate::rand;
30use crate::sync::Arc;
31use crate::verify::DigitallySignedStruct;
32use crate::x509::wrap_in_sequence;
33
34/// Create a newtype wrapper around a given type.
35///
36/// This is used to create newtypes for the various TLS message types which is used to wrap
37/// the `PayloadU8` or `PayloadU16` types. This is typically used for types where we don't need
38/// anything other than access to the underlying bytes.
39macro_rules! wrapped_payload(
40  ($(#[$comment:meta])* $vis:vis struct $name:ident, $inner:ident$(<$inner_ty:ty>)?,) => {
41    $(#[$comment])*
42    #[derive(Clone, Debug)]
43    $vis struct $name($inner$(<$inner_ty>)?);
44
45    impl From<Vec<u8>> for $name {
46        fn from(v: Vec<u8>) -> Self {
47            Self($inner::new(v))
48        }
49    }
50
51    impl AsRef<[u8]> for $name {
52        fn as_ref(&self) -> &[u8] {
53            self.0.0.as_slice()
54        }
55    }
56
57    impl Codec<'_> for $name {
58        fn encode(&self, bytes: &mut Vec<u8>) {
59            self.0.encode(bytes);
60        }
61
62        fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
63            Ok(Self($inner::read(r)?))
64        }
65    }
66  }
67);
68
69#[derive(Clone, Copy, Eq, PartialEq)]
70pub(crate) struct Random(pub(crate) [u8; 32]);
71
72impl fmt::Debug for Random {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        super::base::hex(f, &self.0)
75    }
76}
77
78static HELLO_RETRY_REQUEST_RANDOM: Random = Random([
79    0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11, 0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91,
80    0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e, 0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c,
81]);
82
83static ZERO_RANDOM: Random = Random([0u8; 32]);
84
85impl Codec<'_> for Random {
86    fn encode(&self, bytes: &mut Vec<u8>) {
87        bytes.extend_from_slice(&self.0);
88    }
89
90    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
91        let Some(bytes) = r.take(32) else {
92            return Err(InvalidMessage::MissingData("Random"));
93        };
94
95        let mut opaque = [0; 32];
96        opaque.clone_from_slice(bytes);
97        Ok(Self(opaque))
98    }
99}
100
101impl Random {
102    pub(crate) fn new(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
103        let mut data = [0u8; 32];
104        secure_random.fill(&mut data)?;
105        Ok(Self(data))
106    }
107}
108
109impl From<[u8; 32]> for Random {
110    #[inline]
111    fn from(bytes: [u8; 32]) -> Self {
112        Self(bytes)
113    }
114}
115
116#[derive(Copy, Clone)]
117pub(crate) struct SessionId {
118    len: usize,
119    data: [u8; 32],
120}
121
122impl fmt::Debug for SessionId {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        super::base::hex(f, &self.data[..self.len])
125    }
126}
127
128impl PartialEq for SessionId {
129    fn eq(&self, other: &Self) -> bool {
130        if self.len != other.len {
131            return false;
132        }
133
134        let mut diff = 0u8;
135        for i in 0..self.len {
136            diff |= self.data[i] ^ other.data[i];
137        }
138
139        diff == 0u8
140    }
141}
142
143impl Codec<'_> for SessionId {
144    fn encode(&self, bytes: &mut Vec<u8>) {
145        debug_assert!(self.len <= 32);
146        bytes.push(self.len as u8);
147        bytes.extend_from_slice(self.as_ref());
148    }
149
150    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
151        let len = u8::read(r)? as usize;
152        if len > 32 {
153            return Err(InvalidMessage::TrailingData("SessionID"));
154        }
155
156        let Some(bytes) = r.take(len) else {
157            return Err(InvalidMessage::MissingData("SessionID"));
158        };
159
160        let mut out = [0u8; 32];
161        out[..len].clone_from_slice(&bytes[..len]);
162        Ok(Self { data: out, len })
163    }
164}
165
166impl SessionId {
167    pub(crate) fn random(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
168        let mut data = [0u8; 32];
169        secure_random.fill(&mut data)?;
170        Ok(Self { data, len: 32 })
171    }
172
173    pub(crate) fn empty() -> Self {
174        Self {
175            data: [0u8; 32],
176            len: 0,
177        }
178    }
179
180    #[cfg(feature = "tls12")]
181    pub(crate) fn is_empty(&self) -> bool {
182        self.len == 0
183    }
184}
185
186impl AsRef<[u8]> for SessionId {
187    fn as_ref(&self) -> &[u8] {
188        &self.data[..self.len]
189    }
190}
191
192#[derive(Clone, Debug, PartialEq)]
193pub struct UnknownExtension {
194    pub(crate) typ: ExtensionType,
195    pub(crate) payload: Payload<'static>,
196}
197
198impl UnknownExtension {
199    fn encode(&self, bytes: &mut Vec<u8>) {
200        self.payload.encode(bytes);
201    }
202
203    fn read(typ: ExtensionType, r: &mut Reader<'_>) -> Self {
204        let payload = Payload::read(r).into_owned();
205        Self { typ, payload }
206    }
207}
208
209/// RFC8422: `ECPointFormat ec_point_format_list<1..2^8-1>`
210impl TlsListElement for ECPointFormat {
211    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
212        empty_error: InvalidMessage::IllegalEmptyList("ECPointFormats"),
213    };
214}
215
216/// RFC8422: `NamedCurve named_curve_list<2..2^16-1>`
217impl TlsListElement for NamedGroup {
218    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
219        empty_error: InvalidMessage::IllegalEmptyList("NamedGroups"),
220    };
221}
222
223/// RFC8446: `SignatureScheme supported_signature_algorithms<2..2^16-2>;`
224impl TlsListElement for SignatureScheme {
225    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
226        empty_error: InvalidMessage::NoSignatureSchemes,
227    };
228}
229
230#[derive(Clone, Debug)]
231pub(crate) enum ServerNamePayload<'a> {
232    /// A successfully decoded value:
233    SingleDnsName(DnsName<'a>),
234
235    /// A DNS name which was actually an IP address
236    IpAddress,
237
238    /// A successfully decoded, but syntactically-invalid value.
239    Invalid,
240}
241
242impl ServerNamePayload<'_> {
243    fn into_owned(self) -> ServerNamePayload<'static> {
244        match self {
245            Self::SingleDnsName(d) => ServerNamePayload::SingleDnsName(d.to_owned()),
246            Self::IpAddress => ServerNamePayload::IpAddress,
247            Self::Invalid => ServerNamePayload::Invalid,
248        }
249    }
250
251    /// RFC6066: `ServerName server_name_list<1..2^16-1>`
252    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
253        empty_error: InvalidMessage::IllegalEmptyList("ServerNames"),
254    };
255}
256
257/// Simplified encoding/decoding for a `ServerName` extension payload to/from `DnsName`
258///
259/// This is possible because:
260///
261/// - the spec (RFC6066) disallows multiple names for a given name type
262/// - name types other than ServerNameType::HostName are not defined, and they and
263///   any data that follows them cannot be skipped over.
264impl<'a> Codec<'a> for ServerNamePayload<'a> {
265    fn encode(&self, bytes: &mut Vec<u8>) {
266        let server_name_list = LengthPrefixedBuffer::new(Self::SIZE_LEN, bytes);
267
268        let ServerNamePayload::SingleDnsName(dns_name) = self else {
269            return;
270        };
271
272        ServerNameType::HostName.encode(server_name_list.buf);
273        let name_slice = dns_name.as_ref().as_bytes();
274        (name_slice.len() as u16).encode(server_name_list.buf);
275        server_name_list
276            .buf
277            .extend_from_slice(name_slice);
278    }
279
280    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
281        let mut found = None;
282
283        let len = Self::SIZE_LEN.read(r)?;
284        let mut sub = r.sub(len)?;
285
286        while sub.any_left() {
287            let typ = ServerNameType::read(&mut sub)?;
288
289            let payload = match typ {
290                ServerNameType::HostName => HostNamePayload::read(&mut sub)?,
291                _ => {
292                    // Consume remainder of extension bytes.  Since the length of the item
293                    // is an unknown encoding, we cannot continue.
294                    sub.rest();
295                    break;
296                }
297            };
298
299            // "The ServerNameList MUST NOT contain more than one name of
300            // the same name_type." - RFC6066
301            if found.is_some() {
302                warn!("Illegal SNI extension: duplicate host_name received");
303                return Err(InvalidMessage::InvalidServerName);
304            }
305
306            found = match payload {
307                HostNamePayload::HostName(dns_name) => {
308                    Some(Self::SingleDnsName(dns_name.to_owned()))
309                }
310
311                HostNamePayload::IpAddress(_invalid) => {
312                    warn!(
313                        "Illegal SNI extension: ignoring IP address presented as hostname ({_invalid:?})"
314                    );
315                    Some(Self::IpAddress)
316                }
317
318                HostNamePayload::Invalid(_invalid) => {
319                    warn!(
320                        "Illegal SNI hostname received {:?}",
321                        String::from_utf8_lossy(&_invalid.0)
322                    );
323                    Some(Self::Invalid)
324                }
325            };
326        }
327
328        Ok(found.unwrap_or(Self::Invalid))
329    }
330}
331
332impl<'a> From<&DnsName<'a>> for ServerNamePayload<'static> {
333    fn from(value: &DnsName<'a>) -> Self {
334        Self::SingleDnsName(trim_hostname_trailing_dot_for_sni(value))
335    }
336}
337
338#[derive(Clone, Debug)]
339pub(crate) enum HostNamePayload {
340    HostName(DnsName<'static>),
341    IpAddress(PayloadU16<NonEmpty>),
342    Invalid(PayloadU16<NonEmpty>),
343}
344
345impl HostNamePayload {
346    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
347        use pki_types::ServerName;
348        let raw = PayloadU16::<NonEmpty>::read(r)?;
349
350        match ServerName::try_from(raw.0.as_slice()) {
351            Ok(ServerName::DnsName(d)) => Ok(Self::HostName(d.to_owned())),
352            Ok(ServerName::IpAddress(_)) => Ok(Self::IpAddress(raw)),
353            Ok(_) | Err(_) => Ok(Self::Invalid(raw)),
354        }
355    }
356}
357
358wrapped_payload!(
359    /// RFC7301: `opaque ProtocolName<1..2^8-1>;`
360    pub(crate) struct ProtocolName, PayloadU8<NonEmpty>,
361);
362
363impl PartialEq for ProtocolName {
364    fn eq(&self, other: &Self) -> bool {
365        self.0 == other.0
366    }
367}
368
369impl Deref for ProtocolName {
370    type Target = [u8];
371
372    fn deref(&self) -> &Self::Target {
373        self.as_ref()
374    }
375}
376
377/// RFC7301: `ProtocolName protocol_name_list<2..2^16-1>`
378impl TlsListElement for ProtocolName {
379    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
380        empty_error: InvalidMessage::IllegalEmptyList("ProtocolNames"),
381    };
382}
383
384/// RFC7301 encodes a single protocol name as `Vec<ProtocolName>`
385#[derive(Clone, Debug)]
386pub(crate) struct SingleProtocolName(ProtocolName);
387
388impl SingleProtocolName {
389    pub(crate) fn new(single: ProtocolName) -> Self {
390        Self(single)
391    }
392
393    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
394        empty_error: InvalidMessage::IllegalEmptyList("ProtocolNames"),
395    };
396}
397
398impl Codec<'_> for SingleProtocolName {
399    fn encode(&self, bytes: &mut Vec<u8>) {
400        let body = LengthPrefixedBuffer::new(Self::SIZE_LEN, bytes);
401        self.0.encode(body.buf);
402    }
403
404    fn read(reader: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
405        let len = Self::SIZE_LEN.read(reader)?;
406        let mut sub = reader.sub(len)?;
407
408        let item = ProtocolName::read(&mut sub)?;
409
410        if sub.any_left() {
411            Err(InvalidMessage::TrailingData("SingleProtocolName"))
412        } else {
413            Ok(Self(item))
414        }
415    }
416}
417
418impl AsRef<[u8]> for SingleProtocolName {
419    fn as_ref(&self) -> &[u8] {
420        self.0.as_ref()
421    }
422}
423
424// --- TLS 1.3 Key shares ---
425#[derive(Clone, Debug)]
426pub(crate) struct KeyShareEntry {
427    pub(crate) group: NamedGroup,
428    /// RFC8446: `opaque key_exchange<1..2^16-1>;`
429    pub(crate) payload: PayloadU16<NonEmpty>,
430}
431
432impl KeyShareEntry {
433    pub(crate) fn new(group: NamedGroup, payload: impl Into<Vec<u8>>) -> Self {
434        Self {
435            group,
436            payload: PayloadU16::new(payload.into()),
437        }
438    }
439}
440
441impl Codec<'_> for KeyShareEntry {
442    fn encode(&self, bytes: &mut Vec<u8>) {
443        self.group.encode(bytes);
444        self.payload.encode(bytes);
445    }
446
447    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
448        let group = NamedGroup::read(r)?;
449        let payload = PayloadU16::read(r)?;
450
451        Ok(Self { group, payload })
452    }
453}
454
455// --- TLS 1.3 PresharedKey offers ---
456#[derive(Clone, Debug)]
457pub(crate) struct PresharedKeyIdentity {
458    /// RFC8446: `opaque identity<1..2^16-1>;`
459    pub(crate) identity: PayloadU16<NonEmpty>,
460    pub(crate) obfuscated_ticket_age: u32,
461}
462
463impl PresharedKeyIdentity {
464    pub(crate) fn new(id: Vec<u8>, age: u32) -> Self {
465        Self {
466            identity: PayloadU16::new(id),
467            obfuscated_ticket_age: age,
468        }
469    }
470}
471
472impl Codec<'_> for PresharedKeyIdentity {
473    fn encode(&self, bytes: &mut Vec<u8>) {
474        self.identity.encode(bytes);
475        self.obfuscated_ticket_age.encode(bytes);
476    }
477
478    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
479        Ok(Self {
480            identity: PayloadU16::read(r)?,
481            obfuscated_ticket_age: u32::read(r)?,
482        })
483    }
484}
485
486/// RFC8446: `PskIdentity identities<7..2^16-1>;`
487impl TlsListElement for PresharedKeyIdentity {
488    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
489        empty_error: InvalidMessage::IllegalEmptyList("PskIdentities"),
490    };
491}
492
493wrapped_payload!(
494    /// RFC8446: `opaque PskBinderEntry<32..255>;`
495    pub(crate) struct PresharedKeyBinder, PayloadU8<NonEmpty>,
496);
497
498/// RFC8446: `PskBinderEntry binders<33..2^16-1>;`
499impl TlsListElement for PresharedKeyBinder {
500    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
501        empty_error: InvalidMessage::IllegalEmptyList("PskBinders"),
502    };
503}
504
505#[derive(Clone, Debug)]
506pub(crate) struct PresharedKeyOffer {
507    pub(crate) identities: Vec<PresharedKeyIdentity>,
508    pub(crate) binders: Vec<PresharedKeyBinder>,
509}
510
511impl PresharedKeyOffer {
512    /// Make a new one with one entry.
513    pub(crate) fn new(id: PresharedKeyIdentity, binder: Vec<u8>) -> Self {
514        Self {
515            identities: vec![id],
516            binders: vec![PresharedKeyBinder::from(binder)],
517        }
518    }
519}
520
521impl Codec<'_> for PresharedKeyOffer {
522    fn encode(&self, bytes: &mut Vec<u8>) {
523        self.identities.encode(bytes);
524        self.binders.encode(bytes);
525    }
526
527    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
528        Ok(Self {
529            identities: Vec::read(r)?,
530            binders: Vec::read(r)?,
531        })
532    }
533}
534
535// --- RFC6066 certificate status request ---
536wrapped_payload!(pub(crate) struct ResponderId, PayloadU16,);
537
538/// RFC6066: `ResponderID responder_id_list<0..2^16-1>;`
539impl TlsListElement for ResponderId {
540    const SIZE_LEN: ListLength = ListLength::U16;
541}
542
543#[derive(Clone, Debug)]
544pub(crate) struct OcspCertificateStatusRequest {
545    pub(crate) responder_ids: Vec<ResponderId>,
546    pub(crate) extensions: PayloadU16,
547}
548
549impl Codec<'_> for OcspCertificateStatusRequest {
550    fn encode(&self, bytes: &mut Vec<u8>) {
551        CertificateStatusType::OCSP.encode(bytes);
552        self.responder_ids.encode(bytes);
553        self.extensions.encode(bytes);
554    }
555
556    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
557        Ok(Self {
558            responder_ids: Vec::read(r)?,
559            extensions: PayloadU16::read(r)?,
560        })
561    }
562}
563
564#[derive(Clone, Debug)]
565pub(crate) enum CertificateStatusRequest {
566    Ocsp(OcspCertificateStatusRequest),
567    Unknown((CertificateStatusType, Payload<'static>)),
568}
569
570impl Codec<'_> for CertificateStatusRequest {
571    fn encode(&self, bytes: &mut Vec<u8>) {
572        match self {
573            Self::Ocsp(r) => r.encode(bytes),
574            Self::Unknown((typ, payload)) => {
575                typ.encode(bytes);
576                payload.encode(bytes);
577            }
578        }
579    }
580
581    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
582        let typ = CertificateStatusType::read(r)?;
583
584        match typ {
585            CertificateStatusType::OCSP => {
586                let ocsp_req = OcspCertificateStatusRequest::read(r)?;
587                Ok(Self::Ocsp(ocsp_req))
588            }
589            _ => {
590                let data = Payload::read(r).into_owned();
591                Ok(Self::Unknown((typ, data)))
592            }
593        }
594    }
595}
596
597impl CertificateStatusRequest {
598    pub(crate) fn build_ocsp() -> Self {
599        let ocsp = OcspCertificateStatusRequest {
600            responder_ids: Vec::new(),
601            extensions: PayloadU16::empty(),
602        };
603        Self::Ocsp(ocsp)
604    }
605}
606
607// ---
608
609/// RFC8446: `PskKeyExchangeMode ke_modes<1..255>;`
610impl TlsListElement for PskKeyExchangeMode {
611    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
612        empty_error: InvalidMessage::IllegalEmptyList("PskKeyExchangeModes"),
613    };
614}
615
616/// RFC8446: `KeyShareEntry client_shares<0..2^16-1>;`
617impl TlsListElement for KeyShareEntry {
618    const SIZE_LEN: ListLength = ListLength::U16;
619}
620
621/// The body of the `SupportedVersions` extension when it appears in a
622/// `ClientHello`
623///
624/// This is documented as a preference-order vector, but we (as a server)
625/// ignore the preference of the client.
626///
627/// RFC8446: `ProtocolVersion versions<2..254>;`
628#[derive(Clone, Copy, Debug, Default)]
629pub(crate) struct SupportedProtocolVersions {
630    pub(crate) tls13: bool,
631    pub(crate) tls12: bool,
632}
633
634impl SupportedProtocolVersions {
635    /// Return true if `filter` returns true for any enabled version.
636    pub(crate) fn any(&self, filter: impl Fn(ProtocolVersion) -> bool) -> bool {
637        if self.tls13 && filter(ProtocolVersion::TLSv1_3) {
638            return true;
639        }
640        if self.tls12 && filter(ProtocolVersion::TLSv1_2) {
641            return true;
642        }
643        false
644    }
645
646    const LIST_LENGTH: ListLength = ListLength::NonZeroU8 {
647        empty_error: InvalidMessage::IllegalEmptyList("ProtocolVersions"),
648    };
649}
650
651impl Codec<'_> for SupportedProtocolVersions {
652    fn encode(&self, bytes: &mut Vec<u8>) {
653        let inner = LengthPrefixedBuffer::new(Self::LIST_LENGTH, bytes);
654        if self.tls13 {
655            ProtocolVersion::TLSv1_3.encode(inner.buf);
656        }
657        if self.tls12 {
658            ProtocolVersion::TLSv1_2.encode(inner.buf);
659        }
660    }
661
662    fn read(reader: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
663        let len = Self::LIST_LENGTH.read(reader)?;
664        let mut sub = reader.sub(len)?;
665
666        let mut tls12 = false;
667        let mut tls13 = false;
668
669        while sub.any_left() {
670            match ProtocolVersion::read(&mut sub)? {
671                ProtocolVersion::TLSv1_3 => tls13 = true,
672                ProtocolVersion::TLSv1_2 => tls12 = true,
673                _ => continue,
674            };
675        }
676
677        Ok(Self { tls13, tls12 })
678    }
679}
680
681impl TlsListElement for ProtocolVersion {
682    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
683        empty_error: InvalidMessage::IllegalEmptyList("ProtocolVersions"),
684    };
685}
686
687/// RFC7250: `CertificateType client_certificate_types<1..2^8-1>;`
688///
689/// Ditto `CertificateType server_certificate_types<1..2^8-1>;`
690impl TlsListElement for CertificateType {
691    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
692        empty_error: InvalidMessage::IllegalEmptyList("CertificateTypes"),
693    };
694}
695
696/// RFC8879: `CertificateCompressionAlgorithm algorithms<2..2^8-2>;`
697impl TlsListElement for CertificateCompressionAlgorithm {
698    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
699        empty_error: InvalidMessage::IllegalEmptyList("CertificateCompressionAlgorithms"),
700    };
701}
702
703#[derive(Clone, Debug)]
704pub(crate) enum ClientExtension {
705    EcPointFormats(Vec<ECPointFormat>),
706    NamedGroups(Vec<NamedGroup>),
707    SignatureAlgorithms(Vec<SignatureScheme>),
708    ServerName(ServerNamePayload<'static>),
709    SessionTicket(ClientSessionTicket),
710    Protocols(Vec<ProtocolName>),
711    SupportedVersions(SupportedProtocolVersions),
712    KeyShare(Vec<KeyShareEntry>),
713    PresharedKeyModes(Vec<PskKeyExchangeMode>),
714    PresharedKey(PresharedKeyOffer),
715    Cookie(PayloadU16<NonEmpty>),
716    ExtendedMasterSecretRequest,
717    CertificateStatusRequest(CertificateStatusRequest),
718    ServerCertTypes(Vec<CertificateType>),
719    ClientCertTypes(Vec<CertificateType>),
720    TransportParameters(Vec<u8>),
721    TransportParametersDraft(Vec<u8>),
722    EarlyData,
723    CertificateCompressionAlgorithms(Vec<CertificateCompressionAlgorithm>),
724    EncryptedClientHello(EncryptedClientHello),
725    EncryptedClientHelloOuterExtensions(Vec<ExtensionType>),
726    AuthorityNames(Vec<DistinguishedName>),
727    Unknown(UnknownExtension),
728}
729
730impl ClientExtension {
731    pub(crate) fn ext_type(&self) -> ExtensionType {
732        match self {
733            Self::EcPointFormats(_) => ExtensionType::ECPointFormats,
734            Self::NamedGroups(_) => ExtensionType::EllipticCurves,
735            Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms,
736            Self::ServerName(_) => ExtensionType::ServerName,
737            Self::SessionTicket(_) => ExtensionType::SessionTicket,
738            Self::Protocols(_) => ExtensionType::ALProtocolNegotiation,
739            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
740            Self::KeyShare(_) => ExtensionType::KeyShare,
741            Self::PresharedKeyModes(_) => ExtensionType::PSKKeyExchangeModes,
742            Self::PresharedKey(_) => ExtensionType::PreSharedKey,
743            Self::Cookie(_) => ExtensionType::Cookie,
744            Self::ExtendedMasterSecretRequest => ExtensionType::ExtendedMasterSecret,
745            Self::CertificateStatusRequest(_) => ExtensionType::StatusRequest,
746            Self::ClientCertTypes(_) => ExtensionType::ClientCertificateType,
747            Self::ServerCertTypes(_) => ExtensionType::ServerCertificateType,
748            Self::TransportParameters(_) => ExtensionType::TransportParameters,
749            Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft,
750            Self::EarlyData => ExtensionType::EarlyData,
751            Self::CertificateCompressionAlgorithms(_) => ExtensionType::CompressCertificate,
752            Self::EncryptedClientHello(_) => ExtensionType::EncryptedClientHello,
753            Self::EncryptedClientHelloOuterExtensions(_) => {
754                ExtensionType::EncryptedClientHelloOuterExtensions
755            }
756            Self::AuthorityNames(_) => ExtensionType::CertificateAuthorities,
757            Self::Unknown(r) => r.typ,
758        }
759    }
760}
761
762impl Codec<'_> for ClientExtension {
763    fn encode(&self, bytes: &mut Vec<u8>) {
764        self.ext_type().encode(bytes);
765
766        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
767        match self {
768            Self::EcPointFormats(r) => r.encode(nested.buf),
769            Self::NamedGroups(r) => r.encode(nested.buf),
770            Self::SignatureAlgorithms(r) => r.encode(nested.buf),
771            Self::ServerName(r) => r.encode(nested.buf),
772            Self::SessionTicket(r) => r.encode(nested.buf),
773            Self::ExtendedMasterSecretRequest | Self::EarlyData => {}
774            Self::Protocols(r) => r.encode(nested.buf),
775            Self::SupportedVersions(r) => r.encode(nested.buf),
776            Self::KeyShare(r) => r.encode(nested.buf),
777            Self::PresharedKeyModes(r) => r.encode(nested.buf),
778            Self::PresharedKey(r) => r.encode(nested.buf),
779            Self::Cookie(r) => r.encode(nested.buf),
780            Self::CertificateStatusRequest(r) => r.encode(nested.buf),
781            Self::ClientCertTypes(r) => r.encode(nested.buf),
782            Self::ServerCertTypes(r) => r.encode(nested.buf),
783            Self::TransportParameters(r) | Self::TransportParametersDraft(r) => {
784                nested.buf.extend_from_slice(r);
785            }
786            Self::CertificateCompressionAlgorithms(r) => r.encode(nested.buf),
787            Self::EncryptedClientHello(r) => r.encode(nested.buf),
788            Self::EncryptedClientHelloOuterExtensions(r) => r.encode(nested.buf),
789            Self::AuthorityNames(r) => r.encode(nested.buf),
790            Self::Unknown(r) => r.encode(nested.buf),
791        }
792    }
793
794    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
795        let typ = ExtensionType::read(r)?;
796        let len = u16::read(r)? as usize;
797        let mut sub = r.sub(len)?;
798
799        let ext = match typ {
800            ExtensionType::ECPointFormats => Self::EcPointFormats(Vec::read(&mut sub)?),
801            ExtensionType::EllipticCurves => Self::NamedGroups(Vec::read(&mut sub)?),
802            ExtensionType::SignatureAlgorithms => Self::SignatureAlgorithms(Vec::read(&mut sub)?),
803            ExtensionType::ServerName => {
804                Self::ServerName(ServerNamePayload::read(&mut sub)?.into_owned())
805            }
806            ExtensionType::SessionTicket => {
807                Self::SessionTicket(ClientSessionTicket::read(&mut sub)?)
808            }
809            ExtensionType::ALProtocolNegotiation => Self::Protocols(Vec::read(&mut sub)?),
810            ExtensionType::SupportedVersions => {
811                Self::SupportedVersions(SupportedProtocolVersions::read(&mut sub)?)
812            }
813            ExtensionType::KeyShare => Self::KeyShare(Vec::read(&mut sub)?),
814            ExtensionType::PSKKeyExchangeModes => Self::PresharedKeyModes(Vec::read(&mut sub)?),
815            ExtensionType::PreSharedKey => Self::PresharedKey(PresharedKeyOffer::read(&mut sub)?),
816            ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?),
817            ExtensionType::ExtendedMasterSecret if !sub.any_left() => {
818                Self::ExtendedMasterSecretRequest
819            }
820            ExtensionType::ClientCertificateType => Self::ClientCertTypes(Vec::read(&mut sub)?),
821            ExtensionType::ServerCertificateType => Self::ServerCertTypes(Vec::read(&mut sub)?),
822            ExtensionType::StatusRequest => {
823                let csr = CertificateStatusRequest::read(&mut sub)?;
824                Self::CertificateStatusRequest(csr)
825            }
826            ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()),
827            ExtensionType::TransportParametersDraft => {
828                Self::TransportParametersDraft(sub.rest().to_vec())
829            }
830            ExtensionType::EarlyData if !sub.any_left() => Self::EarlyData,
831            ExtensionType::CompressCertificate => {
832                Self::CertificateCompressionAlgorithms(Vec::read(&mut sub)?)
833            }
834            ExtensionType::EncryptedClientHelloOuterExtensions => {
835                Self::EncryptedClientHelloOuterExtensions(Vec::read(&mut sub)?)
836            }
837            ExtensionType::CertificateAuthorities => Self::AuthorityNames({
838                let items = Vec::read(&mut sub)?;
839                if items.is_empty() {
840                    return Err(InvalidMessage::IllegalEmptyList("DistinguishedNames"));
841                }
842                items
843            }),
844            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
845        };
846
847        sub.expect_empty("ClientExtension")
848            .map(|_| ext)
849    }
850}
851
852fn trim_hostname_trailing_dot_for_sni(dns_name: &DnsName<'_>) -> DnsName<'static> {
853    let dns_name_str = dns_name.as_ref();
854
855    // RFC6066: "The hostname is represented as a byte string using
856    // ASCII encoding without a trailing dot"
857    if dns_name_str.ends_with('.') {
858        let trimmed = &dns_name_str[0..dns_name_str.len() - 1];
859        DnsName::try_from(trimmed)
860            .unwrap()
861            .to_owned()
862    } else {
863        dns_name.to_owned()
864    }
865}
866
867#[derive(Clone, Debug)]
868pub(crate) enum ClientSessionTicket {
869    Request,
870    Offer(Payload<'static>),
871}
872
873impl<'a> Codec<'a> for ClientSessionTicket {
874    fn encode(&self, bytes: &mut Vec<u8>) {
875        match self {
876            Self::Request => (),
877            Self::Offer(p) => p.encode(bytes),
878        }
879    }
880
881    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
882        Ok(match r.left() {
883            0 => Self::Request,
884            _ => Self::Offer(Payload::read(r).into_owned()),
885        })
886    }
887}
888
889#[derive(Clone, Debug)]
890pub(crate) enum ServerExtension {
891    EcPointFormats(Vec<ECPointFormat>),
892    ServerNameAck,
893    SessionTicketAck,
894    RenegotiationInfo(PayloadU8),
895    Protocols(SingleProtocolName),
896    KeyShare(KeyShareEntry),
897    PresharedKey(u16),
898    ExtendedMasterSecretAck,
899    CertificateStatusAck,
900    ServerCertType(CertificateType),
901    ClientCertType(CertificateType),
902    SupportedVersions(ProtocolVersion),
903    TransportParameters(Vec<u8>),
904    TransportParametersDraft(Vec<u8>),
905    EarlyData,
906    EncryptedClientHello(ServerEncryptedClientHello),
907    Unknown(UnknownExtension),
908}
909
910impl ServerExtension {
911    pub(crate) fn ext_type(&self) -> ExtensionType {
912        match self {
913            Self::EcPointFormats(_) => ExtensionType::ECPointFormats,
914            Self::ServerNameAck => ExtensionType::ServerName,
915            Self::SessionTicketAck => ExtensionType::SessionTicket,
916            Self::RenegotiationInfo(_) => ExtensionType::RenegotiationInfo,
917            Self::Protocols(_) => ExtensionType::ALProtocolNegotiation,
918            Self::KeyShare(_) => ExtensionType::KeyShare,
919            Self::PresharedKey(_) => ExtensionType::PreSharedKey,
920            Self::ClientCertType(_) => ExtensionType::ClientCertificateType,
921            Self::ServerCertType(_) => ExtensionType::ServerCertificateType,
922            Self::ExtendedMasterSecretAck => ExtensionType::ExtendedMasterSecret,
923            Self::CertificateStatusAck => ExtensionType::StatusRequest,
924            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
925            Self::TransportParameters(_) => ExtensionType::TransportParameters,
926            Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft,
927            Self::EarlyData => ExtensionType::EarlyData,
928            Self::EncryptedClientHello(_) => ExtensionType::EncryptedClientHello,
929            Self::Unknown(r) => r.typ,
930        }
931    }
932}
933
934impl Codec<'_> for ServerExtension {
935    fn encode(&self, bytes: &mut Vec<u8>) {
936        self.ext_type().encode(bytes);
937
938        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
939        match self {
940            Self::EcPointFormats(r) => r.encode(nested.buf),
941            Self::ServerNameAck
942            | Self::SessionTicketAck
943            | Self::ExtendedMasterSecretAck
944            | Self::CertificateStatusAck
945            | Self::EarlyData => {}
946            Self::RenegotiationInfo(r) => r.encode(nested.buf),
947            Self::Protocols(r) => r.encode(nested.buf),
948            Self::KeyShare(r) => r.encode(nested.buf),
949            Self::PresharedKey(r) => r.encode(nested.buf),
950            Self::ClientCertType(r) => r.encode(nested.buf),
951            Self::ServerCertType(r) => r.encode(nested.buf),
952            Self::SupportedVersions(r) => r.encode(nested.buf),
953            Self::TransportParameters(r) | Self::TransportParametersDraft(r) => {
954                nested.buf.extend_from_slice(r);
955            }
956            Self::EncryptedClientHello(r) => r.encode(nested.buf),
957            Self::Unknown(r) => r.encode(nested.buf),
958        }
959    }
960
961    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
962        let typ = ExtensionType::read(r)?;
963        let len = u16::read(r)? as usize;
964        let mut sub = r.sub(len)?;
965
966        let ext = match typ {
967            ExtensionType::ECPointFormats => Self::EcPointFormats(Vec::read(&mut sub)?),
968            ExtensionType::ServerName => Self::ServerNameAck,
969            ExtensionType::SessionTicket => Self::SessionTicketAck,
970            ExtensionType::StatusRequest => Self::CertificateStatusAck,
971            ExtensionType::RenegotiationInfo => Self::RenegotiationInfo(PayloadU8::read(&mut sub)?),
972            ExtensionType::ALProtocolNegotiation => {
973                Self::Protocols(SingleProtocolName::read(&mut sub)?)
974            }
975            ExtensionType::ClientCertificateType => {
976                Self::ClientCertType(CertificateType::read(&mut sub)?)
977            }
978            ExtensionType::ServerCertificateType => {
979                Self::ServerCertType(CertificateType::read(&mut sub)?)
980            }
981            ExtensionType::KeyShare => Self::KeyShare(KeyShareEntry::read(&mut sub)?),
982            ExtensionType::PreSharedKey => Self::PresharedKey(u16::read(&mut sub)?),
983            ExtensionType::ExtendedMasterSecret => Self::ExtendedMasterSecretAck,
984            ExtensionType::SupportedVersions => {
985                Self::SupportedVersions(ProtocolVersion::read(&mut sub)?)
986            }
987            ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()),
988            ExtensionType::TransportParametersDraft => {
989                Self::TransportParametersDraft(sub.rest().to_vec())
990            }
991            ExtensionType::EarlyData => Self::EarlyData,
992            ExtensionType::EncryptedClientHello => {
993                Self::EncryptedClientHello(ServerEncryptedClientHello::read(&mut sub)?)
994            }
995            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
996        };
997
998        sub.expect_empty("ServerExtension")
999            .map(|_| ext)
1000    }
1001}
1002
1003impl ServerExtension {
1004    #[cfg(feature = "tls12")]
1005    pub(crate) fn make_empty_renegotiation_info() -> Self {
1006        let empty = Vec::new();
1007        Self::RenegotiationInfo(PayloadU8::new(empty))
1008    }
1009}
1010
1011#[derive(Clone, Debug)]
1012pub(crate) struct ClientHelloPayload {
1013    pub(crate) client_version: ProtocolVersion,
1014    pub(crate) random: Random,
1015    pub(crate) session_id: SessionId,
1016    pub(crate) cipher_suites: Vec<CipherSuite>,
1017    pub(crate) compression_methods: Vec<Compression>,
1018    pub(crate) extensions: Vec<ClientExtension>,
1019}
1020
1021impl ClientHelloPayload {
1022    pub(crate) fn ech_inner_encoding(&self, to_compress: Vec<ExtensionType>) -> Vec<u8> {
1023        let mut bytes = Vec::new();
1024        self.payload_encode(&mut bytes, Encoding::EchInnerHello { to_compress });
1025        bytes
1026    }
1027
1028    pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) {
1029        self.client_version.encode(bytes);
1030        self.random.encode(bytes);
1031
1032        match purpose {
1033            // SessionID is required to be empty in the encoded inner client hello.
1034            Encoding::EchInnerHello { .. } => SessionId::empty().encode(bytes),
1035            _ => self.session_id.encode(bytes),
1036        }
1037
1038        self.cipher_suites.encode(bytes);
1039        self.compression_methods.encode(bytes);
1040
1041        let to_compress = match purpose {
1042            // Compressed extensions must be replaced in the encoded inner client hello.
1043            Encoding::EchInnerHello { to_compress } if !to_compress.is_empty() => to_compress,
1044            _ => {
1045                if !self.extensions.is_empty() {
1046                    self.extensions.encode(bytes);
1047                }
1048                return;
1049            }
1050        };
1051
1052        // Safety: not empty check in match guard.
1053        let first_compressed_type = *to_compress.first().unwrap();
1054
1055        // Compressed extensions are in a contiguous range and must be replaced
1056        // with a marker extension.
1057        let compressed_start_idx = self
1058            .extensions
1059            .iter()
1060            .position(|ext| ext.ext_type() == first_compressed_type);
1061        let compressed_end_idx = compressed_start_idx.map(|start| start + to_compress.len());
1062        let marker_ext = ClientExtension::EncryptedClientHelloOuterExtensions(to_compress);
1063
1064        let exts = self
1065            .extensions
1066            .iter()
1067            .enumerate()
1068            .filter_map(|(i, ext)| {
1069                if Some(i) == compressed_start_idx {
1070                    Some(&marker_ext)
1071                } else if Some(i) > compressed_start_idx && Some(i) < compressed_end_idx {
1072                    None
1073                } else {
1074                    Some(ext)
1075                }
1076            });
1077
1078        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1079        for ext in exts {
1080            ext.encode(nested.buf);
1081        }
1082    }
1083
1084    /// Returns true if there is more than one extension of a given
1085    /// type.
1086    pub(crate) fn has_duplicate_extension(&self) -> bool {
1087        has_duplicates::<_, _, u16>(
1088            self.extensions
1089                .iter()
1090                .map(|ext| ext.ext_type()),
1091        )
1092    }
1093
1094    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&ClientExtension> {
1095        self.extensions
1096            .iter()
1097            .find(|x| x.ext_type() == ext)
1098    }
1099
1100    pub(crate) fn sni_extension(&self) -> Option<&ServerNamePayload<'_>> {
1101        let ext = self.find_extension(ExtensionType::ServerName)?;
1102        match ext {
1103            ClientExtension::ServerName(req) => Some(req),
1104            _ => None,
1105        }
1106    }
1107
1108    pub(crate) fn sigalgs_extension(&self) -> Option<&[SignatureScheme]> {
1109        let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?;
1110        match ext {
1111            ClientExtension::SignatureAlgorithms(req) => Some(req),
1112            _ => None,
1113        }
1114    }
1115
1116    pub(crate) fn namedgroups_extension(&self) -> Option<&[NamedGroup]> {
1117        let ext = self.find_extension(ExtensionType::EllipticCurves)?;
1118        match ext {
1119            ClientExtension::NamedGroups(req) => Some(req),
1120            _ => None,
1121        }
1122    }
1123
1124    #[cfg(feature = "tls12")]
1125    pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
1126        let ext = self.find_extension(ExtensionType::ECPointFormats)?;
1127        match ext {
1128            ClientExtension::EcPointFormats(req) => Some(req),
1129            _ => None,
1130        }
1131    }
1132
1133    pub(crate) fn server_certificate_extension(&self) -> Option<&[CertificateType]> {
1134        let ext = self.find_extension(ExtensionType::ServerCertificateType)?;
1135        match ext {
1136            ClientExtension::ServerCertTypes(req) => Some(req),
1137            _ => None,
1138        }
1139    }
1140
1141    pub(crate) fn client_certificate_extension(&self) -> Option<&[CertificateType]> {
1142        let ext = self.find_extension(ExtensionType::ClientCertificateType)?;
1143        match ext {
1144            ClientExtension::ClientCertTypes(req) => Some(req),
1145            _ => None,
1146        }
1147    }
1148
1149    pub(crate) fn alpn_extension(&self) -> Option<&Vec<ProtocolName>> {
1150        let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
1151        match ext {
1152            ClientExtension::Protocols(req) => Some(req),
1153            _ => None,
1154        }
1155    }
1156
1157    pub(crate) fn quic_params_extension(&self) -> Option<Vec<u8>> {
1158        let ext = self
1159            .find_extension(ExtensionType::TransportParameters)
1160            .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?;
1161        match ext {
1162            ClientExtension::TransportParameters(bytes)
1163            | ClientExtension::TransportParametersDraft(bytes) => Some(bytes.to_vec()),
1164            _ => None,
1165        }
1166    }
1167
1168    #[cfg(feature = "tls12")]
1169    pub(crate) fn ticket_extension(&self) -> Option<&ClientExtension> {
1170        self.find_extension(ExtensionType::SessionTicket)
1171    }
1172
1173    pub(crate) fn versions_extension(&self) -> Option<SupportedProtocolVersions> {
1174        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1175        match ext {
1176            ClientExtension::SupportedVersions(vers) => Some(*vers),
1177            _ => None,
1178        }
1179    }
1180
1181    pub(crate) fn keyshare_extension(&self) -> Option<&[KeyShareEntry]> {
1182        let ext = self.find_extension(ExtensionType::KeyShare)?;
1183        match ext {
1184            ClientExtension::KeyShare(shares) => Some(shares),
1185            _ => None,
1186        }
1187    }
1188
1189    pub(crate) fn has_keyshare_extension_with_duplicates(&self) -> bool {
1190        self.keyshare_extension()
1191            .map(|entries| {
1192                has_duplicates::<_, _, u16>(
1193                    entries
1194                        .iter()
1195                        .map(|kse| u16::from(kse.group)),
1196                )
1197            })
1198            .unwrap_or_default()
1199    }
1200
1201    pub(crate) fn psk(&self) -> Option<&PresharedKeyOffer> {
1202        let ext = self.find_extension(ExtensionType::PreSharedKey)?;
1203        match ext {
1204            ClientExtension::PresharedKey(psk) => Some(psk),
1205            _ => None,
1206        }
1207    }
1208
1209    pub(crate) fn check_psk_ext_is_last(&self) -> bool {
1210        self.extensions
1211            .last()
1212            .is_some_and(|ext| ext.ext_type() == ExtensionType::PreSharedKey)
1213    }
1214
1215    pub(crate) fn psk_modes(&self) -> Option<&[PskKeyExchangeMode]> {
1216        let ext = self.find_extension(ExtensionType::PSKKeyExchangeModes)?;
1217        match ext {
1218            ClientExtension::PresharedKeyModes(psk_modes) => Some(psk_modes),
1219            _ => None,
1220        }
1221    }
1222
1223    pub(crate) fn psk_mode_offered(&self, mode: PskKeyExchangeMode) -> bool {
1224        self.psk_modes()
1225            .map(|modes| modes.contains(&mode))
1226            .unwrap_or(false)
1227    }
1228
1229    pub(crate) fn set_psk_binder(&mut self, binder: impl Into<Vec<u8>>) {
1230        let last_extension = self.extensions.last_mut();
1231        if let Some(ClientExtension::PresharedKey(offer)) = last_extension {
1232            offer.binders[0] = PresharedKeyBinder::from(binder.into());
1233        }
1234    }
1235
1236    #[cfg(feature = "tls12")]
1237    pub(crate) fn ems_support_offered(&self) -> bool {
1238        self.find_extension(ExtensionType::ExtendedMasterSecret)
1239            .is_some()
1240    }
1241
1242    pub(crate) fn early_data_extension_offered(&self) -> bool {
1243        self.find_extension(ExtensionType::EarlyData)
1244            .is_some()
1245    }
1246
1247    pub(crate) fn certificate_compression_extension(
1248        &self,
1249    ) -> Option<&[CertificateCompressionAlgorithm]> {
1250        let ext = self.find_extension(ExtensionType::CompressCertificate)?;
1251        match ext {
1252            ClientExtension::CertificateCompressionAlgorithms(algs) => Some(algs),
1253            _ => None,
1254        }
1255    }
1256
1257    pub(crate) fn has_certificate_compression_extension_with_duplicates(&self) -> bool {
1258        if let Some(algs) = self.certificate_compression_extension() {
1259            has_duplicates::<_, _, u16>(algs.iter().cloned())
1260        } else {
1261            false
1262        }
1263    }
1264
1265    pub(crate) fn certificate_authorities_extension(&self) -> Option<&[DistinguishedName]> {
1266        match self.find_extension(ExtensionType::CertificateAuthorities)? {
1267            ClientExtension::AuthorityNames(ext) => Some(ext),
1268            _ => unreachable!("extension type checked"),
1269        }
1270    }
1271}
1272
1273impl Codec<'_> for ClientHelloPayload {
1274    fn encode(&self, bytes: &mut Vec<u8>) {
1275        self.payload_encode(bytes, Encoding::Standard)
1276    }
1277
1278    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1279        let mut ret = Self {
1280            client_version: ProtocolVersion::read(r)?,
1281            random: Random::read(r)?,
1282            session_id: SessionId::read(r)?,
1283            cipher_suites: Vec::read(r)?,
1284            compression_methods: Vec::read(r)?,
1285            extensions: Vec::new(),
1286        };
1287
1288        if r.any_left() {
1289            ret.extensions = Vec::read(r)?;
1290        }
1291
1292        match (r.any_left(), ret.extensions.is_empty()) {
1293            (true, _) => Err(InvalidMessage::TrailingData("ClientHelloPayload")),
1294            (_, true) => Err(InvalidMessage::MissingData("ClientHelloPayload")),
1295            _ => Ok(ret),
1296        }
1297    }
1298}
1299
1300/// RFC8446: `CipherSuite cipher_suites<2..2^16-2>;`
1301impl TlsListElement for CipherSuite {
1302    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
1303        empty_error: InvalidMessage::IllegalEmptyList("CipherSuites"),
1304    };
1305}
1306
1307/// RFC5246: `CompressionMethod compression_methods<1..2^8-1>;`
1308impl TlsListElement for Compression {
1309    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
1310        empty_error: InvalidMessage::IllegalEmptyList("Compressions"),
1311    };
1312}
1313
1314impl TlsListElement for ClientExtension {
1315    const SIZE_LEN: ListLength = ListLength::U16;
1316}
1317
1318/// draft-ietf-tls-esni-17: `ExtensionType OuterExtensions<2..254>;`
1319impl TlsListElement for ExtensionType {
1320    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
1321        empty_error: InvalidMessage::IllegalEmptyList("ExtensionTypes"),
1322    };
1323}
1324
1325#[derive(Clone, Debug)]
1326pub(crate) enum HelloRetryExtension {
1327    KeyShare(NamedGroup),
1328    Cookie(PayloadU16<NonEmpty>),
1329    SupportedVersions(ProtocolVersion),
1330    EchHelloRetryRequest(Vec<u8>),
1331    Unknown(UnknownExtension),
1332}
1333
1334impl HelloRetryExtension {
1335    pub(crate) fn ext_type(&self) -> ExtensionType {
1336        match self {
1337            Self::KeyShare(_) => ExtensionType::KeyShare,
1338            Self::Cookie(_) => ExtensionType::Cookie,
1339            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
1340            Self::EchHelloRetryRequest(_) => ExtensionType::EncryptedClientHello,
1341            Self::Unknown(r) => r.typ,
1342        }
1343    }
1344}
1345
1346impl Codec<'_> for HelloRetryExtension {
1347    fn encode(&self, bytes: &mut Vec<u8>) {
1348        self.ext_type().encode(bytes);
1349
1350        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1351        match self {
1352            Self::KeyShare(r) => r.encode(nested.buf),
1353            Self::Cookie(r) => r.encode(nested.buf),
1354            Self::SupportedVersions(r) => r.encode(nested.buf),
1355            Self::EchHelloRetryRequest(r) => {
1356                nested.buf.extend_from_slice(r);
1357            }
1358            Self::Unknown(r) => r.encode(nested.buf),
1359        }
1360    }
1361
1362    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1363        let typ = ExtensionType::read(r)?;
1364        let len = u16::read(r)? as usize;
1365        let mut sub = r.sub(len)?;
1366
1367        let ext = match typ {
1368            ExtensionType::KeyShare => Self::KeyShare(NamedGroup::read(&mut sub)?),
1369            ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?),
1370            ExtensionType::SupportedVersions => {
1371                Self::SupportedVersions(ProtocolVersion::read(&mut sub)?)
1372            }
1373            ExtensionType::EncryptedClientHello => Self::EchHelloRetryRequest(sub.rest().to_vec()),
1374            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1375        };
1376
1377        sub.expect_empty("HelloRetryExtension")
1378            .map(|_| ext)
1379    }
1380}
1381
1382impl TlsListElement for HelloRetryExtension {
1383    const SIZE_LEN: ListLength = ListLength::U16;
1384}
1385
1386#[derive(Clone, Debug)]
1387pub(crate) struct HelloRetryRequest {
1388    pub(crate) legacy_version: ProtocolVersion,
1389    pub(crate) session_id: SessionId,
1390    pub(crate) cipher_suite: CipherSuite,
1391    pub(crate) extensions: Vec<HelloRetryExtension>,
1392}
1393
1394impl Codec<'_> for HelloRetryRequest {
1395    fn encode(&self, bytes: &mut Vec<u8>) {
1396        self.payload_encode(bytes, Encoding::Standard)
1397    }
1398
1399    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1400        let session_id = SessionId::read(r)?;
1401        let cipher_suite = CipherSuite::read(r)?;
1402        let compression = Compression::read(r)?;
1403
1404        if compression != Compression::Null {
1405            return Err(InvalidMessage::UnsupportedCompression);
1406        }
1407
1408        Ok(Self {
1409            legacy_version: ProtocolVersion::Unknown(0),
1410            session_id,
1411            cipher_suite,
1412            extensions: Vec::read(r)?,
1413        })
1414    }
1415}
1416
1417impl HelloRetryRequest {
1418    /// Returns true if there is more than one extension of a given
1419    /// type.
1420    pub(crate) fn has_duplicate_extension(&self) -> bool {
1421        has_duplicates::<_, _, u16>(
1422            self.extensions
1423                .iter()
1424                .map(|ext| ext.ext_type()),
1425        )
1426    }
1427
1428    pub(crate) fn has_unknown_extension(&self) -> bool {
1429        self.extensions.iter().any(|ext| {
1430            ext.ext_type() != ExtensionType::KeyShare
1431                && ext.ext_type() != ExtensionType::SupportedVersions
1432                && ext.ext_type() != ExtensionType::Cookie
1433                && ext.ext_type() != ExtensionType::EncryptedClientHello
1434        })
1435    }
1436
1437    fn find_extension(&self, ext: ExtensionType) -> Option<&HelloRetryExtension> {
1438        self.extensions
1439            .iter()
1440            .find(|x| x.ext_type() == ext)
1441    }
1442
1443    pub(crate) fn requested_key_share_group(&self) -> Option<NamedGroup> {
1444        let ext = self.find_extension(ExtensionType::KeyShare)?;
1445        match ext {
1446            HelloRetryExtension::KeyShare(grp) => Some(*grp),
1447            _ => None,
1448        }
1449    }
1450
1451    pub(crate) fn cookie(&self) -> Option<&PayloadU16<NonEmpty>> {
1452        let ext = self.find_extension(ExtensionType::Cookie)?;
1453        match ext {
1454            HelloRetryExtension::Cookie(ck) => Some(ck),
1455            _ => None,
1456        }
1457    }
1458
1459    pub(crate) fn supported_versions(&self) -> Option<ProtocolVersion> {
1460        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1461        match ext {
1462            HelloRetryExtension::SupportedVersions(ver) => Some(*ver),
1463            _ => None,
1464        }
1465    }
1466
1467    pub(crate) fn ech(&self) -> Option<&Vec<u8>> {
1468        let ext = self.find_extension(ExtensionType::EncryptedClientHello)?;
1469        match ext {
1470            HelloRetryExtension::EchHelloRetryRequest(ech) => Some(ech),
1471            _ => None,
1472        }
1473    }
1474
1475    fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) {
1476        self.legacy_version.encode(bytes);
1477        HELLO_RETRY_REQUEST_RANDOM.encode(bytes);
1478        self.session_id.encode(bytes);
1479        self.cipher_suite.encode(bytes);
1480        Compression::Null.encode(bytes);
1481
1482        match purpose {
1483            // For the purpose of ECH confirmation, the Encrypted Client Hello extension
1484            // must have its payload replaced by 8 zero bytes.
1485            //
1486            // See draft-ietf-tls-esni-18 7.2.1:
1487            // <https://datatracker.ietf.org/doc/html/draft-ietf-tls-esni-18#name-sending-helloretryrequest-2>
1488            Encoding::EchConfirmation => {
1489                let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1490                for ext in &self.extensions {
1491                    match ext.ext_type() {
1492                        ExtensionType::EncryptedClientHello => {
1493                            HelloRetryExtension::EchHelloRetryRequest(vec![0u8; 8])
1494                                .encode(extensions.buf);
1495                        }
1496                        _ => {
1497                            ext.encode(extensions.buf);
1498                        }
1499                    }
1500                }
1501            }
1502            _ => {
1503                self.extensions.encode(bytes);
1504            }
1505        }
1506    }
1507}
1508
1509#[derive(Clone, Debug)]
1510pub(crate) struct ServerHelloPayload {
1511    pub(crate) legacy_version: ProtocolVersion,
1512    pub(crate) random: Random,
1513    pub(crate) session_id: SessionId,
1514    pub(crate) cipher_suite: CipherSuite,
1515    pub(crate) compression_method: Compression,
1516    pub(crate) extensions: Vec<ServerExtension>,
1517}
1518
1519impl Codec<'_> for ServerHelloPayload {
1520    fn encode(&self, bytes: &mut Vec<u8>) {
1521        self.payload_encode(bytes, Encoding::Standard)
1522    }
1523
1524    // minus version and random, which have already been read.
1525    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1526        let session_id = SessionId::read(r)?;
1527        let suite = CipherSuite::read(r)?;
1528        let compression = Compression::read(r)?;
1529
1530        // RFC5246:
1531        // "The presence of extensions can be detected by determining whether
1532        //  there are bytes following the compression_method field at the end of
1533        //  the ServerHello."
1534        let extensions = if r.any_left() { Vec::read(r)? } else { vec![] };
1535
1536        let ret = Self {
1537            legacy_version: ProtocolVersion::Unknown(0),
1538            random: ZERO_RANDOM,
1539            session_id,
1540            cipher_suite: suite,
1541            compression_method: compression,
1542            extensions,
1543        };
1544
1545        r.expect_empty("ServerHelloPayload")
1546            .map(|_| ret)
1547    }
1548}
1549
1550impl HasServerExtensions for ServerHelloPayload {
1551    fn extensions(&self) -> &[ServerExtension] {
1552        &self.extensions
1553    }
1554}
1555
1556impl ServerHelloPayload {
1557    pub(crate) fn key_share(&self) -> Option<&KeyShareEntry> {
1558        let ext = self.find_extension(ExtensionType::KeyShare)?;
1559        match ext {
1560            ServerExtension::KeyShare(share) => Some(share),
1561            _ => None,
1562        }
1563    }
1564
1565    pub(crate) fn psk_index(&self) -> Option<u16> {
1566        let ext = self.find_extension(ExtensionType::PreSharedKey)?;
1567        match ext {
1568            ServerExtension::PresharedKey(index) => Some(*index),
1569            _ => None,
1570        }
1571    }
1572
1573    pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
1574        let ext = self.find_extension(ExtensionType::ECPointFormats)?;
1575        match ext {
1576            ServerExtension::EcPointFormats(fmts) => Some(fmts),
1577            _ => None,
1578        }
1579    }
1580
1581    #[cfg(feature = "tls12")]
1582    pub(crate) fn ems_support_acked(&self) -> bool {
1583        self.find_extension(ExtensionType::ExtendedMasterSecret)
1584            .is_some()
1585    }
1586
1587    pub(crate) fn supported_versions(&self) -> Option<ProtocolVersion> {
1588        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1589        match ext {
1590            ServerExtension::SupportedVersions(vers) => Some(*vers),
1591            _ => None,
1592        }
1593    }
1594
1595    fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) {
1596        self.legacy_version.encode(bytes);
1597
1598        match encoding {
1599            // When encoding a ServerHello for ECH confirmation, the random value
1600            // has the last 8 bytes zeroed out.
1601            Encoding::EchConfirmation => {
1602                // Indexing safety: self.random is 32 bytes long by definition.
1603                let rand_vec = self.random.get_encoding();
1604                bytes.extend_from_slice(&rand_vec.as_slice()[..24]);
1605                bytes.extend_from_slice(&[0u8; 8]);
1606            }
1607            _ => self.random.encode(bytes),
1608        }
1609
1610        self.session_id.encode(bytes);
1611        self.cipher_suite.encode(bytes);
1612        self.compression_method.encode(bytes);
1613
1614        if !self.extensions.is_empty() {
1615            self.extensions.encode(bytes);
1616        }
1617    }
1618}
1619
1620#[derive(Clone, Default, Debug)]
1621pub(crate) struct CertificateChain<'a>(pub(crate) Vec<CertificateDer<'a>>);
1622
1623impl CertificateChain<'_> {
1624    pub(crate) fn into_owned(self) -> CertificateChain<'static> {
1625        CertificateChain(
1626            self.0
1627                .into_iter()
1628                .map(|c| c.into_owned())
1629                .collect(),
1630        )
1631    }
1632}
1633
1634impl<'a> Codec<'a> for CertificateChain<'a> {
1635    fn encode(&self, bytes: &mut Vec<u8>) {
1636        Vec::encode(&self.0, bytes)
1637    }
1638
1639    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1640        Vec::read(r).map(Self)
1641    }
1642}
1643
1644impl<'a> Deref for CertificateChain<'a> {
1645    type Target = [CertificateDer<'a>];
1646
1647    fn deref(&self) -> &[CertificateDer<'a>] {
1648        &self.0
1649    }
1650}
1651
1652impl TlsListElement for CertificateDer<'_> {
1653    const SIZE_LEN: ListLength = ListLength::U24 {
1654        max: CERTIFICATE_MAX_SIZE_LIMIT,
1655        error: InvalidMessage::CertificatePayloadTooLarge,
1656    };
1657}
1658
1659/// TLS has a 16MB size limit on any handshake message,
1660/// plus a 16MB limit on any given certificate.
1661///
1662/// We contract that to 64KB to limit the amount of memory allocation
1663/// that is directly controllable by the peer.
1664pub(crate) const CERTIFICATE_MAX_SIZE_LIMIT: usize = 0x1_0000;
1665
1666#[derive(Debug)]
1667pub(crate) enum CertificateExtension<'a> {
1668    CertificateStatus(CertificateStatus<'a>),
1669    Unknown(UnknownExtension),
1670}
1671
1672impl CertificateExtension<'_> {
1673    pub(crate) fn ext_type(&self) -> ExtensionType {
1674        match self {
1675            Self::CertificateStatus(_) => ExtensionType::StatusRequest,
1676            Self::Unknown(r) => r.typ,
1677        }
1678    }
1679
1680    pub(crate) fn cert_status(&self) -> Option<&[u8]> {
1681        match self {
1682            Self::CertificateStatus(cs) => Some(cs.ocsp_response.0.bytes()),
1683            _ => None,
1684        }
1685    }
1686
1687    pub(crate) fn into_owned(self) -> CertificateExtension<'static> {
1688        match self {
1689            Self::CertificateStatus(st) => CertificateExtension::CertificateStatus(st.into_owned()),
1690            Self::Unknown(unk) => CertificateExtension::Unknown(unk),
1691        }
1692    }
1693}
1694
1695impl<'a> Codec<'a> for CertificateExtension<'a> {
1696    fn encode(&self, bytes: &mut Vec<u8>) {
1697        self.ext_type().encode(bytes);
1698
1699        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1700        match self {
1701            Self::CertificateStatus(r) => r.encode(nested.buf),
1702            Self::Unknown(r) => r.encode(nested.buf),
1703        }
1704    }
1705
1706    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1707        let typ = ExtensionType::read(r)?;
1708        let len = u16::read(r)? as usize;
1709        let mut sub = r.sub(len)?;
1710
1711        let ext = match typ {
1712            ExtensionType::StatusRequest => {
1713                let st = CertificateStatus::read(&mut sub)?;
1714                Self::CertificateStatus(st)
1715            }
1716            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1717        };
1718
1719        sub.expect_empty("CertificateExtension")
1720            .map(|_| ext)
1721    }
1722}
1723
1724impl TlsListElement for CertificateExtension<'_> {
1725    const SIZE_LEN: ListLength = ListLength::U16;
1726}
1727
1728#[derive(Debug)]
1729pub(crate) struct CertificateEntry<'a> {
1730    pub(crate) cert: CertificateDer<'a>,
1731    pub(crate) exts: Vec<CertificateExtension<'a>>,
1732}
1733
1734impl<'a> Codec<'a> for CertificateEntry<'a> {
1735    fn encode(&self, bytes: &mut Vec<u8>) {
1736        self.cert.encode(bytes);
1737        self.exts.encode(bytes);
1738    }
1739
1740    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1741        Ok(Self {
1742            cert: CertificateDer::read(r)?,
1743            exts: Vec::read(r)?,
1744        })
1745    }
1746}
1747
1748impl<'a> CertificateEntry<'a> {
1749    pub(crate) fn new(cert: CertificateDer<'a>) -> Self {
1750        Self {
1751            cert,
1752            exts: Vec::new(),
1753        }
1754    }
1755
1756    pub(crate) fn into_owned(self) -> CertificateEntry<'static> {
1757        CertificateEntry {
1758            cert: self.cert.into_owned(),
1759            exts: self
1760                .exts
1761                .into_iter()
1762                .map(CertificateExtension::into_owned)
1763                .collect(),
1764        }
1765    }
1766
1767    pub(crate) fn has_duplicate_extension(&self) -> bool {
1768        has_duplicates::<_, _, u16>(
1769            self.exts
1770                .iter()
1771                .map(|ext| ext.ext_type()),
1772        )
1773    }
1774
1775    pub(crate) fn has_unknown_extension(&self) -> bool {
1776        self.exts
1777            .iter()
1778            .any(|ext| ext.ext_type() != ExtensionType::StatusRequest)
1779    }
1780
1781    pub(crate) fn ocsp_response(&self) -> Option<&[u8]> {
1782        self.exts
1783            .iter()
1784            .find(|ext| ext.ext_type() == ExtensionType::StatusRequest)
1785            .and_then(CertificateExtension::cert_status)
1786    }
1787}
1788
1789impl TlsListElement for CertificateEntry<'_> {
1790    const SIZE_LEN: ListLength = ListLength::U24 {
1791        max: CERTIFICATE_MAX_SIZE_LIMIT,
1792        error: InvalidMessage::CertificatePayloadTooLarge,
1793    };
1794}
1795
1796#[derive(Debug)]
1797pub(crate) struct CertificatePayloadTls13<'a> {
1798    pub(crate) context: PayloadU8,
1799    pub(crate) entries: Vec<CertificateEntry<'a>>,
1800}
1801
1802impl<'a> Codec<'a> for CertificatePayloadTls13<'a> {
1803    fn encode(&self, bytes: &mut Vec<u8>) {
1804        self.context.encode(bytes);
1805        self.entries.encode(bytes);
1806    }
1807
1808    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1809        Ok(Self {
1810            context: PayloadU8::read(r)?,
1811            entries: Vec::read(r)?,
1812        })
1813    }
1814}
1815
1816impl<'a> CertificatePayloadTls13<'a> {
1817    pub(crate) fn new(
1818        certs: impl Iterator<Item = &'a CertificateDer<'a>>,
1819        ocsp_response: Option<&'a [u8]>,
1820    ) -> Self {
1821        Self {
1822            context: PayloadU8::empty(),
1823            entries: certs
1824                // zip certificate iterator with `ocsp_response` followed by
1825                // an infinite-length iterator of `None`.
1826                .zip(
1827                    ocsp_response
1828                        .into_iter()
1829                        .map(Some)
1830                        .chain(iter::repeat(None)),
1831                )
1832                .map(|(cert, ocsp)| {
1833                    let mut e = CertificateEntry::new(cert.clone());
1834                    if let Some(ocsp) = ocsp {
1835                        e.exts
1836                            .push(CertificateExtension::CertificateStatus(
1837                                CertificateStatus::new(ocsp),
1838                            ));
1839                    }
1840                    e
1841                })
1842                .collect(),
1843        }
1844    }
1845
1846    pub(crate) fn into_owned(self) -> CertificatePayloadTls13<'static> {
1847        CertificatePayloadTls13 {
1848            context: self.context,
1849            entries: self
1850                .entries
1851                .into_iter()
1852                .map(CertificateEntry::into_owned)
1853                .collect(),
1854        }
1855    }
1856
1857    pub(crate) fn any_entry_has_duplicate_extension(&self) -> bool {
1858        for entry in &self.entries {
1859            if entry.has_duplicate_extension() {
1860                return true;
1861            }
1862        }
1863
1864        false
1865    }
1866
1867    pub(crate) fn any_entry_has_unknown_extension(&self) -> bool {
1868        for entry in &self.entries {
1869            if entry.has_unknown_extension() {
1870                return true;
1871            }
1872        }
1873
1874        false
1875    }
1876
1877    pub(crate) fn any_entry_has_extension(&self) -> bool {
1878        for entry in &self.entries {
1879            if !entry.exts.is_empty() {
1880                return true;
1881            }
1882        }
1883
1884        false
1885    }
1886
1887    pub(crate) fn end_entity_ocsp(&self) -> &[u8] {
1888        self.entries
1889            .first()
1890            .and_then(CertificateEntry::ocsp_response)
1891            .unwrap_or_default()
1892    }
1893
1894    pub(crate) fn into_certificate_chain(self) -> CertificateChain<'a> {
1895        CertificateChain(
1896            self.entries
1897                .into_iter()
1898                .map(|e| e.cert)
1899                .collect(),
1900        )
1901    }
1902}
1903
1904/// Describes supported key exchange mechanisms.
1905#[derive(Clone, Copy, Debug, PartialEq)]
1906#[non_exhaustive]
1907pub enum KeyExchangeAlgorithm {
1908    /// Diffie-Hellman Key exchange (with only known parameters as defined in [RFC 7919]).
1909    ///
1910    /// [RFC 7919]: https://datatracker.ietf.org/doc/html/rfc7919
1911    DHE,
1912    /// Key exchange performed via elliptic curve Diffie-Hellman.
1913    ECDHE,
1914}
1915
1916pub(crate) static ALL_KEY_EXCHANGE_ALGORITHMS: &[KeyExchangeAlgorithm] =
1917    &[KeyExchangeAlgorithm::ECDHE, KeyExchangeAlgorithm::DHE];
1918
1919// We don't support arbitrary curves.  It's a terrible
1920// idea and unnecessary attack surface.  Please,
1921// get a grip.
1922#[derive(Debug)]
1923pub(crate) struct EcParameters {
1924    pub(crate) curve_type: ECCurveType,
1925    pub(crate) named_group: NamedGroup,
1926}
1927
1928impl Codec<'_> for EcParameters {
1929    fn encode(&self, bytes: &mut Vec<u8>) {
1930        self.curve_type.encode(bytes);
1931        self.named_group.encode(bytes);
1932    }
1933
1934    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1935        let ct = ECCurveType::read(r)?;
1936        if ct != ECCurveType::NamedCurve {
1937            return Err(InvalidMessage::UnsupportedCurveType);
1938        }
1939
1940        let grp = NamedGroup::read(r)?;
1941
1942        Ok(Self {
1943            curve_type: ct,
1944            named_group: grp,
1945        })
1946    }
1947}
1948
1949#[cfg(feature = "tls12")]
1950pub(crate) trait KxDecode<'a>: fmt::Debug + Sized {
1951    /// Decode a key exchange message given the key_exchange `algo`
1952    fn decode(r: &mut Reader<'a>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage>;
1953}
1954
1955#[cfg(feature = "tls12")]
1956#[derive(Debug)]
1957pub(crate) enum ClientKeyExchangeParams {
1958    Ecdh(ClientEcdhParams),
1959    Dh(ClientDhParams),
1960}
1961
1962#[cfg(feature = "tls12")]
1963impl ClientKeyExchangeParams {
1964    pub(crate) fn pub_key(&self) -> &[u8] {
1965        match self {
1966            Self::Ecdh(ecdh) => &ecdh.public.0,
1967            Self::Dh(dh) => &dh.public.0,
1968        }
1969    }
1970
1971    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
1972        match self {
1973            Self::Ecdh(ecdh) => ecdh.encode(buf),
1974            Self::Dh(dh) => dh.encode(buf),
1975        }
1976    }
1977}
1978
1979#[cfg(feature = "tls12")]
1980impl KxDecode<'_> for ClientKeyExchangeParams {
1981    fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
1982        use KeyExchangeAlgorithm::*;
1983        Ok(match algo {
1984            ECDHE => Self::Ecdh(ClientEcdhParams::read(r)?),
1985            DHE => Self::Dh(ClientDhParams::read(r)?),
1986        })
1987    }
1988}
1989
1990#[cfg(feature = "tls12")]
1991#[derive(Debug)]
1992pub(crate) struct ClientEcdhParams {
1993    /// RFC4492: `opaque point <1..2^8-1>;`
1994    pub(crate) public: PayloadU8<NonEmpty>,
1995}
1996
1997#[cfg(feature = "tls12")]
1998impl Codec<'_> for ClientEcdhParams {
1999    fn encode(&self, bytes: &mut Vec<u8>) {
2000        self.public.encode(bytes);
2001    }
2002
2003    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2004        let pb = PayloadU8::read(r)?;
2005        Ok(Self { public: pb })
2006    }
2007}
2008
2009#[cfg(feature = "tls12")]
2010#[derive(Debug)]
2011pub(crate) struct ClientDhParams {
2012    /// RFC5246: `opaque dh_Yc<1..2^16-1>;`
2013    pub(crate) public: PayloadU16<NonEmpty>,
2014}
2015
2016#[cfg(feature = "tls12")]
2017impl Codec<'_> for ClientDhParams {
2018    fn encode(&self, bytes: &mut Vec<u8>) {
2019        self.public.encode(bytes);
2020    }
2021
2022    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2023        Ok(Self {
2024            public: PayloadU16::read(r)?,
2025        })
2026    }
2027}
2028
2029#[derive(Debug)]
2030pub(crate) struct ServerEcdhParams {
2031    pub(crate) curve_params: EcParameters,
2032    /// RFC4492: `opaque point <1..2^8-1>;`
2033    pub(crate) public: PayloadU8<NonEmpty>,
2034}
2035
2036impl ServerEcdhParams {
2037    #[cfg(feature = "tls12")]
2038    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
2039        Self {
2040            curve_params: EcParameters {
2041                curve_type: ECCurveType::NamedCurve,
2042                named_group: kx.group(),
2043            },
2044            public: PayloadU8::new(kx.pub_key().to_vec()),
2045        }
2046    }
2047}
2048
2049impl Codec<'_> for ServerEcdhParams {
2050    fn encode(&self, bytes: &mut Vec<u8>) {
2051        self.curve_params.encode(bytes);
2052        self.public.encode(bytes);
2053    }
2054
2055    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2056        let cp = EcParameters::read(r)?;
2057        let pb = PayloadU8::read(r)?;
2058
2059        Ok(Self {
2060            curve_params: cp,
2061            public: pb,
2062        })
2063    }
2064}
2065
2066#[derive(Debug)]
2067#[allow(non_snake_case)]
2068pub(crate) struct ServerDhParams {
2069    /// RFC5246: `opaque dh_p<1..2^16-1>;`
2070    pub(crate) dh_p: PayloadU16<NonEmpty>,
2071    /// RFC5246: `opaque dh_g<1..2^16-1>;`
2072    pub(crate) dh_g: PayloadU16<NonEmpty>,
2073    /// RFC5246: `opaque dh_Ys<1..2^16-1>;`
2074    pub(crate) dh_Ys: PayloadU16<NonEmpty>,
2075}
2076
2077impl ServerDhParams {
2078    #[cfg(feature = "tls12")]
2079    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
2080        let Some(params) = kx.ffdhe_group() else {
2081            panic!("invalid NamedGroup for DHE key exchange: {:?}", kx.group());
2082        };
2083
2084        Self {
2085            dh_p: PayloadU16::new(params.p.to_vec()),
2086            dh_g: PayloadU16::new(params.g.to_vec()),
2087            dh_Ys: PayloadU16::new(kx.pub_key().to_vec()),
2088        }
2089    }
2090
2091    #[cfg(feature = "tls12")]
2092    pub(crate) fn as_ffdhe_group(&self) -> FfdheGroup<'_> {
2093        FfdheGroup::from_params_trimming_leading_zeros(&self.dh_p.0, &self.dh_g.0)
2094    }
2095}
2096
2097impl Codec<'_> for ServerDhParams {
2098    fn encode(&self, bytes: &mut Vec<u8>) {
2099        self.dh_p.encode(bytes);
2100        self.dh_g.encode(bytes);
2101        self.dh_Ys.encode(bytes);
2102    }
2103
2104    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2105        Ok(Self {
2106            dh_p: PayloadU16::read(r)?,
2107            dh_g: PayloadU16::read(r)?,
2108            dh_Ys: PayloadU16::read(r)?,
2109        })
2110    }
2111}
2112
2113#[allow(dead_code)]
2114#[derive(Debug)]
2115pub(crate) enum ServerKeyExchangeParams {
2116    Ecdh(ServerEcdhParams),
2117    Dh(ServerDhParams),
2118}
2119
2120impl ServerKeyExchangeParams {
2121    #[cfg(feature = "tls12")]
2122    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
2123        match kx.group().key_exchange_algorithm() {
2124            KeyExchangeAlgorithm::DHE => Self::Dh(ServerDhParams::new(kx)),
2125            KeyExchangeAlgorithm::ECDHE => Self::Ecdh(ServerEcdhParams::new(kx)),
2126        }
2127    }
2128
2129    #[cfg(feature = "tls12")]
2130    pub(crate) fn pub_key(&self) -> &[u8] {
2131        match self {
2132            Self::Ecdh(ecdh) => &ecdh.public.0,
2133            Self::Dh(dh) => &dh.dh_Ys.0,
2134        }
2135    }
2136
2137    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
2138        match self {
2139            Self::Ecdh(ecdh) => ecdh.encode(buf),
2140            Self::Dh(dh) => dh.encode(buf),
2141        }
2142    }
2143}
2144
2145#[cfg(feature = "tls12")]
2146impl KxDecode<'_> for ServerKeyExchangeParams {
2147    fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
2148        use KeyExchangeAlgorithm::*;
2149        Ok(match algo {
2150            ECDHE => Self::Ecdh(ServerEcdhParams::read(r)?),
2151            DHE => Self::Dh(ServerDhParams::read(r)?),
2152        })
2153    }
2154}
2155
2156#[derive(Debug)]
2157pub(crate) struct ServerKeyExchange {
2158    pub(crate) params: ServerKeyExchangeParams,
2159    pub(crate) dss: DigitallySignedStruct,
2160}
2161
2162impl ServerKeyExchange {
2163    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
2164        self.params.encode(buf);
2165        self.dss.encode(buf);
2166    }
2167}
2168
2169#[derive(Debug)]
2170pub(crate) enum ServerKeyExchangePayload {
2171    Known(ServerKeyExchange),
2172    Unknown(Payload<'static>),
2173}
2174
2175impl From<ServerKeyExchange> for ServerKeyExchangePayload {
2176    fn from(value: ServerKeyExchange) -> Self {
2177        Self::Known(value)
2178    }
2179}
2180
2181impl Codec<'_> for ServerKeyExchangePayload {
2182    fn encode(&self, bytes: &mut Vec<u8>) {
2183        match self {
2184            Self::Known(x) => x.encode(bytes),
2185            Self::Unknown(x) => x.encode(bytes),
2186        }
2187    }
2188
2189    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2190        // read as Unknown, fully parse when we know the
2191        // KeyExchangeAlgorithm
2192        Ok(Self::Unknown(Payload::read(r).into_owned()))
2193    }
2194}
2195
2196impl ServerKeyExchangePayload {
2197    #[cfg(feature = "tls12")]
2198    pub(crate) fn unwrap_given_kxa(&self, kxa: KeyExchangeAlgorithm) -> Option<ServerKeyExchange> {
2199        if let Self::Unknown(unk) = self {
2200            let mut rd = Reader::init(unk.bytes());
2201
2202            let result = ServerKeyExchange {
2203                params: ServerKeyExchangeParams::decode(&mut rd, kxa).ok()?,
2204                dss: DigitallySignedStruct::read(&mut rd).ok()?,
2205            };
2206
2207            if !rd.any_left() {
2208                return Some(result);
2209            };
2210        }
2211
2212        None
2213    }
2214}
2215
2216// -- EncryptedExtensions (TLS1.3 only) --
2217
2218impl TlsListElement for ServerExtension {
2219    const SIZE_LEN: ListLength = ListLength::U16;
2220}
2221
2222pub(crate) trait HasServerExtensions {
2223    fn extensions(&self) -> &[ServerExtension];
2224
2225    /// Returns true if there is more than one extension of a given
2226    /// type.
2227    fn has_duplicate_extension(&self) -> bool {
2228        has_duplicates::<_, _, u16>(
2229            self.extensions()
2230                .iter()
2231                .map(|ext| ext.ext_type()),
2232        )
2233    }
2234
2235    fn find_extension(&self, ext: ExtensionType) -> Option<&ServerExtension> {
2236        self.extensions()
2237            .iter()
2238            .find(|x| x.ext_type() == ext)
2239    }
2240
2241    fn alpn_protocol(&self) -> Option<&ProtocolName> {
2242        let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
2243        match ext {
2244            ServerExtension::Protocols(protos) => Some(&protos.0),
2245            _ => None,
2246        }
2247    }
2248
2249    fn server_cert_type(&self) -> Option<&CertificateType> {
2250        let ext = self.find_extension(ExtensionType::ServerCertificateType)?;
2251        match ext {
2252            ServerExtension::ServerCertType(req) => Some(req),
2253            _ => None,
2254        }
2255    }
2256
2257    fn client_cert_type(&self) -> Option<&CertificateType> {
2258        let ext = self.find_extension(ExtensionType::ClientCertificateType)?;
2259        match ext {
2260            ServerExtension::ClientCertType(req) => Some(req),
2261            _ => None,
2262        }
2263    }
2264
2265    fn quic_params_extension(&self) -> Option<Vec<u8>> {
2266        let ext = self
2267            .find_extension(ExtensionType::TransportParameters)
2268            .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?;
2269        match ext {
2270            ServerExtension::TransportParameters(bytes)
2271            | ServerExtension::TransportParametersDraft(bytes) => Some(bytes.to_vec()),
2272            _ => None,
2273        }
2274    }
2275
2276    fn server_ech_extension(&self) -> Option<ServerEncryptedClientHello> {
2277        let ext = self.find_extension(ExtensionType::EncryptedClientHello)?;
2278        match ext {
2279            ServerExtension::EncryptedClientHello(ech) => Some(ech.clone()),
2280            _ => None,
2281        }
2282    }
2283
2284    fn early_data_extension_offered(&self) -> bool {
2285        self.find_extension(ExtensionType::EarlyData)
2286            .is_some()
2287    }
2288}
2289
2290impl HasServerExtensions for Vec<ServerExtension> {
2291    fn extensions(&self) -> &[ServerExtension] {
2292        self
2293    }
2294}
2295
2296/// RFC5246: `ClientCertificateType certificate_types<1..2^8-1>;`
2297impl TlsListElement for ClientCertificateType {
2298    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
2299        empty_error: InvalidMessage::IllegalEmptyList("ClientCertificateTypes"),
2300    };
2301}
2302
2303wrapped_payload!(
2304    /// A `DistinguishedName` is a `Vec<u8>` wrapped in internal types.
2305    ///
2306    /// It contains the DER or BER encoded [`Subject` field from RFC 5280](https://datatracker.ietf.org/doc/html/rfc5280#section-4.1.2.6)
2307    /// for a single certificate. The Subject field is [encoded as an RFC 5280 `Name`](https://datatracker.ietf.org/doc/html/rfc5280#page-116).
2308    /// It can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
2309    ///
2310    /// ```ignore
2311    /// for name in distinguished_names {
2312    ///     use x509_parser::prelude::FromDer;
2313    ///     println!("{}", x509_parser::x509::X509Name::from_der(&name.0)?.1);
2314    /// }
2315    /// ```
2316    ///
2317    /// The TLS encoding is defined in RFC5246: `opaque DistinguishedName<1..2^16-1>;`
2318    pub struct DistinguishedName,
2319    PayloadU16<NonEmpty>,
2320);
2321
2322impl DistinguishedName {
2323    /// Create a [`DistinguishedName`] after prepending its outer SEQUENCE encoding.
2324    ///
2325    /// This can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
2326    ///
2327    /// ```ignore
2328    /// use x509_parser::prelude::FromDer;
2329    /// println!("{}", x509_parser::x509::X509Name::from_der(dn.as_ref())?.1);
2330    /// ```
2331    pub fn in_sequence(bytes: &[u8]) -> Self {
2332        Self(PayloadU16::new(wrap_in_sequence(bytes)))
2333    }
2334}
2335
2336/// RFC8446: `DistinguishedName authorities<3..2^16-1>;` however,
2337/// RFC5246: `DistinguishedName certificate_authorities<0..2^16-1>;`
2338impl TlsListElement for DistinguishedName {
2339    const SIZE_LEN: ListLength = ListLength::U16;
2340}
2341
2342#[derive(Debug)]
2343pub(crate) struct CertificateRequestPayload {
2344    pub(crate) certtypes: Vec<ClientCertificateType>,
2345    pub(crate) sigschemes: Vec<SignatureScheme>,
2346    pub(crate) canames: Vec<DistinguishedName>,
2347}
2348
2349impl Codec<'_> for CertificateRequestPayload {
2350    fn encode(&self, bytes: &mut Vec<u8>) {
2351        self.certtypes.encode(bytes);
2352        self.sigschemes.encode(bytes);
2353        self.canames.encode(bytes);
2354    }
2355
2356    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2357        let certtypes = Vec::read(r)?;
2358        let sigschemes = Vec::read(r)?;
2359        let canames = Vec::read(r)?;
2360
2361        if sigschemes.is_empty() {
2362            warn!("meaningless CertificateRequest message");
2363            Err(InvalidMessage::NoSignatureSchemes)
2364        } else {
2365            Ok(Self {
2366                certtypes,
2367                sigschemes,
2368                canames,
2369            })
2370        }
2371    }
2372}
2373
2374#[derive(Debug)]
2375pub(crate) enum CertReqExtension {
2376    SignatureAlgorithms(Vec<SignatureScheme>),
2377    AuthorityNames(Vec<DistinguishedName>),
2378    CertificateCompressionAlgorithms(Vec<CertificateCompressionAlgorithm>),
2379    Unknown(UnknownExtension),
2380}
2381
2382impl CertReqExtension {
2383    pub(crate) fn ext_type(&self) -> ExtensionType {
2384        match self {
2385            Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms,
2386            Self::AuthorityNames(_) => ExtensionType::CertificateAuthorities,
2387            Self::CertificateCompressionAlgorithms(_) => ExtensionType::CompressCertificate,
2388            Self::Unknown(r) => r.typ,
2389        }
2390    }
2391}
2392
2393impl Codec<'_> for CertReqExtension {
2394    fn encode(&self, bytes: &mut Vec<u8>) {
2395        self.ext_type().encode(bytes);
2396
2397        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2398        match self {
2399            Self::SignatureAlgorithms(r) => r.encode(nested.buf),
2400            Self::AuthorityNames(r) => r.encode(nested.buf),
2401            Self::CertificateCompressionAlgorithms(r) => r.encode(nested.buf),
2402            Self::Unknown(r) => r.encode(nested.buf),
2403        }
2404    }
2405
2406    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2407        let typ = ExtensionType::read(r)?;
2408        let len = u16::read(r)? as usize;
2409        let mut sub = r.sub(len)?;
2410
2411        let ext = match typ {
2412            ExtensionType::SignatureAlgorithms => {
2413                let schemes = Vec::read(&mut sub)?;
2414                if schemes.is_empty() {
2415                    return Err(InvalidMessage::NoSignatureSchemes);
2416                }
2417                Self::SignatureAlgorithms(schemes)
2418            }
2419            ExtensionType::CertificateAuthorities => {
2420                let cas = Vec::read(&mut sub)?;
2421                if cas.is_empty() {
2422                    return Err(InvalidMessage::IllegalEmptyList("DistinguishedNames"));
2423                }
2424                Self::AuthorityNames(cas)
2425            }
2426            ExtensionType::CompressCertificate => {
2427                Self::CertificateCompressionAlgorithms(Vec::read(&mut sub)?)
2428            }
2429            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
2430        };
2431
2432        sub.expect_empty("CertReqExtension")
2433            .map(|_| ext)
2434    }
2435}
2436
2437impl TlsListElement for CertReqExtension {
2438    const SIZE_LEN: ListLength = ListLength::U16;
2439}
2440
2441#[derive(Debug)]
2442pub(crate) struct CertificateRequestPayloadTls13 {
2443    pub(crate) context: PayloadU8,
2444    pub(crate) extensions: Vec<CertReqExtension>,
2445}
2446
2447impl Codec<'_> for CertificateRequestPayloadTls13 {
2448    fn encode(&self, bytes: &mut Vec<u8>) {
2449        self.context.encode(bytes);
2450        self.extensions.encode(bytes);
2451    }
2452
2453    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2454        let context = PayloadU8::read(r)?;
2455        let extensions = Vec::read(r)?;
2456
2457        Ok(Self {
2458            context,
2459            extensions,
2460        })
2461    }
2462}
2463
2464impl CertificateRequestPayloadTls13 {
2465    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&CertReqExtension> {
2466        self.extensions
2467            .iter()
2468            .find(|x| x.ext_type() == ext)
2469    }
2470
2471    pub(crate) fn sigalgs_extension(&self) -> Option<&[SignatureScheme]> {
2472        let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?;
2473        match ext {
2474            CertReqExtension::SignatureAlgorithms(sa) => Some(sa),
2475            _ => None,
2476        }
2477    }
2478
2479    pub(crate) fn authorities_extension(&self) -> Option<&[DistinguishedName]> {
2480        let ext = self.find_extension(ExtensionType::CertificateAuthorities)?;
2481        match ext {
2482            CertReqExtension::AuthorityNames(an) => Some(an),
2483            _ => None,
2484        }
2485    }
2486
2487    pub(crate) fn certificate_compression_extension(
2488        &self,
2489    ) -> Option<&[CertificateCompressionAlgorithm]> {
2490        let ext = self.find_extension(ExtensionType::CompressCertificate)?;
2491        match ext {
2492            CertReqExtension::CertificateCompressionAlgorithms(comps) => Some(comps),
2493            _ => None,
2494        }
2495    }
2496}
2497
2498// -- NewSessionTicket --
2499#[derive(Debug)]
2500pub(crate) struct NewSessionTicketPayload {
2501    pub(crate) lifetime_hint: u32,
2502    // Tickets can be large (KB), so we deserialise this straight
2503    // into an Arc, so it can be passed directly into the client's
2504    // session object without copying.
2505    pub(crate) ticket: Arc<PayloadU16>,
2506}
2507
2508impl NewSessionTicketPayload {
2509    #[cfg(feature = "tls12")]
2510    pub(crate) fn new(lifetime_hint: u32, ticket: Vec<u8>) -> Self {
2511        Self {
2512            lifetime_hint,
2513            ticket: Arc::new(PayloadU16::new(ticket)),
2514        }
2515    }
2516}
2517
2518impl Codec<'_> for NewSessionTicketPayload {
2519    fn encode(&self, bytes: &mut Vec<u8>) {
2520        self.lifetime_hint.encode(bytes);
2521        self.ticket.encode(bytes);
2522    }
2523
2524    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2525        let lifetime = u32::read(r)?;
2526        let ticket = Arc::new(PayloadU16::read(r)?);
2527
2528        Ok(Self {
2529            lifetime_hint: lifetime,
2530            ticket,
2531        })
2532    }
2533}
2534
2535// -- NewSessionTicket electric boogaloo --
2536#[derive(Debug)]
2537pub(crate) enum NewSessionTicketExtension {
2538    EarlyData(u32),
2539    Unknown(UnknownExtension),
2540}
2541
2542impl NewSessionTicketExtension {
2543    pub(crate) fn ext_type(&self) -> ExtensionType {
2544        match self {
2545            Self::EarlyData(_) => ExtensionType::EarlyData,
2546            Self::Unknown(r) => r.typ,
2547        }
2548    }
2549}
2550
2551impl Codec<'_> for NewSessionTicketExtension {
2552    fn encode(&self, bytes: &mut Vec<u8>) {
2553        self.ext_type().encode(bytes);
2554
2555        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2556        match self {
2557            Self::EarlyData(r) => r.encode(nested.buf),
2558            Self::Unknown(r) => r.encode(nested.buf),
2559        }
2560    }
2561
2562    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2563        let typ = ExtensionType::read(r)?;
2564        let len = u16::read(r)? as usize;
2565        let mut sub = r.sub(len)?;
2566
2567        let ext = match typ {
2568            ExtensionType::EarlyData => Self::EarlyData(u32::read(&mut sub)?),
2569            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
2570        };
2571
2572        sub.expect_empty("NewSessionTicketExtension")
2573            .map(|_| ext)
2574    }
2575}
2576
2577impl TlsListElement for NewSessionTicketExtension {
2578    const SIZE_LEN: ListLength = ListLength::U16;
2579}
2580
2581#[derive(Debug)]
2582pub(crate) struct NewSessionTicketPayloadTls13 {
2583    pub(crate) lifetime: u32,
2584    pub(crate) age_add: u32,
2585    pub(crate) nonce: PayloadU8,
2586    pub(crate) ticket: Arc<PayloadU16>,
2587    pub(crate) exts: Vec<NewSessionTicketExtension>,
2588}
2589
2590impl NewSessionTicketPayloadTls13 {
2591    pub(crate) fn new(lifetime: u32, age_add: u32, nonce: Vec<u8>, ticket: Vec<u8>) -> Self {
2592        Self {
2593            lifetime,
2594            age_add,
2595            nonce: PayloadU8::new(nonce),
2596            ticket: Arc::new(PayloadU16::new(ticket)),
2597            exts: vec![],
2598        }
2599    }
2600
2601    pub(crate) fn has_duplicate_extension(&self) -> bool {
2602        has_duplicates::<_, _, u16>(
2603            self.exts
2604                .iter()
2605                .map(|ext| ext.ext_type()),
2606        )
2607    }
2608
2609    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&NewSessionTicketExtension> {
2610        self.exts
2611            .iter()
2612            .find(|x| x.ext_type() == ext)
2613    }
2614
2615    pub(crate) fn max_early_data_size(&self) -> Option<u32> {
2616        let ext = self.find_extension(ExtensionType::EarlyData)?;
2617        match ext {
2618            NewSessionTicketExtension::EarlyData(sz) => Some(*sz),
2619            _ => None,
2620        }
2621    }
2622}
2623
2624impl Codec<'_> for NewSessionTicketPayloadTls13 {
2625    fn encode(&self, bytes: &mut Vec<u8>) {
2626        self.lifetime.encode(bytes);
2627        self.age_add.encode(bytes);
2628        self.nonce.encode(bytes);
2629        self.ticket.encode(bytes);
2630        self.exts.encode(bytes);
2631    }
2632
2633    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2634        let lifetime = u32::read(r)?;
2635        let age_add = u32::read(r)?;
2636        let nonce = PayloadU8::read(r)?;
2637        // nb. RFC8446: `opaque ticket<1..2^16-1>;`
2638        let ticket = Arc::new(match PayloadU16::<NonEmpty>::read(r) {
2639            Err(InvalidMessage::IllegalEmptyValue) => Err(InvalidMessage::EmptyTicketValue),
2640            Err(err) => Err(err),
2641            Ok(pl) => Ok(PayloadU16::new(pl.0)),
2642        }?);
2643        let exts = Vec::read(r)?;
2644
2645        Ok(Self {
2646            lifetime,
2647            age_add,
2648            nonce,
2649            ticket,
2650            exts,
2651        })
2652    }
2653}
2654
2655// -- RFC6066 certificate status types
2656
2657/// Only supports OCSP
2658#[derive(Debug)]
2659pub(crate) struct CertificateStatus<'a> {
2660    pub(crate) ocsp_response: PayloadU24<'a>,
2661}
2662
2663impl<'a> Codec<'a> for CertificateStatus<'a> {
2664    fn encode(&self, bytes: &mut Vec<u8>) {
2665        CertificateStatusType::OCSP.encode(bytes);
2666        self.ocsp_response.encode(bytes);
2667    }
2668
2669    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2670        let typ = CertificateStatusType::read(r)?;
2671
2672        match typ {
2673            CertificateStatusType::OCSP => Ok(Self {
2674                ocsp_response: PayloadU24::read(r)?,
2675            }),
2676            _ => Err(InvalidMessage::InvalidCertificateStatusType),
2677        }
2678    }
2679}
2680
2681impl<'a> CertificateStatus<'a> {
2682    pub(crate) fn new(ocsp: &'a [u8]) -> Self {
2683        CertificateStatus {
2684            ocsp_response: PayloadU24(Payload::Borrowed(ocsp)),
2685        }
2686    }
2687
2688    #[cfg(feature = "tls12")]
2689    pub(crate) fn into_inner(self) -> Vec<u8> {
2690        self.ocsp_response.0.into_vec()
2691    }
2692
2693    pub(crate) fn into_owned(self) -> CertificateStatus<'static> {
2694        CertificateStatus {
2695            ocsp_response: self.ocsp_response.into_owned(),
2696        }
2697    }
2698}
2699
2700// -- RFC8879 compressed certificates
2701
2702#[derive(Debug)]
2703pub(crate) struct CompressedCertificatePayload<'a> {
2704    pub(crate) alg: CertificateCompressionAlgorithm,
2705    pub(crate) uncompressed_len: u32,
2706    pub(crate) compressed: PayloadU24<'a>,
2707}
2708
2709impl<'a> Codec<'a> for CompressedCertificatePayload<'a> {
2710    fn encode(&self, bytes: &mut Vec<u8>) {
2711        self.alg.encode(bytes);
2712        codec::u24(self.uncompressed_len).encode(bytes);
2713        self.compressed.encode(bytes);
2714    }
2715
2716    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2717        Ok(Self {
2718            alg: CertificateCompressionAlgorithm::read(r)?,
2719            uncompressed_len: codec::u24::read(r)?.0,
2720            compressed: PayloadU24::read(r)?,
2721        })
2722    }
2723}
2724
2725impl CompressedCertificatePayload<'_> {
2726    fn into_owned(self) -> CompressedCertificatePayload<'static> {
2727        CompressedCertificatePayload {
2728            compressed: self.compressed.into_owned(),
2729            ..self
2730        }
2731    }
2732
2733    pub(crate) fn as_borrowed(&self) -> CompressedCertificatePayload<'_> {
2734        CompressedCertificatePayload {
2735            alg: self.alg,
2736            uncompressed_len: self.uncompressed_len,
2737            compressed: PayloadU24(Payload::Borrowed(self.compressed.0.bytes())),
2738        }
2739    }
2740}
2741
2742#[derive(Debug)]
2743pub(crate) enum HandshakePayload<'a> {
2744    HelloRequest,
2745    ClientHello(ClientHelloPayload),
2746    ServerHello(ServerHelloPayload),
2747    HelloRetryRequest(HelloRetryRequest),
2748    Certificate(CertificateChain<'a>),
2749    CertificateTls13(CertificatePayloadTls13<'a>),
2750    CompressedCertificate(CompressedCertificatePayload<'a>),
2751    ServerKeyExchange(ServerKeyExchangePayload),
2752    CertificateRequest(CertificateRequestPayload),
2753    CertificateRequestTls13(CertificateRequestPayloadTls13),
2754    CertificateVerify(DigitallySignedStruct),
2755    ServerHelloDone,
2756    EndOfEarlyData,
2757    ClientKeyExchange(Payload<'a>),
2758    NewSessionTicket(NewSessionTicketPayload),
2759    NewSessionTicketTls13(NewSessionTicketPayloadTls13),
2760    EncryptedExtensions(Vec<ServerExtension>),
2761    KeyUpdate(KeyUpdateRequest),
2762    Finished(Payload<'a>),
2763    CertificateStatus(CertificateStatus<'a>),
2764    MessageHash(Payload<'a>),
2765    Unknown((HandshakeType, Payload<'a>)),
2766}
2767
2768impl HandshakePayload<'_> {
2769    fn encode(&self, bytes: &mut Vec<u8>) {
2770        use self::HandshakePayload::*;
2771        match self {
2772            HelloRequest | ServerHelloDone | EndOfEarlyData => {}
2773            ClientHello(x) => x.encode(bytes),
2774            ServerHello(x) => x.encode(bytes),
2775            HelloRetryRequest(x) => x.encode(bytes),
2776            Certificate(x) => x.encode(bytes),
2777            CertificateTls13(x) => x.encode(bytes),
2778            CompressedCertificate(x) => x.encode(bytes),
2779            ServerKeyExchange(x) => x.encode(bytes),
2780            ClientKeyExchange(x) => x.encode(bytes),
2781            CertificateRequest(x) => x.encode(bytes),
2782            CertificateRequestTls13(x) => x.encode(bytes),
2783            CertificateVerify(x) => x.encode(bytes),
2784            NewSessionTicket(x) => x.encode(bytes),
2785            NewSessionTicketTls13(x) => x.encode(bytes),
2786            EncryptedExtensions(x) => x.encode(bytes),
2787            KeyUpdate(x) => x.encode(bytes),
2788            Finished(x) => x.encode(bytes),
2789            CertificateStatus(x) => x.encode(bytes),
2790            MessageHash(x) => x.encode(bytes),
2791            Unknown((_, x)) => x.encode(bytes),
2792        }
2793    }
2794
2795    pub(crate) fn handshake_type(&self) -> HandshakeType {
2796        use self::HandshakePayload::*;
2797        match self {
2798            HelloRequest => HandshakeType::HelloRequest,
2799            ClientHello(_) => HandshakeType::ClientHello,
2800            ServerHello(_) => HandshakeType::ServerHello,
2801            HelloRetryRequest(_) => HandshakeType::HelloRetryRequest,
2802            Certificate(_) | CertificateTls13(_) => HandshakeType::Certificate,
2803            CompressedCertificate(_) => HandshakeType::CompressedCertificate,
2804            ServerKeyExchange(_) => HandshakeType::ServerKeyExchange,
2805            CertificateRequest(_) | CertificateRequestTls13(_) => HandshakeType::CertificateRequest,
2806            CertificateVerify(_) => HandshakeType::CertificateVerify,
2807            ServerHelloDone => HandshakeType::ServerHelloDone,
2808            EndOfEarlyData => HandshakeType::EndOfEarlyData,
2809            ClientKeyExchange(_) => HandshakeType::ClientKeyExchange,
2810            NewSessionTicket(_) | NewSessionTicketTls13(_) => HandshakeType::NewSessionTicket,
2811            EncryptedExtensions(_) => HandshakeType::EncryptedExtensions,
2812            KeyUpdate(_) => HandshakeType::KeyUpdate,
2813            Finished(_) => HandshakeType::Finished,
2814            CertificateStatus(_) => HandshakeType::CertificateStatus,
2815            MessageHash(_) => HandshakeType::MessageHash,
2816            Unknown((t, _)) => *t,
2817        }
2818    }
2819
2820    fn wire_handshake_type(&self) -> HandshakeType {
2821        match self.handshake_type() {
2822            // A `HelloRetryRequest` appears on the wire as a `ServerHello` with a magic `random` value.
2823            HandshakeType::HelloRetryRequest => HandshakeType::ServerHello,
2824            other => other,
2825        }
2826    }
2827
2828    fn into_owned(self) -> HandshakePayload<'static> {
2829        use HandshakePayload::*;
2830
2831        match self {
2832            HelloRequest => HelloRequest,
2833            ClientHello(x) => ClientHello(x),
2834            ServerHello(x) => ServerHello(x),
2835            HelloRetryRequest(x) => HelloRetryRequest(x),
2836            Certificate(x) => Certificate(x.into_owned()),
2837            CertificateTls13(x) => CertificateTls13(x.into_owned()),
2838            CompressedCertificate(x) => CompressedCertificate(x.into_owned()),
2839            ServerKeyExchange(x) => ServerKeyExchange(x),
2840            CertificateRequest(x) => CertificateRequest(x),
2841            CertificateRequestTls13(x) => CertificateRequestTls13(x),
2842            CertificateVerify(x) => CertificateVerify(x),
2843            ServerHelloDone => ServerHelloDone,
2844            EndOfEarlyData => EndOfEarlyData,
2845            ClientKeyExchange(x) => ClientKeyExchange(x.into_owned()),
2846            NewSessionTicket(x) => NewSessionTicket(x),
2847            NewSessionTicketTls13(x) => NewSessionTicketTls13(x),
2848            EncryptedExtensions(x) => EncryptedExtensions(x),
2849            KeyUpdate(x) => KeyUpdate(x),
2850            Finished(x) => Finished(x.into_owned()),
2851            CertificateStatus(x) => CertificateStatus(x.into_owned()),
2852            MessageHash(x) => MessageHash(x.into_owned()),
2853            Unknown((t, x)) => Unknown((t, x.into_owned())),
2854        }
2855    }
2856}
2857
2858#[derive(Debug)]
2859pub struct HandshakeMessagePayload<'a>(pub(crate) HandshakePayload<'a>);
2860
2861impl<'a> Codec<'a> for HandshakeMessagePayload<'a> {
2862    fn encode(&self, bytes: &mut Vec<u8>) {
2863        self.payload_encode(bytes, Encoding::Standard);
2864    }
2865
2866    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2867        Self::read_version(r, ProtocolVersion::TLSv1_2)
2868    }
2869}
2870
2871impl<'a> HandshakeMessagePayload<'a> {
2872    pub(crate) fn read_version(
2873        r: &mut Reader<'a>,
2874        vers: ProtocolVersion,
2875    ) -> Result<Self, InvalidMessage> {
2876        let typ = HandshakeType::read(r)?;
2877        let len = codec::u24::read(r)?.0 as usize;
2878        let mut sub = r.sub(len)?;
2879
2880        let payload = match typ {
2881            HandshakeType::HelloRequest if sub.left() == 0 => HandshakePayload::HelloRequest,
2882            HandshakeType::ClientHello => {
2883                HandshakePayload::ClientHello(ClientHelloPayload::read(&mut sub)?)
2884            }
2885            HandshakeType::ServerHello => {
2886                let version = ProtocolVersion::read(&mut sub)?;
2887                let random = Random::read(&mut sub)?;
2888
2889                if random == HELLO_RETRY_REQUEST_RANDOM {
2890                    let mut hrr = HelloRetryRequest::read(&mut sub)?;
2891                    hrr.legacy_version = version;
2892                    HandshakePayload::HelloRetryRequest(hrr)
2893                } else {
2894                    let mut shp = ServerHelloPayload::read(&mut sub)?;
2895                    shp.legacy_version = version;
2896                    shp.random = random;
2897                    HandshakePayload::ServerHello(shp)
2898                }
2899            }
2900            HandshakeType::Certificate if vers == ProtocolVersion::TLSv1_3 => {
2901                let p = CertificatePayloadTls13::read(&mut sub)?;
2902                HandshakePayload::CertificateTls13(p)
2903            }
2904            HandshakeType::Certificate => {
2905                HandshakePayload::Certificate(CertificateChain::read(&mut sub)?)
2906            }
2907            HandshakeType::ServerKeyExchange => {
2908                let p = ServerKeyExchangePayload::read(&mut sub)?;
2909                HandshakePayload::ServerKeyExchange(p)
2910            }
2911            HandshakeType::ServerHelloDone => {
2912                sub.expect_empty("ServerHelloDone")?;
2913                HandshakePayload::ServerHelloDone
2914            }
2915            HandshakeType::ClientKeyExchange => {
2916                HandshakePayload::ClientKeyExchange(Payload::read(&mut sub))
2917            }
2918            HandshakeType::CertificateRequest if vers == ProtocolVersion::TLSv1_3 => {
2919                let p = CertificateRequestPayloadTls13::read(&mut sub)?;
2920                HandshakePayload::CertificateRequestTls13(p)
2921            }
2922            HandshakeType::CertificateRequest => {
2923                let p = CertificateRequestPayload::read(&mut sub)?;
2924                HandshakePayload::CertificateRequest(p)
2925            }
2926            HandshakeType::CompressedCertificate => HandshakePayload::CompressedCertificate(
2927                CompressedCertificatePayload::read(&mut sub)?,
2928            ),
2929            HandshakeType::CertificateVerify => {
2930                HandshakePayload::CertificateVerify(DigitallySignedStruct::read(&mut sub)?)
2931            }
2932            HandshakeType::NewSessionTicket if vers == ProtocolVersion::TLSv1_3 => {
2933                let p = NewSessionTicketPayloadTls13::read(&mut sub)?;
2934                HandshakePayload::NewSessionTicketTls13(p)
2935            }
2936            HandshakeType::NewSessionTicket => {
2937                let p = NewSessionTicketPayload::read(&mut sub)?;
2938                HandshakePayload::NewSessionTicket(p)
2939            }
2940            HandshakeType::EncryptedExtensions => {
2941                HandshakePayload::EncryptedExtensions(Vec::read(&mut sub)?)
2942            }
2943            HandshakeType::KeyUpdate => {
2944                HandshakePayload::KeyUpdate(KeyUpdateRequest::read(&mut sub)?)
2945            }
2946            HandshakeType::EndOfEarlyData => {
2947                sub.expect_empty("EndOfEarlyData")?;
2948                HandshakePayload::EndOfEarlyData
2949            }
2950            HandshakeType::Finished => HandshakePayload::Finished(Payload::read(&mut sub)),
2951            HandshakeType::CertificateStatus => {
2952                HandshakePayload::CertificateStatus(CertificateStatus::read(&mut sub)?)
2953            }
2954            HandshakeType::MessageHash => {
2955                // does not appear on the wire
2956                return Err(InvalidMessage::UnexpectedMessage("MessageHash"));
2957            }
2958            HandshakeType::HelloRetryRequest => {
2959                // not legal on wire
2960                return Err(InvalidMessage::UnexpectedMessage("HelloRetryRequest"));
2961            }
2962            _ => HandshakePayload::Unknown((typ, Payload::read(&mut sub))),
2963        };
2964
2965        sub.expect_empty("HandshakeMessagePayload")
2966            .map(|_| Self(payload))
2967    }
2968
2969    pub(crate) fn encoding_for_binder_signing(&self) -> Vec<u8> {
2970        let mut ret = self.get_encoding();
2971        let ret_len = ret.len() - self.total_binder_length();
2972        ret.truncate(ret_len);
2973        ret
2974    }
2975
2976    pub(crate) fn total_binder_length(&self) -> usize {
2977        match &self.0 {
2978            HandshakePayload::ClientHello(ch) => match ch.extensions.last() {
2979                Some(ClientExtension::PresharedKey(offer)) => {
2980                    let mut binders_encoding = Vec::new();
2981                    offer
2982                        .binders
2983                        .encode(&mut binders_encoding);
2984                    binders_encoding.len()
2985                }
2986                _ => 0,
2987            },
2988            _ => 0,
2989        }
2990    }
2991
2992    pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) {
2993        // output type, length, and encoded payload
2994        self.0
2995            .wire_handshake_type()
2996            .encode(bytes);
2997
2998        let nested = LengthPrefixedBuffer::new(
2999            ListLength::U24 {
3000                max: usize::MAX,
3001                error: InvalidMessage::MessageTooLarge,
3002            },
3003            bytes,
3004        );
3005
3006        match &self.0 {
3007            // for Server Hello and HelloRetryRequest payloads we need to encode the payload
3008            // differently based on the purpose of the encoding.
3009            HandshakePayload::ServerHello(payload) => payload.payload_encode(nested.buf, encoding),
3010            HandshakePayload::HelloRetryRequest(payload) => {
3011                payload.payload_encode(nested.buf, encoding)
3012            }
3013
3014            // All other payload types are encoded the same regardless of purpose.
3015            _ => self.0.encode(nested.buf),
3016        }
3017    }
3018
3019    pub(crate) fn build_handshake_hash(hash: &[u8]) -> Self {
3020        Self(HandshakePayload::MessageHash(Payload::new(hash.to_vec())))
3021    }
3022
3023    pub(crate) fn into_owned(self) -> HandshakeMessagePayload<'static> {
3024        HandshakeMessagePayload(self.0.into_owned())
3025    }
3026}
3027
3028#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
3029pub struct HpkeSymmetricCipherSuite {
3030    pub kdf_id: HpkeKdf,
3031    pub aead_id: HpkeAead,
3032}
3033
3034impl Codec<'_> for HpkeSymmetricCipherSuite {
3035    fn encode(&self, bytes: &mut Vec<u8>) {
3036        self.kdf_id.encode(bytes);
3037        self.aead_id.encode(bytes);
3038    }
3039
3040    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3041        Ok(Self {
3042            kdf_id: HpkeKdf::read(r)?,
3043            aead_id: HpkeAead::read(r)?,
3044        })
3045    }
3046}
3047
3048/// draft-ietf-tls-esni-24: `HpkeSymmetricCipherSuite cipher_suites<4..2^16-4>;`
3049impl TlsListElement for HpkeSymmetricCipherSuite {
3050    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
3051        empty_error: InvalidMessage::IllegalEmptyList("HpkeSymmetricCipherSuites"),
3052    };
3053}
3054
3055#[derive(Clone, Debug, PartialEq)]
3056pub struct HpkeKeyConfig {
3057    pub config_id: u8,
3058    pub kem_id: HpkeKem,
3059    /// draft-ietf-tls-esni-24: `opaque HpkePublicKey<1..2^16-1>;`
3060    pub public_key: PayloadU16<NonEmpty>,
3061    pub symmetric_cipher_suites: Vec<HpkeSymmetricCipherSuite>,
3062}
3063
3064impl Codec<'_> for HpkeKeyConfig {
3065    fn encode(&self, bytes: &mut Vec<u8>) {
3066        self.config_id.encode(bytes);
3067        self.kem_id.encode(bytes);
3068        self.public_key.encode(bytes);
3069        self.symmetric_cipher_suites
3070            .encode(bytes);
3071    }
3072
3073    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3074        Ok(Self {
3075            config_id: u8::read(r)?,
3076            kem_id: HpkeKem::read(r)?,
3077            public_key: PayloadU16::read(r)?,
3078            symmetric_cipher_suites: Vec::<HpkeSymmetricCipherSuite>::read(r)?,
3079        })
3080    }
3081}
3082
3083#[derive(Clone, Debug, PartialEq)]
3084pub struct EchConfigContents {
3085    pub key_config: HpkeKeyConfig,
3086    pub maximum_name_length: u8,
3087    pub public_name: DnsName<'static>,
3088    pub extensions: Vec<EchConfigExtension>,
3089}
3090
3091impl EchConfigContents {
3092    /// Returns true if there is more than one extension of a given
3093    /// type.
3094    pub(crate) fn has_duplicate_extension(&self) -> bool {
3095        has_duplicates::<_, _, u16>(
3096            self.extensions
3097                .iter()
3098                .map(|ext| ext.ext_type()),
3099        )
3100    }
3101
3102    /// Returns true if there is at least one mandatory unsupported extension.
3103    pub(crate) fn has_unknown_mandatory_extension(&self) -> bool {
3104        self.extensions
3105            .iter()
3106            // An extension is considered mandatory if the high bit of its type is set.
3107            .any(|ext| {
3108                matches!(ext.ext_type(), ExtensionType::Unknown(_))
3109                    && u16::from(ext.ext_type()) & 0x8000 != 0
3110            })
3111    }
3112}
3113
3114impl Codec<'_> for EchConfigContents {
3115    fn encode(&self, bytes: &mut Vec<u8>) {
3116        self.key_config.encode(bytes);
3117        self.maximum_name_length.encode(bytes);
3118        let dns_name = &self.public_name.borrow();
3119        PayloadU8::<MaybeEmpty>::encode_slice(dns_name.as_ref().as_ref(), bytes);
3120        self.extensions.encode(bytes);
3121    }
3122
3123    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3124        Ok(Self {
3125            key_config: HpkeKeyConfig::read(r)?,
3126            maximum_name_length: u8::read(r)?,
3127            public_name: {
3128                DnsName::try_from(
3129                    PayloadU8::<MaybeEmpty>::read(r)?
3130                        .0
3131                        .as_slice(),
3132                )
3133                .map_err(|_| InvalidMessage::InvalidServerName)?
3134                .to_owned()
3135            },
3136            extensions: Vec::read(r)?,
3137        })
3138    }
3139}
3140
3141/// An encrypted client hello (ECH) config.
3142#[derive(Clone, Debug, PartialEq)]
3143pub enum EchConfigPayload {
3144    /// A recognized V18 ECH configuration.
3145    V18(EchConfigContents),
3146    /// An unknown version ECH configuration.
3147    Unknown {
3148        version: EchVersion,
3149        contents: PayloadU16,
3150    },
3151}
3152
3153impl TlsListElement for EchConfigPayload {
3154    const SIZE_LEN: ListLength = ListLength::U16;
3155}
3156
3157impl Codec<'_> for EchConfigPayload {
3158    fn encode(&self, bytes: &mut Vec<u8>) {
3159        match self {
3160            Self::V18(c) => {
3161                // Write the version, the length, and the contents.
3162                EchVersion::V18.encode(bytes);
3163                let inner = LengthPrefixedBuffer::new(ListLength::U16, bytes);
3164                c.encode(inner.buf);
3165            }
3166            Self::Unknown { version, contents } => {
3167                // Unknown configuration versions are opaque.
3168                version.encode(bytes);
3169                contents.encode(bytes);
3170            }
3171        }
3172    }
3173
3174    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3175        let version = EchVersion::read(r)?;
3176        let length = u16::read(r)?;
3177        let mut contents = r.sub(length as usize)?;
3178
3179        Ok(match version {
3180            EchVersion::V18 => Self::V18(EchConfigContents::read(&mut contents)?),
3181            _ => {
3182                // Note: we don't PayloadU16::read() here because we've already read the length prefix.
3183                let data = PayloadU16::new(contents.rest().into());
3184                Self::Unknown {
3185                    version,
3186                    contents: data,
3187                }
3188            }
3189        })
3190    }
3191}
3192
3193#[derive(Clone, Debug, PartialEq)]
3194pub enum EchConfigExtension {
3195    Unknown(UnknownExtension),
3196}
3197
3198impl EchConfigExtension {
3199    pub(crate) fn ext_type(&self) -> ExtensionType {
3200        match self {
3201            Self::Unknown(r) => r.typ,
3202        }
3203    }
3204}
3205
3206impl Codec<'_> for EchConfigExtension {
3207    fn encode(&self, bytes: &mut Vec<u8>) {
3208        self.ext_type().encode(bytes);
3209
3210        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
3211        match self {
3212            Self::Unknown(r) => r.encode(nested.buf),
3213        }
3214    }
3215
3216    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3217        let typ = ExtensionType::read(r)?;
3218        let len = u16::read(r)? as usize;
3219        let mut sub = r.sub(len)?;
3220
3221        #[allow(clippy::match_single_binding)] // Future-proofing.
3222        let ext = match typ {
3223            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
3224        };
3225
3226        sub.expect_empty("EchConfigExtension")
3227            .map(|_| ext)
3228    }
3229}
3230
3231impl TlsListElement for EchConfigExtension {
3232    const SIZE_LEN: ListLength = ListLength::U16;
3233}
3234
3235/// Representation of the `ECHClientHello` client extension specified in
3236/// [draft-ietf-tls-esni Section 5].
3237///
3238/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3239#[derive(Clone, Debug)]
3240pub(crate) enum EncryptedClientHello {
3241    /// A `ECHClientHello` with type [EchClientHelloType::ClientHelloOuter].
3242    Outer(EncryptedClientHelloOuter),
3243    /// An empty `ECHClientHello` with type [EchClientHelloType::ClientHelloInner].
3244    ///
3245    /// This variant has no payload.
3246    Inner,
3247}
3248
3249impl Codec<'_> for EncryptedClientHello {
3250    fn encode(&self, bytes: &mut Vec<u8>) {
3251        match self {
3252            Self::Outer(payload) => {
3253                EchClientHelloType::ClientHelloOuter.encode(bytes);
3254                payload.encode(bytes);
3255            }
3256            Self::Inner => {
3257                EchClientHelloType::ClientHelloInner.encode(bytes);
3258                // Empty payload.
3259            }
3260        }
3261    }
3262
3263    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3264        match EchClientHelloType::read(r)? {
3265            EchClientHelloType::ClientHelloOuter => {
3266                Ok(Self::Outer(EncryptedClientHelloOuter::read(r)?))
3267            }
3268            EchClientHelloType::ClientHelloInner => Ok(Self::Inner),
3269            _ => Err(InvalidMessage::InvalidContentType),
3270        }
3271    }
3272}
3273
3274/// Representation of the ECHClientHello extension with type outer specified in
3275/// [draft-ietf-tls-esni Section 5].
3276///
3277/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3278#[derive(Clone, Debug)]
3279pub(crate) struct EncryptedClientHelloOuter {
3280    /// The cipher suite used to encrypt ClientHelloInner. Must match a value from
3281    /// ECHConfigContents.cipher_suites list.
3282    pub cipher_suite: HpkeSymmetricCipherSuite,
3283    /// The ECHConfigContents.key_config.config_id for the chosen ECHConfig.
3284    pub config_id: u8,
3285    /// The HPKE encapsulated key, used by servers to decrypt the corresponding payload field.
3286    /// This field is empty in a ClientHelloOuter sent in response to a HelloRetryRequest.
3287    pub enc: PayloadU16,
3288    /// The serialized and encrypted ClientHelloInner structure, encrypted using HPKE.
3289    pub payload: PayloadU16<NonEmpty>,
3290}
3291
3292impl Codec<'_> for EncryptedClientHelloOuter {
3293    fn encode(&self, bytes: &mut Vec<u8>) {
3294        self.cipher_suite.encode(bytes);
3295        self.config_id.encode(bytes);
3296        self.enc.encode(bytes);
3297        self.payload.encode(bytes);
3298    }
3299
3300    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3301        Ok(Self {
3302            cipher_suite: HpkeSymmetricCipherSuite::read(r)?,
3303            config_id: u8::read(r)?,
3304            enc: PayloadU16::read(r)?,
3305            payload: PayloadU16::read(r)?,
3306        })
3307    }
3308}
3309
3310/// Representation of the ECHEncryptedExtensions extension specified in
3311/// [draft-ietf-tls-esni Section 5].
3312///
3313/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3314#[derive(Clone, Debug)]
3315pub(crate) struct ServerEncryptedClientHello {
3316    pub(crate) retry_configs: Vec<EchConfigPayload>,
3317}
3318
3319impl Codec<'_> for ServerEncryptedClientHello {
3320    fn encode(&self, bytes: &mut Vec<u8>) {
3321        self.retry_configs.encode(bytes);
3322    }
3323
3324    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3325        Ok(Self {
3326            retry_configs: Vec::<EchConfigPayload>::read(r)?,
3327        })
3328    }
3329}
3330
3331/// The method of encoding to use for a handshake message.
3332///
3333/// In some cases a handshake message may be encoded differently depending on the purpose
3334/// the encoded message is being used for. For example, a [ServerHelloPayload] may be encoded
3335/// with the last 8 bytes of the random zeroed out when being encoded for ECH confirmation.
3336pub(crate) enum Encoding {
3337    /// Standard RFC 8446 encoding.
3338    Standard,
3339    /// Encoding for ECH confirmation.
3340    EchConfirmation,
3341    /// Encoding for ECH inner client hello.
3342    EchInnerHello { to_compress: Vec<ExtensionType> },
3343}
3344
3345fn has_duplicates<I: IntoIterator<Item = E>, E: Into<T>, T: Eq + Ord>(iter: I) -> bool {
3346    let mut seen = BTreeSet::new();
3347
3348    for x in iter {
3349        if !seen.insert(x.into()) {
3350            return true;
3351        }
3352    }
3353
3354    false
3355}
3356
3357#[cfg(test)]
3358mod tests {
3359    use super::*;
3360
3361    #[test]
3362    fn test_ech_config_dupe_exts() {
3363        let unknown_ext = EchConfigExtension::Unknown(UnknownExtension {
3364            typ: ExtensionType::Unknown(0x42),
3365            payload: Payload::new(vec![0x42]),
3366        });
3367        let mut config = config_template();
3368        config
3369            .extensions
3370            .push(unknown_ext.clone());
3371        config.extensions.push(unknown_ext);
3372
3373        assert!(config.has_duplicate_extension());
3374        assert!(!config.has_unknown_mandatory_extension());
3375    }
3376
3377    #[test]
3378    fn test_ech_config_mandatory_exts() {
3379        let mandatory_unknown_ext = EchConfigExtension::Unknown(UnknownExtension {
3380            typ: ExtensionType::Unknown(0x42 | 0x8000), // Note: high bit set.
3381            payload: Payload::new(vec![0x42]),
3382        });
3383        let mut config = config_template();
3384        config
3385            .extensions
3386            .push(mandatory_unknown_ext);
3387
3388        assert!(!config.has_duplicate_extension());
3389        assert!(config.has_unknown_mandatory_extension());
3390    }
3391
3392    fn config_template() -> EchConfigContents {
3393        EchConfigContents {
3394            key_config: HpkeKeyConfig {
3395                config_id: 0,
3396                kem_id: HpkeKem::DHKEM_P256_HKDF_SHA256,
3397                public_key: PayloadU16::new(b"xxx".into()),
3398                symmetric_cipher_suites: vec![HpkeSymmetricCipherSuite {
3399                    kdf_id: HpkeKdf::HKDF_SHA256,
3400                    aead_id: HpkeAead::AES_128_GCM,
3401                }],
3402            },
3403            maximum_name_length: 0,
3404            public_name: DnsName::try_from("example.com").unwrap(),
3405            extensions: vec![],
3406        }
3407    }
3408}