hyper_util/client/legacy/connect/
http.rs

1use std::error::Error as StdError;
2use std::fmt;
3use std::future::Future;
4use std::io;
5use std::marker::PhantomData;
6use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{self, Poll};
10use std::time::Duration;
11
12use futures_core::ready;
13use futures_util::future::Either;
14use http::uri::{Scheme, Uri};
15use pin_project_lite::pin_project;
16use socket2::TcpKeepalive;
17use tokio::net::{TcpSocket, TcpStream};
18use tokio::time::Sleep;
19use tracing::{debug, trace, warn};
20
21use super::dns::{self, resolve, GaiResolver, Resolve};
22use super::{Connected, Connection};
23use crate::rt::TokioIo;
24
25/// A connector for the `http` scheme.
26///
27/// Performs DNS resolution in a thread pool, and then connects over TCP.
28///
29/// # Note
30///
31/// Sets the [`HttpInfo`](HttpInfo) value on responses, which includes
32/// transport information such as the remote socket address used.
33#[derive(Clone)]
34pub struct HttpConnector<R = GaiResolver> {
35    config: Arc<Config>,
36    resolver: R,
37}
38
39/// Extra information about the transport when an HttpConnector is used.
40///
41/// # Example
42///
43/// ```
44/// # fn doc(res: http::Response<()>) {
45/// use hyper_util::client::legacy::connect::HttpInfo;
46///
47/// // res = http::Response
48/// res
49///     .extensions()
50///     .get::<HttpInfo>()
51///     .map(|info| {
52///         println!("remote addr = {}", info.remote_addr());
53///     });
54/// # }
55/// ```
56///
57/// # Note
58///
59/// If a different connector is used besides [`HttpConnector`](HttpConnector),
60/// this value will not exist in the extensions. Consult that specific
61/// connector to see what "extra" information it might provide to responses.
62#[derive(Clone, Debug)]
63pub struct HttpInfo {
64    remote_addr: SocketAddr,
65    local_addr: SocketAddr,
66}
67
68#[derive(Clone)]
69struct Config {
70    connect_timeout: Option<Duration>,
71    enforce_http: bool,
72    happy_eyeballs_timeout: Option<Duration>,
73    tcp_keepalive_config: TcpKeepaliveConfig,
74    local_address_ipv4: Option<Ipv4Addr>,
75    local_address_ipv6: Option<Ipv6Addr>,
76    nodelay: bool,
77    reuse_address: bool,
78    send_buffer_size: Option<usize>,
79    recv_buffer_size: Option<usize>,
80    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
81    interface: Option<String>,
82    #[cfg(any(
83        target_os = "illumos",
84        target_os = "ios",
85        target_os = "macos",
86        target_os = "solaris",
87        target_os = "tvos",
88        target_os = "visionos",
89        target_os = "watchos",
90    ))]
91    interface: Option<std::ffi::CString>,
92    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
93    tcp_user_timeout: Option<Duration>,
94}
95
96#[derive(Default, Debug, Clone, Copy)]
97struct TcpKeepaliveConfig {
98    time: Option<Duration>,
99    interval: Option<Duration>,
100    retries: Option<u32>,
101}
102
103impl TcpKeepaliveConfig {
104    /// Converts into a `socket2::TcpKeealive` if there is any keep alive configuration.
105    fn into_tcpkeepalive(self) -> Option<TcpKeepalive> {
106        let mut dirty = false;
107        let mut ka = TcpKeepalive::new();
108        if let Some(time) = self.time {
109            ka = ka.with_time(time);
110            dirty = true
111        }
112        if let Some(interval) = self.interval {
113            ka = Self::ka_with_interval(ka, interval, &mut dirty)
114        };
115        if let Some(retries) = self.retries {
116            ka = Self::ka_with_retries(ka, retries, &mut dirty)
117        };
118        if dirty {
119            Some(ka)
120        } else {
121            None
122        }
123    }
124
125    #[cfg(
126        // See https://docs.rs/socket2/0.5.8/src/socket2/lib.rs.html#511-525
127        any(
128            target_os = "android",
129            target_os = "dragonfly",
130            target_os = "freebsd",
131            target_os = "fuchsia",
132            target_os = "illumos",
133            target_os = "ios",
134            target_os = "visionos",
135            target_os = "linux",
136            target_os = "macos",
137            target_os = "netbsd",
138            target_os = "tvos",
139            target_os = "watchos",
140            target_os = "windows",
141        )
142    )]
143    fn ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive {
144        *dirty = true;
145        ka.with_interval(interval)
146    }
147
148    #[cfg(not(
149         // See https://docs.rs/socket2/0.5.8/src/socket2/lib.rs.html#511-525
150        any(
151            target_os = "android",
152            target_os = "dragonfly",
153            target_os = "freebsd",
154            target_os = "fuchsia",
155            target_os = "illumos",
156            target_os = "ios",
157            target_os = "visionos",
158            target_os = "linux",
159            target_os = "macos",
160            target_os = "netbsd",
161            target_os = "tvos",
162            target_os = "watchos",
163            target_os = "windows",
164        )
165    ))]
166    fn ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive {
167        ka // no-op as keepalive interval is not supported on this platform
168    }
169
170    #[cfg(
171        // See https://docs.rs/socket2/0.5.8/src/socket2/lib.rs.html#557-570
172        any(
173            target_os = "android",
174            target_os = "dragonfly",
175            target_os = "freebsd",
176            target_os = "fuchsia",
177            target_os = "illumos",
178            target_os = "ios",
179            target_os = "visionos",
180            target_os = "linux",
181            target_os = "macos",
182            target_os = "netbsd",
183            target_os = "tvos",
184            target_os = "watchos",
185        )
186    )]
187    fn ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive {
188        *dirty = true;
189        ka.with_retries(retries)
190    }
191
192    #[cfg(not(
193        // See https://docs.rs/socket2/0.5.8/src/socket2/lib.rs.html#557-570
194        any(
195            target_os = "android",
196            target_os = "dragonfly",
197            target_os = "freebsd",
198            target_os = "fuchsia",
199            target_os = "illumos",
200            target_os = "ios",
201            target_os = "visionos",
202            target_os = "linux",
203            target_os = "macos",
204            target_os = "netbsd",
205            target_os = "tvos",
206            target_os = "watchos",
207        )
208    ))]
209    fn ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive {
210        ka // no-op as keepalive retries is not supported on this platform
211    }
212}
213
214// ===== impl HttpConnector =====
215
216impl HttpConnector {
217    /// Construct a new HttpConnector.
218    pub fn new() -> HttpConnector {
219        HttpConnector::new_with_resolver(GaiResolver::new())
220    }
221}
222
223impl<R> HttpConnector<R> {
224    /// Construct a new HttpConnector.
225    ///
226    /// Takes a [`Resolver`](crate::client::legacy::connect::dns#resolvers-are-services) to handle DNS lookups.
227    pub fn new_with_resolver(resolver: R) -> HttpConnector<R> {
228        HttpConnector {
229            config: Arc::new(Config {
230                connect_timeout: None,
231                enforce_http: true,
232                happy_eyeballs_timeout: Some(Duration::from_millis(300)),
233                tcp_keepalive_config: TcpKeepaliveConfig::default(),
234                local_address_ipv4: None,
235                local_address_ipv6: None,
236                nodelay: false,
237                reuse_address: false,
238                send_buffer_size: None,
239                recv_buffer_size: None,
240                #[cfg(any(
241                    target_os = "android",
242                    target_os = "fuchsia",
243                    target_os = "illumos",
244                    target_os = "ios",
245                    target_os = "linux",
246                    target_os = "macos",
247                    target_os = "solaris",
248                    target_os = "tvos",
249                    target_os = "visionos",
250                    target_os = "watchos",
251                ))]
252                interface: None,
253                #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
254                tcp_user_timeout: None,
255            }),
256            resolver,
257        }
258    }
259
260    /// Option to enforce all `Uri`s have the `http` scheme.
261    ///
262    /// Enabled by default.
263    #[inline]
264    pub fn enforce_http(&mut self, is_enforced: bool) {
265        self.config_mut().enforce_http = is_enforced;
266    }
267
268    /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration
269    /// to remain idle before sending TCP keepalive probes.
270    ///
271    /// If `None`, keepalive is disabled.
272    ///
273    /// Default is `None`.
274    #[inline]
275    pub fn set_keepalive(&mut self, time: Option<Duration>) {
276        self.config_mut().tcp_keepalive_config.time = time;
277    }
278
279    /// Set the duration between two successive TCP keepalive retransmissions,
280    /// if acknowledgement to the previous keepalive transmission is not received.
281    #[inline]
282    pub fn set_keepalive_interval(&mut self, interval: Option<Duration>) {
283        self.config_mut().tcp_keepalive_config.interval = interval;
284    }
285
286    /// Set the number of retransmissions to be carried out before declaring that remote end is not available.
287    #[inline]
288    pub fn set_keepalive_retries(&mut self, retries: Option<u32>) {
289        self.config_mut().tcp_keepalive_config.retries = retries;
290    }
291
292    /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`.
293    ///
294    /// Default is `false`.
295    #[inline]
296    pub fn set_nodelay(&mut self, nodelay: bool) {
297        self.config_mut().nodelay = nodelay;
298    }
299
300    /// Sets the value of the SO_SNDBUF option on the socket.
301    #[inline]
302    pub fn set_send_buffer_size(&mut self, size: Option<usize>) {
303        self.config_mut().send_buffer_size = size;
304    }
305
306    /// Sets the value of the SO_RCVBUF option on the socket.
307    #[inline]
308    pub fn set_recv_buffer_size(&mut self, size: Option<usize>) {
309        self.config_mut().recv_buffer_size = size;
310    }
311
312    /// Set that all sockets are bound to the configured address before connection.
313    ///
314    /// If `None`, the sockets will not be bound.
315    ///
316    /// Default is `None`.
317    #[inline]
318    pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
319        let (v4, v6) = match addr {
320            Some(IpAddr::V4(a)) => (Some(a), None),
321            Some(IpAddr::V6(a)) => (None, Some(a)),
322            _ => (None, None),
323        };
324
325        let cfg = self.config_mut();
326
327        cfg.local_address_ipv4 = v4;
328        cfg.local_address_ipv6 = v6;
329    }
330
331    /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's
332    /// preferences) before connection.
333    #[inline]
334    pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
335        let cfg = self.config_mut();
336
337        cfg.local_address_ipv4 = Some(addr_ipv4);
338        cfg.local_address_ipv6 = Some(addr_ipv6);
339    }
340
341    /// Set the connect timeout.
342    ///
343    /// If a domain resolves to multiple IP addresses, the timeout will be
344    /// evenly divided across them.
345    ///
346    /// Default is `None`.
347    #[inline]
348    pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
349        self.config_mut().connect_timeout = dur;
350    }
351
352    /// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm.
353    ///
354    /// If hostname resolves to both IPv4 and IPv6 addresses and connection
355    /// cannot be established using preferred address family before timeout
356    /// elapses, then connector will in parallel attempt connection using other
357    /// address family.
358    ///
359    /// If `None`, parallel connection attempts are disabled.
360    ///
361    /// Default is 300 milliseconds.
362    ///
363    /// [RFC 6555]: https://tools.ietf.org/html/rfc6555
364    #[inline]
365    pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) {
366        self.config_mut().happy_eyeballs_timeout = dur;
367    }
368
369    /// Set that all socket have `SO_REUSEADDR` set to the supplied value `reuse_address`.
370    ///
371    /// Default is `false`.
372    #[inline]
373    pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self {
374        self.config_mut().reuse_address = reuse_address;
375        self
376    }
377
378    /// Sets the name of the interface to bind sockets produced by this
379    /// connector.
380    ///
381    /// On Linux, this sets the `SO_BINDTODEVICE` option on this socket (see
382    /// [`man 7 socket`] for details). On macOS (and macOS-derived systems like
383    /// iOS), illumos, and Solaris, this will instead use the `IP_BOUND_IF`
384    /// socket option (see [`man 7p ip`]).
385    ///
386    /// If a socket is bound to an interface, only packets received from that particular
387    /// interface are processed by the socket. Note that this only works for some socket
388    /// types, particularly `AF_INET`` sockets.
389    ///
390    /// On Linux it can be used to specify a [VRF], but the binary needs
391    /// to either have `CAP_NET_RAW` or to be run as root.
392    ///
393    /// This function is only available on the following operating systems:
394    /// - Linux, including Android
395    /// - Fuchsia
396    /// - illumos and Solaris
397    /// - macOS, iOS, visionOS, watchOS, and tvOS
398    ///
399    /// [VRF]: https://www.kernel.org/doc/Documentation/networking/vrf.txt
400    /// [`man 7 socket`] https://man7.org/linux/man-pages/man7/socket.7.html
401    /// [`man 7p ip`]: https://docs.oracle.com/cd/E86824_01/html/E54777/ip-7p.html
402    #[cfg(any(
403        target_os = "android",
404        target_os = "fuchsia",
405        target_os = "illumos",
406        target_os = "ios",
407        target_os = "linux",
408        target_os = "macos",
409        target_os = "solaris",
410        target_os = "tvos",
411        target_os = "visionos",
412        target_os = "watchos",
413    ))]
414    #[inline]
415    pub fn set_interface<S: Into<String>>(&mut self, interface: S) -> &mut Self {
416        let interface = interface.into();
417        #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
418        {
419            self.config_mut().interface = Some(interface);
420        }
421        #[cfg(not(any(target_os = "android", target_os = "fuchsia", target_os = "linux")))]
422        {
423            let interface = std::ffi::CString::new(interface)
424                .expect("interface name should not have nulls in it");
425            self.config_mut().interface = Some(interface);
426        }
427        self
428    }
429
430    /// Sets the value of the TCP_USER_TIMEOUT option on the socket.
431    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
432    #[inline]
433    pub fn set_tcp_user_timeout(&mut self, time: Option<Duration>) {
434        self.config_mut().tcp_user_timeout = time;
435    }
436
437    // private
438
439    fn config_mut(&mut self) -> &mut Config {
440        // If the are HttpConnector clones, this will clone the inner
441        // config. So mutating the config won't ever affect previous
442        // clones.
443        Arc::make_mut(&mut self.config)
444    }
445}
446
447static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
448static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
449static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
450
451// R: Debug required for now to allow adding it to debug output later...
452impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> {
453    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
454        f.debug_struct("HttpConnector").finish()
455    }
456}
457
458impl<R> tower_service::Service<Uri> for HttpConnector<R>
459where
460    R: Resolve + Clone + Send + Sync + 'static,
461    R::Future: Send,
462{
463    type Response = TokioIo<TcpStream>;
464    type Error = ConnectError;
465    type Future = HttpConnecting<R>;
466
467    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
468        ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
469        Poll::Ready(Ok(()))
470    }
471
472    fn call(&mut self, dst: Uri) -> Self::Future {
473        let mut self_ = self.clone();
474        HttpConnecting {
475            fut: Box::pin(async move { self_.call_async(dst).await }),
476            _marker: PhantomData,
477        }
478    }
479}
480
481fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> {
482    trace!(
483        "Http::connect; scheme={:?}, host={:?}, port={:?}",
484        dst.scheme(),
485        dst.host(),
486        dst.port(),
487    );
488
489    if config.enforce_http {
490        if dst.scheme() != Some(&Scheme::HTTP) {
491            return Err(ConnectError {
492                msg: INVALID_NOT_HTTP,
493                addr: None,
494                cause: None,
495            });
496        }
497    } else if dst.scheme().is_none() {
498        return Err(ConnectError {
499            msg: INVALID_MISSING_SCHEME,
500            addr: None,
501            cause: None,
502        });
503    }
504
505    let host = match dst.host() {
506        Some(s) => s,
507        None => {
508            return Err(ConnectError {
509                msg: INVALID_MISSING_HOST,
510                addr: None,
511                cause: None,
512            })
513        }
514    };
515    let port = match dst.port() {
516        Some(port) => port.as_u16(),
517        None => {
518            if dst.scheme() == Some(&Scheme::HTTPS) {
519                443
520            } else {
521                80
522            }
523        }
524    };
525
526    Ok((host, port))
527}
528
529impl<R> HttpConnector<R>
530where
531    R: Resolve,
532{
533    async fn call_async(&mut self, dst: Uri) -> Result<TokioIo<TcpStream>, ConnectError> {
534        let config = &self.config;
535
536        let (host, port) = get_host_port(config, &dst)?;
537        let host = host.trim_start_matches('[').trim_end_matches(']');
538
539        // If the host is already an IP addr (v4 or v6),
540        // skip resolving the dns and start connecting right away.
541        let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) {
542            addrs
543        } else {
544            let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
545                .await
546                .map_err(ConnectError::dns)?;
547            let addrs = addrs
548                .map(|mut addr| {
549                    set_port(&mut addr, port, dst.port().is_some());
550
551                    addr
552                })
553                .collect();
554            dns::SocketAddrs::new(addrs)
555        };
556
557        let c = ConnectingTcp::new(addrs, config);
558
559        let sock = c.connect().await?;
560
561        if let Err(e) = sock.set_nodelay(config.nodelay) {
562            warn!("tcp set_nodelay error: {}", e);
563        }
564
565        Ok(TokioIo::new(sock))
566    }
567}
568
569impl Connection for TcpStream {
570    fn connected(&self) -> Connected {
571        let connected = Connected::new();
572        if let (Ok(remote_addr), Ok(local_addr)) = (self.peer_addr(), self.local_addr()) {
573            connected.extra(HttpInfo {
574                remote_addr,
575                local_addr,
576            })
577        } else {
578            connected
579        }
580    }
581}
582
583#[cfg(unix)]
584impl Connection for tokio::net::UnixStream {
585    fn connected(&self) -> Connected {
586        Connected::new()
587    }
588}
589
590#[cfg(windows)]
591impl Connection for tokio::net::windows::named_pipe::NamedPipeClient {
592    fn connected(&self) -> Connected {
593        Connected::new()
594    }
595}
596
597// Implement `Connection` for generic `TokioIo<T>` so that external crates can
598// implement their own `HttpConnector` with `TokioIo<CustomTcpStream>`.
599impl<T> Connection for TokioIo<T>
600where
601    T: Connection,
602{
603    fn connected(&self) -> Connected {
604        self.inner().connected()
605    }
606}
607
608impl HttpInfo {
609    /// Get the remote address of the transport used.
610    pub fn remote_addr(&self) -> SocketAddr {
611        self.remote_addr
612    }
613
614    /// Get the local address of the transport used.
615    pub fn local_addr(&self) -> SocketAddr {
616        self.local_addr
617    }
618}
619
620pin_project! {
621    // Not publicly exported (so missing_docs doesn't trigger).
622    //
623    // We return this `Future` instead of the `Pin<Box<dyn Future>>` directly
624    // so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot
625    // (and thus we can change the type in the future).
626    #[must_use = "futures do nothing unless polled"]
627    #[allow(missing_debug_implementations)]
628    pub struct HttpConnecting<R> {
629        #[pin]
630        fut: BoxConnecting,
631        _marker: PhantomData<R>,
632    }
633}
634
635type ConnectResult = Result<TokioIo<TcpStream>, ConnectError>;
636type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>;
637
638impl<R: Resolve> Future for HttpConnecting<R> {
639    type Output = ConnectResult;
640
641    fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
642        self.project().fut.poll(cx)
643    }
644}
645
646// Not publicly exported (so missing_docs doesn't trigger).
647pub struct ConnectError {
648    msg: &'static str,
649    addr: Option<SocketAddr>,
650    cause: Option<Box<dyn StdError + Send + Sync>>,
651}
652
653impl ConnectError {
654    fn new<E>(msg: &'static str, cause: E) -> ConnectError
655    where
656        E: Into<Box<dyn StdError + Send + Sync>>,
657    {
658        ConnectError {
659            msg,
660            addr: None,
661            cause: Some(cause.into()),
662        }
663    }
664
665    fn dns<E>(cause: E) -> ConnectError
666    where
667        E: Into<Box<dyn StdError + Send + Sync>>,
668    {
669        ConnectError::new("dns error", cause)
670    }
671
672    fn m<E>(msg: &'static str) -> impl FnOnce(E) -> ConnectError
673    where
674        E: Into<Box<dyn StdError + Send + Sync>>,
675    {
676        move |cause| ConnectError::new(msg, cause)
677    }
678}
679
680impl fmt::Debug for ConnectError {
681    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
682        let mut b = f.debug_tuple("ConnectError");
683        b.field(&self.msg);
684        if let Some(ref addr) = self.addr {
685            b.field(addr);
686        }
687        if let Some(ref cause) = self.cause {
688            b.field(cause);
689        }
690        b.finish()
691    }
692}
693
694impl fmt::Display for ConnectError {
695    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
696        f.write_str(self.msg)
697    }
698}
699
700impl StdError for ConnectError {
701    fn source(&self) -> Option<&(dyn StdError + 'static)> {
702        self.cause.as_ref().map(|e| &**e as _)
703    }
704}
705
706struct ConnectingTcp<'a> {
707    preferred: ConnectingTcpRemote,
708    fallback: Option<ConnectingTcpFallback>,
709    config: &'a Config,
710}
711
712impl<'a> ConnectingTcp<'a> {
713    fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self {
714        if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
715            let (preferred_addrs, fallback_addrs) = remote_addrs
716                .split_by_preference(config.local_address_ipv4, config.local_address_ipv6);
717            if fallback_addrs.is_empty() {
718                return ConnectingTcp {
719                    preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
720                    fallback: None,
721                    config,
722                };
723            }
724
725            ConnectingTcp {
726                preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
727                fallback: Some(ConnectingTcpFallback {
728                    delay: tokio::time::sleep(fallback_timeout),
729                    remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout),
730                }),
731                config,
732            }
733        } else {
734            ConnectingTcp {
735                preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout),
736                fallback: None,
737                config,
738            }
739        }
740    }
741}
742
743struct ConnectingTcpFallback {
744    delay: Sleep,
745    remote: ConnectingTcpRemote,
746}
747
748struct ConnectingTcpRemote {
749    addrs: dns::SocketAddrs,
750    connect_timeout: Option<Duration>,
751}
752
753impl ConnectingTcpRemote {
754    fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self {
755        let connect_timeout = connect_timeout.and_then(|t| t.checked_div(addrs.len() as u32));
756
757        Self {
758            addrs,
759            connect_timeout,
760        }
761    }
762}
763
764impl ConnectingTcpRemote {
765    async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> {
766        let mut err = None;
767        for addr in &mut self.addrs {
768            debug!("connecting to {}", addr);
769            match connect(&addr, config, self.connect_timeout)?.await {
770                Ok(tcp) => {
771                    debug!("connected to {}", addr);
772                    return Ok(tcp);
773                }
774                Err(mut e) => {
775                    trace!("connect error for {}: {:?}", addr, e);
776                    e.addr = Some(addr);
777                    // only return the first error, we assume it's the most relevant
778                    if err.is_none() {
779                        err = Some(e);
780                    }
781                }
782            }
783        }
784
785        match err {
786            Some(e) => Err(e),
787            None => Err(ConnectError::new(
788                "tcp connect error",
789                std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"),
790            )),
791        }
792    }
793}
794
795fn bind_local_address(
796    socket: &socket2::Socket,
797    dst_addr: &SocketAddr,
798    local_addr_ipv4: &Option<Ipv4Addr>,
799    local_addr_ipv6: &Option<Ipv6Addr>,
800) -> io::Result<()> {
801    match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
802        (SocketAddr::V4(_), Some(addr), _) => {
803            socket.bind(&SocketAddr::new((*addr).into(), 0).into())?;
804        }
805        (SocketAddr::V6(_), _, Some(addr)) => {
806            socket.bind(&SocketAddr::new((*addr).into(), 0).into())?;
807        }
808        _ => {
809            if cfg!(windows) {
810                // Windows requires a socket be bound before calling connect
811                let any: SocketAddr = match *dst_addr {
812                    SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
813                    SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
814                };
815                socket.bind(&any.into())?;
816            }
817        }
818    }
819
820    Ok(())
821}
822
823fn connect(
824    addr: &SocketAddr,
825    config: &Config,
826    connect_timeout: Option<Duration>,
827) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
828    // TODO(eliza): if Tokio's `TcpSocket` gains support for setting the
829    // keepalive timeout, it would be nice to use that instead of socket2,
830    // and avoid the unsafe `into_raw_fd`/`from_raw_fd` dance...
831    use socket2::{Domain, Protocol, Socket, Type};
832
833    let domain = Domain::for_address(*addr);
834    let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
835        .map_err(ConnectError::m("tcp open error"))?;
836
837    // When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is
838    // responsible for ensuring O_NONBLOCK is set.
839    socket
840        .set_nonblocking(true)
841        .map_err(ConnectError::m("tcp set_nonblocking error"))?;
842
843    if let Some(tcp_keepalive) = &config.tcp_keepalive_config.into_tcpkeepalive() {
844        if let Err(e) = socket.set_tcp_keepalive(tcp_keepalive) {
845            warn!("tcp set_keepalive error: {}", e);
846        }
847    }
848
849    // That this only works for some socket types, particularly AF_INET sockets.
850    #[cfg(any(
851        target_os = "android",
852        target_os = "fuchsia",
853        target_os = "illumos",
854        target_os = "ios",
855        target_os = "linux",
856        target_os = "macos",
857        target_os = "solaris",
858        target_os = "tvos",
859        target_os = "visionos",
860        target_os = "watchos",
861    ))]
862    if let Some(interface) = &config.interface {
863        // On Linux-like systems, set the interface to bind using
864        // `SO_BINDTODEVICE`.
865        #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
866        socket
867            .bind_device(Some(interface.as_bytes()))
868            .map_err(ConnectError::m("tcp bind interface error"))?;
869
870        // On macOS-like and Solaris-like systems, we instead use `IP_BOUND_IF`.
871        // This socket option desires an integer index for the interface, so we
872        // must first determine the index of the requested interface name using
873        // `if_nametoindex`.
874        #[cfg(any(
875            target_os = "illumos",
876            target_os = "ios",
877            target_os = "macos",
878            target_os = "solaris",
879            target_os = "tvos",
880            target_os = "visionos",
881            target_os = "watchos",
882        ))]
883        {
884            let idx = unsafe { libc::if_nametoindex(interface.as_ptr()) };
885            let idx = std::num::NonZeroU32::new(idx).ok_or_else(|| {
886                // If the index is 0, check errno and return an I/O error.
887                ConnectError::new(
888                    "error converting interface name to index",
889                    io::Error::last_os_error(),
890                )
891            })?;
892            // Different setsockopt calls are necessary depending on whether the
893            // address is IPv4 or IPv6.
894            match addr {
895                SocketAddr::V4(_) => socket.bind_device_by_index_v4(Some(idx)),
896                SocketAddr::V6(_) => socket.bind_device_by_index_v6(Some(idx)),
897            }
898            .map_err(ConnectError::m("tcp bind interface error"))?;
899        }
900    }
901
902    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
903    if let Some(tcp_user_timeout) = &config.tcp_user_timeout {
904        if let Err(e) = socket.set_tcp_user_timeout(Some(*tcp_user_timeout)) {
905            warn!("tcp set_tcp_user_timeout error: {}", e);
906        }
907    }
908
909    bind_local_address(
910        &socket,
911        addr,
912        &config.local_address_ipv4,
913        &config.local_address_ipv6,
914    )
915    .map_err(ConnectError::m("tcp bind local error"))?;
916
917    #[cfg(unix)]
918    let socket = unsafe {
919        // Safety: `from_raw_fd` is only safe to call if ownership of the raw
920        // file descriptor is transferred. Since we call `into_raw_fd` on the
921        // socket2 socket, it gives up ownership of the fd and will not close
922        // it, so this is safe.
923        use std::os::unix::io::{FromRawFd, IntoRawFd};
924        TcpSocket::from_raw_fd(socket.into_raw_fd())
925    };
926    #[cfg(windows)]
927    let socket = unsafe {
928        // Safety: `from_raw_socket` is only safe to call if ownership of the raw
929        // Windows SOCKET is transferred. Since we call `into_raw_socket` on the
930        // socket2 socket, it gives up ownership of the SOCKET and will not close
931        // it, so this is safe.
932        use std::os::windows::io::{FromRawSocket, IntoRawSocket};
933        TcpSocket::from_raw_socket(socket.into_raw_socket())
934    };
935
936    if config.reuse_address {
937        if let Err(e) = socket.set_reuseaddr(true) {
938            warn!("tcp set_reuse_address error: {}", e);
939        }
940    }
941
942    if let Some(size) = config.send_buffer_size {
943        if let Err(e) = socket.set_send_buffer_size(size.try_into().unwrap_or(u32::MAX)) {
944            warn!("tcp set_buffer_size error: {}", e);
945        }
946    }
947
948    if let Some(size) = config.recv_buffer_size {
949        if let Err(e) = socket.set_recv_buffer_size(size.try_into().unwrap_or(u32::MAX)) {
950            warn!("tcp set_recv_buffer_size error: {}", e);
951        }
952    }
953
954    let connect = socket.connect(*addr);
955    Ok(async move {
956        match connect_timeout {
957            Some(dur) => match tokio::time::timeout(dur, connect).await {
958                Ok(Ok(s)) => Ok(s),
959                Ok(Err(e)) => Err(e),
960                Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
961            },
962            None => connect.await,
963        }
964        .map_err(ConnectError::m("tcp connect error"))
965    })
966}
967
968impl ConnectingTcp<'_> {
969    async fn connect(mut self) -> Result<TcpStream, ConnectError> {
970        match self.fallback {
971            None => self.preferred.connect(self.config).await,
972            Some(mut fallback) => {
973                let preferred_fut = self.preferred.connect(self.config);
974                futures_util::pin_mut!(preferred_fut);
975
976                let fallback_fut = fallback.remote.connect(self.config);
977                futures_util::pin_mut!(fallback_fut);
978
979                let fallback_delay = fallback.delay;
980                futures_util::pin_mut!(fallback_delay);
981
982                let (result, future) =
983                    match futures_util::future::select(preferred_fut, fallback_delay).await {
984                        Either::Left((result, _fallback_delay)) => {
985                            (result, Either::Right(fallback_fut))
986                        }
987                        Either::Right(((), preferred_fut)) => {
988                            // Delay is done, start polling both the preferred and the fallback
989                            futures_util::future::select(preferred_fut, fallback_fut)
990                                .await
991                                .factor_first()
992                        }
993                    };
994
995                if result.is_err() {
996                    // Fallback to the remaining future (could be preferred or fallback)
997                    // if we get an error
998                    future.await
999                } else {
1000                    result
1001                }
1002            }
1003        }
1004    }
1005}
1006
1007/// Respect explicit ports in the URI, if none, either
1008/// keep non `0` ports resolved from a custom dns resolver,
1009/// or use the default port for the scheme.
1010fn set_port(addr: &mut SocketAddr, host_port: u16, explicit: bool) {
1011    if explicit || addr.port() == 0 {
1012        addr.set_port(host_port)
1013    };
1014}
1015
1016#[cfg(test)]
1017mod tests {
1018    use std::io;
1019    use std::net::SocketAddr;
1020
1021    use ::http::Uri;
1022
1023    use crate::client::legacy::connect::http::TcpKeepaliveConfig;
1024
1025    use super::super::sealed::{Connect, ConnectSvc};
1026    use super::{Config, ConnectError, HttpConnector};
1027
1028    use super::set_port;
1029
1030    async fn connect<C>(
1031        connector: C,
1032        dst: Uri,
1033    ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error>
1034    where
1035        C: Connect,
1036    {
1037        connector.connect(super::super::sealed::Internal, dst).await
1038    }
1039
1040    #[tokio::test]
1041    #[cfg_attr(miri, ignore)]
1042    async fn test_errors_enforce_http() {
1043        let dst = "https://example.domain/foo/bar?baz".parse().unwrap();
1044        let connector = HttpConnector::new();
1045
1046        let err = connect(connector, dst).await.unwrap_err();
1047        assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
1048    }
1049
1050    #[cfg(any(target_os = "linux", target_os = "macos"))]
1051    fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) {
1052        use std::net::{IpAddr, TcpListener};
1053
1054        let mut ip_v4 = None;
1055        let mut ip_v6 = None;
1056
1057        let ips = pnet_datalink::interfaces()
1058            .into_iter()
1059            .flat_map(|i| i.ips.into_iter().map(|n| n.ip()));
1060
1061        for ip in ips {
1062            match ip {
1063                IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip),
1064                IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip),
1065                _ => (),
1066            }
1067
1068            if ip_v4.is_some() && ip_v6.is_some() {
1069                break;
1070            }
1071        }
1072
1073        (ip_v4, ip_v6)
1074    }
1075
1076    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
1077    fn default_interface() -> Option<String> {
1078        pnet_datalink::interfaces()
1079            .iter()
1080            .find(|e| e.is_up() && !e.is_loopback() && !e.ips.is_empty())
1081            .map(|e| e.name.clone())
1082    }
1083
1084    #[tokio::test]
1085    #[cfg_attr(miri, ignore)]
1086    async fn test_errors_missing_scheme() {
1087        let dst = "example.domain".parse().unwrap();
1088        let mut connector = HttpConnector::new();
1089        connector.enforce_http(false);
1090
1091        let err = connect(connector, dst).await.unwrap_err();
1092        assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
1093    }
1094
1095    // NOTE: pnet crate that we use in this test doesn't compile on Windows
1096    #[cfg(any(target_os = "linux", target_os = "macos"))]
1097    #[cfg_attr(miri, ignore)]
1098    #[tokio::test]
1099    async fn local_address() {
1100        use std::net::{IpAddr, TcpListener};
1101
1102        let (bind_ip_v4, bind_ip_v6) = get_local_ips();
1103        let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
1104        let port = server4.local_addr().unwrap().port();
1105        let server6 = TcpListener::bind(format!("[::1]:{port}")).unwrap();
1106
1107        let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move {
1108            let mut connector = HttpConnector::new();
1109
1110            match (bind_ip_v4, bind_ip_v6) {
1111                (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6),
1112                (Some(v4), None) => connector.set_local_address(Some(v4.into())),
1113                (None, Some(v6)) => connector.set_local_address(Some(v6.into())),
1114                _ => unreachable!(),
1115            }
1116
1117            connect(connector, dst.parse().unwrap()).await.unwrap();
1118
1119            let (_, client_addr) = server.accept().unwrap();
1120
1121            assert_eq!(client_addr.ip(), expected_ip);
1122        };
1123
1124        if let Some(ip) = bind_ip_v4 {
1125            assert_client_ip(format!("http://127.0.0.1:{port}"), server4, ip.into()).await;
1126        }
1127
1128        if let Some(ip) = bind_ip_v6 {
1129            assert_client_ip(format!("http://[::1]:{port}"), server6, ip.into()).await;
1130        }
1131    }
1132
1133    // NOTE: pnet crate that we use in this test doesn't compile on Windows
1134    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
1135    #[tokio::test]
1136    #[ignore = "setting `SO_BINDTODEVICE` requires the `CAP_NET_RAW` capability (works when running as root)"]
1137    async fn interface() {
1138        use socket2::{Domain, Protocol, Socket, Type};
1139        use std::net::TcpListener;
1140
1141        let interface: Option<String> = default_interface();
1142
1143        let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
1144        let port = server4.local_addr().unwrap().port();
1145
1146        let server6 = TcpListener::bind(format!("[::1]:{port}")).unwrap();
1147
1148        let assert_interface_name =
1149            |dst: String,
1150             server: TcpListener,
1151             bind_iface: Option<String>,
1152             expected_interface: Option<String>| async move {
1153                let mut connector = HttpConnector::new();
1154                if let Some(iface) = bind_iface {
1155                    connector.set_interface(iface);
1156                }
1157
1158                connect(connector, dst.parse().unwrap()).await.unwrap();
1159                let domain = Domain::for_address(server.local_addr().unwrap());
1160                let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)).unwrap();
1161
1162                assert_eq!(
1163                    socket.device().unwrap().as_deref(),
1164                    expected_interface.as_deref().map(|val| val.as_bytes())
1165                );
1166            };
1167
1168        assert_interface_name(
1169            format!("http://127.0.0.1:{port}"),
1170            server4,
1171            interface.clone(),
1172            interface.clone(),
1173        )
1174        .await;
1175        assert_interface_name(
1176            format!("http://[::1]:{port}"),
1177            server6,
1178            interface.clone(),
1179            interface.clone(),
1180        )
1181        .await;
1182    }
1183
1184    #[test]
1185    #[ignore] // TODO
1186    #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
1187    fn client_happy_eyeballs() {
1188        use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
1189        use std::time::{Duration, Instant};
1190
1191        use super::dns;
1192        use super::ConnectingTcp;
1193
1194        let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
1195        let addr = server4.local_addr().unwrap();
1196        let _server6 = TcpListener::bind(format!("[::1]:{}", addr.port())).unwrap();
1197        let rt = tokio::runtime::Builder::new_current_thread()
1198            .enable_all()
1199            .build()
1200            .unwrap();
1201
1202        let local_timeout = Duration::default();
1203        let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1;
1204        let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1;
1205        let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout)
1206            + Duration::from_millis(250);
1207
1208        let scenarios = &[
1209            // Fast primary, without fallback.
1210            (&[local_ipv4_addr()][..], 4, local_timeout, false),
1211            (&[local_ipv6_addr()][..], 6, local_timeout, false),
1212            // Fast primary, with (unused) fallback.
1213            (
1214                &[local_ipv4_addr(), local_ipv6_addr()][..],
1215                4,
1216                local_timeout,
1217                false,
1218            ),
1219            (
1220                &[local_ipv6_addr(), local_ipv4_addr()][..],
1221                6,
1222                local_timeout,
1223                false,
1224            ),
1225            // Unreachable + fast primary, without fallback.
1226            (
1227                &[unreachable_ipv4_addr(), local_ipv4_addr()][..],
1228                4,
1229                unreachable_v4_timeout,
1230                false,
1231            ),
1232            (
1233                &[unreachable_ipv6_addr(), local_ipv6_addr()][..],
1234                6,
1235                unreachable_v6_timeout,
1236                false,
1237            ),
1238            // Unreachable + fast primary, with (unused) fallback.
1239            (
1240                &[
1241                    unreachable_ipv4_addr(),
1242                    local_ipv4_addr(),
1243                    local_ipv6_addr(),
1244                ][..],
1245                4,
1246                unreachable_v4_timeout,
1247                false,
1248            ),
1249            (
1250                &[
1251                    unreachable_ipv6_addr(),
1252                    local_ipv6_addr(),
1253                    local_ipv4_addr(),
1254                ][..],
1255                6,
1256                unreachable_v6_timeout,
1257                true,
1258            ),
1259            // Slow primary, with (used) fallback.
1260            (
1261                &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
1262                6,
1263                fallback_timeout,
1264                false,
1265            ),
1266            (
1267                &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
1268                4,
1269                fallback_timeout,
1270                true,
1271            ),
1272            // Slow primary, with (used) unreachable + fast fallback.
1273            (
1274                &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
1275                6,
1276                fallback_timeout + unreachable_v6_timeout,
1277                false,
1278            ),
1279            (
1280                &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
1281                4,
1282                fallback_timeout + unreachable_v4_timeout,
1283                true,
1284            ),
1285        ];
1286
1287        // Scenarios for IPv6 -> IPv4 fallback require that host can access IPv6 network.
1288        // Otherwise, connection to "slow" IPv6 address will error-out immediately.
1289        let ipv6_accessible = measure_connect(slow_ipv6_addr()).0;
1290
1291        for &(hosts, family, timeout, needs_ipv6_access) in scenarios {
1292            if needs_ipv6_access && !ipv6_accessible {
1293                continue;
1294            }
1295
1296            let (start, stream) = rt
1297                .block_on(async move {
1298                    let addrs = hosts
1299                        .iter()
1300                        .map(|host| (*host, addr.port()).into())
1301                        .collect();
1302                    let cfg = Config {
1303                        local_address_ipv4: None,
1304                        local_address_ipv6: None,
1305                        connect_timeout: None,
1306                        tcp_keepalive_config: TcpKeepaliveConfig::default(),
1307                        happy_eyeballs_timeout: Some(fallback_timeout),
1308                        nodelay: false,
1309                        reuse_address: false,
1310                        enforce_http: false,
1311                        send_buffer_size: None,
1312                        recv_buffer_size: None,
1313                        #[cfg(any(
1314                            target_os = "android",
1315                            target_os = "fuchsia",
1316                            target_os = "linux"
1317                        ))]
1318                        interface: None,
1319                        #[cfg(any(
1320                            target_os = "illumos",
1321                            target_os = "ios",
1322                            target_os = "macos",
1323                            target_os = "solaris",
1324                            target_os = "tvos",
1325                            target_os = "visionos",
1326                            target_os = "watchos",
1327                        ))]
1328                        interface: None,
1329                        #[cfg(any(
1330                            target_os = "android",
1331                            target_os = "fuchsia",
1332                            target_os = "linux"
1333                        ))]
1334                        tcp_user_timeout: None,
1335                    };
1336                    let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg);
1337                    let start = Instant::now();
1338                    Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?))
1339                })
1340                .unwrap();
1341            let res = if stream.peer_addr().unwrap().is_ipv4() {
1342                4
1343            } else {
1344                6
1345            };
1346            let duration = start.elapsed();
1347
1348            // Allow actual duration to be +/- 150ms off.
1349            let min_duration = if timeout >= Duration::from_millis(150) {
1350                timeout - Duration::from_millis(150)
1351            } else {
1352                Duration::default()
1353            };
1354            let max_duration = timeout + Duration::from_millis(150);
1355
1356            assert_eq!(res, family);
1357            assert!(duration >= min_duration);
1358            assert!(duration <= max_duration);
1359        }
1360
1361        fn local_ipv4_addr() -> IpAddr {
1362            Ipv4Addr::new(127, 0, 0, 1).into()
1363        }
1364
1365        fn local_ipv6_addr() -> IpAddr {
1366            Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()
1367        }
1368
1369        fn unreachable_ipv4_addr() -> IpAddr {
1370            Ipv4Addr::new(127, 0, 0, 2).into()
1371        }
1372
1373        fn unreachable_ipv6_addr() -> IpAddr {
1374            Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into()
1375        }
1376
1377        fn slow_ipv4_addr() -> IpAddr {
1378            // RFC 6890 reserved IPv4 address.
1379            Ipv4Addr::new(198, 18, 0, 25).into()
1380        }
1381
1382        fn slow_ipv6_addr() -> IpAddr {
1383            // RFC 6890 reserved IPv6 address.
1384            Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into()
1385        }
1386
1387        fn measure_connect(addr: IpAddr) -> (bool, Duration) {
1388            let start = Instant::now();
1389            let result =
1390                std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1));
1391
1392            let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut;
1393            let duration = start.elapsed();
1394            (reachable, duration)
1395        }
1396    }
1397
1398    use std::time::Duration;
1399
1400    #[test]
1401    fn no_tcp_keepalive_config() {
1402        assert!(TcpKeepaliveConfig::default().into_tcpkeepalive().is_none());
1403    }
1404
1405    #[test]
1406    fn tcp_keepalive_time_config() {
1407        let kac = TcpKeepaliveConfig {
1408            time: Some(Duration::from_secs(60)),
1409            ..Default::default()
1410        };
1411        if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1412            assert!(format!("{tcp_keepalive:?}").contains("time: Some(60s)"));
1413        } else {
1414            panic!("test failed");
1415        }
1416    }
1417
1418    #[cfg(not(any(target_os = "openbsd", target_os = "redox", target_os = "solaris")))]
1419    #[test]
1420    fn tcp_keepalive_interval_config() {
1421        let kac = TcpKeepaliveConfig {
1422            interval: Some(Duration::from_secs(1)),
1423            ..Default::default()
1424        };
1425        if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1426            assert!(format!("{tcp_keepalive:?}").contains("interval: Some(1s)"));
1427        } else {
1428            panic!("test failed");
1429        }
1430    }
1431
1432    #[cfg(not(any(
1433        target_os = "openbsd",
1434        target_os = "redox",
1435        target_os = "solaris",
1436        target_os = "windows"
1437    )))]
1438    #[test]
1439    fn tcp_keepalive_retries_config() {
1440        let kac = TcpKeepaliveConfig {
1441            retries: Some(3),
1442            ..Default::default()
1443        };
1444        if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1445            assert!(format!("{tcp_keepalive:?}").contains("retries: Some(3)"));
1446        } else {
1447            panic!("test failed");
1448        }
1449    }
1450
1451    #[test]
1452    fn test_set_port() {
1453        // Respect explicit ports no matter what the resolved port is.
1454        let mut addr = SocketAddr::from(([0, 0, 0, 0], 6881));
1455        set_port(&mut addr, 42, true);
1456        assert_eq!(addr.port(), 42);
1457
1458        // Ignore default  host port, and use the socket port instead.
1459        let mut addr = SocketAddr::from(([0, 0, 0, 0], 6881));
1460        set_port(&mut addr, 443, false);
1461        assert_eq!(addr.port(), 6881);
1462
1463        // Use the default port if the resolved port is `0`.
1464        let mut addr = SocketAddr::from(([0, 0, 0, 0], 0));
1465        set_port(&mut addr, 443, false);
1466        assert_eq!(addr.port(), 443);
1467    }
1468}