1use std::error::Error as StdError;
2use std::fmt;
3use std::future::Future;
4use std::io;
5use std::marker::PhantomData;
6use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{self, Poll};
10use std::time::Duration;
11
12use futures_core::ready;
13use futures_util::future::Either;
14use http::uri::{Scheme, Uri};
15use pin_project_lite::pin_project;
16use socket2::TcpKeepalive;
17use tokio::net::{TcpSocket, TcpStream};
18use tokio::time::Sleep;
19use tracing::{debug, trace, warn};
20
21use super::dns::{self, resolve, GaiResolver, Resolve};
22use super::{Connected, Connection};
23use crate::rt::TokioIo;
24
25#[derive(Clone)]
34pub struct HttpConnector<R = GaiResolver> {
35 config: Arc<Config>,
36 resolver: R,
37}
38
39#[derive(Clone, Debug)]
63pub struct HttpInfo {
64 remote_addr: SocketAddr,
65 local_addr: SocketAddr,
66}
67
68#[derive(Clone)]
69struct Config {
70 connect_timeout: Option<Duration>,
71 enforce_http: bool,
72 happy_eyeballs_timeout: Option<Duration>,
73 tcp_keepalive_config: TcpKeepaliveConfig,
74 local_address_ipv4: Option<Ipv4Addr>,
75 local_address_ipv6: Option<Ipv6Addr>,
76 nodelay: bool,
77 reuse_address: bool,
78 send_buffer_size: Option<usize>,
79 recv_buffer_size: Option<usize>,
80 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
81 interface: Option<String>,
82 #[cfg(any(
83 target_os = "illumos",
84 target_os = "ios",
85 target_os = "macos",
86 target_os = "solaris",
87 target_os = "tvos",
88 target_os = "visionos",
89 target_os = "watchos",
90 ))]
91 interface: Option<std::ffi::CString>,
92 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
93 tcp_user_timeout: Option<Duration>,
94}
95
96#[derive(Default, Debug, Clone, Copy)]
97struct TcpKeepaliveConfig {
98 time: Option<Duration>,
99 interval: Option<Duration>,
100 retries: Option<u32>,
101}
102
103impl TcpKeepaliveConfig {
104 fn into_tcpkeepalive(self) -> Option<TcpKeepalive> {
106 let mut dirty = false;
107 let mut ka = TcpKeepalive::new();
108 if let Some(time) = self.time {
109 ka = ka.with_time(time);
110 dirty = true
111 }
112 if let Some(interval) = self.interval {
113 ka = Self::ka_with_interval(ka, interval, &mut dirty)
114 };
115 if let Some(retries) = self.retries {
116 ka = Self::ka_with_retries(ka, retries, &mut dirty)
117 };
118 if dirty {
119 Some(ka)
120 } else {
121 None
122 }
123 }
124
125 #[cfg(
126 any(
128 target_os = "android",
129 target_os = "dragonfly",
130 target_os = "freebsd",
131 target_os = "fuchsia",
132 target_os = "illumos",
133 target_os = "ios",
134 target_os = "visionos",
135 target_os = "linux",
136 target_os = "macos",
137 target_os = "netbsd",
138 target_os = "tvos",
139 target_os = "watchos",
140 target_os = "windows",
141 )
142 )]
143 fn ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive {
144 *dirty = true;
145 ka.with_interval(interval)
146 }
147
148 #[cfg(not(
149 any(
151 target_os = "android",
152 target_os = "dragonfly",
153 target_os = "freebsd",
154 target_os = "fuchsia",
155 target_os = "illumos",
156 target_os = "ios",
157 target_os = "visionos",
158 target_os = "linux",
159 target_os = "macos",
160 target_os = "netbsd",
161 target_os = "tvos",
162 target_os = "watchos",
163 target_os = "windows",
164 )
165 ))]
166 fn ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive {
167 ka }
169
170 #[cfg(
171 any(
173 target_os = "android",
174 target_os = "dragonfly",
175 target_os = "freebsd",
176 target_os = "fuchsia",
177 target_os = "illumos",
178 target_os = "ios",
179 target_os = "visionos",
180 target_os = "linux",
181 target_os = "macos",
182 target_os = "netbsd",
183 target_os = "tvos",
184 target_os = "watchos",
185 )
186 )]
187 fn ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive {
188 *dirty = true;
189 ka.with_retries(retries)
190 }
191
192 #[cfg(not(
193 any(
195 target_os = "android",
196 target_os = "dragonfly",
197 target_os = "freebsd",
198 target_os = "fuchsia",
199 target_os = "illumos",
200 target_os = "ios",
201 target_os = "visionos",
202 target_os = "linux",
203 target_os = "macos",
204 target_os = "netbsd",
205 target_os = "tvos",
206 target_os = "watchos",
207 )
208 ))]
209 fn ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive {
210 ka }
212}
213
214impl HttpConnector {
217 pub fn new() -> HttpConnector {
219 HttpConnector::new_with_resolver(GaiResolver::new())
220 }
221}
222
223impl<R> HttpConnector<R> {
224 pub fn new_with_resolver(resolver: R) -> HttpConnector<R> {
228 HttpConnector {
229 config: Arc::new(Config {
230 connect_timeout: None,
231 enforce_http: true,
232 happy_eyeballs_timeout: Some(Duration::from_millis(300)),
233 tcp_keepalive_config: TcpKeepaliveConfig::default(),
234 local_address_ipv4: None,
235 local_address_ipv6: None,
236 nodelay: false,
237 reuse_address: false,
238 send_buffer_size: None,
239 recv_buffer_size: None,
240 #[cfg(any(
241 target_os = "android",
242 target_os = "fuchsia",
243 target_os = "illumos",
244 target_os = "ios",
245 target_os = "linux",
246 target_os = "macos",
247 target_os = "solaris",
248 target_os = "tvos",
249 target_os = "visionos",
250 target_os = "watchos",
251 ))]
252 interface: None,
253 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
254 tcp_user_timeout: None,
255 }),
256 resolver,
257 }
258 }
259
260 #[inline]
264 pub fn enforce_http(&mut self, is_enforced: bool) {
265 self.config_mut().enforce_http = is_enforced;
266 }
267
268 #[inline]
275 pub fn set_keepalive(&mut self, time: Option<Duration>) {
276 self.config_mut().tcp_keepalive_config.time = time;
277 }
278
279 #[inline]
282 pub fn set_keepalive_interval(&mut self, interval: Option<Duration>) {
283 self.config_mut().tcp_keepalive_config.interval = interval;
284 }
285
286 #[inline]
288 pub fn set_keepalive_retries(&mut self, retries: Option<u32>) {
289 self.config_mut().tcp_keepalive_config.retries = retries;
290 }
291
292 #[inline]
296 pub fn set_nodelay(&mut self, nodelay: bool) {
297 self.config_mut().nodelay = nodelay;
298 }
299
300 #[inline]
302 pub fn set_send_buffer_size(&mut self, size: Option<usize>) {
303 self.config_mut().send_buffer_size = size;
304 }
305
306 #[inline]
308 pub fn set_recv_buffer_size(&mut self, size: Option<usize>) {
309 self.config_mut().recv_buffer_size = size;
310 }
311
312 #[inline]
318 pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
319 let (v4, v6) = match addr {
320 Some(IpAddr::V4(a)) => (Some(a), None),
321 Some(IpAddr::V6(a)) => (None, Some(a)),
322 _ => (None, None),
323 };
324
325 let cfg = self.config_mut();
326
327 cfg.local_address_ipv4 = v4;
328 cfg.local_address_ipv6 = v6;
329 }
330
331 #[inline]
334 pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
335 let cfg = self.config_mut();
336
337 cfg.local_address_ipv4 = Some(addr_ipv4);
338 cfg.local_address_ipv6 = Some(addr_ipv6);
339 }
340
341 #[inline]
348 pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
349 self.config_mut().connect_timeout = dur;
350 }
351
352 #[inline]
365 pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) {
366 self.config_mut().happy_eyeballs_timeout = dur;
367 }
368
369 #[inline]
373 pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self {
374 self.config_mut().reuse_address = reuse_address;
375 self
376 }
377
378 #[cfg(any(
403 target_os = "android",
404 target_os = "fuchsia",
405 target_os = "illumos",
406 target_os = "ios",
407 target_os = "linux",
408 target_os = "macos",
409 target_os = "solaris",
410 target_os = "tvos",
411 target_os = "visionos",
412 target_os = "watchos",
413 ))]
414 #[inline]
415 pub fn set_interface<S: Into<String>>(&mut self, interface: S) -> &mut Self {
416 let interface = interface.into();
417 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
418 {
419 self.config_mut().interface = Some(interface);
420 }
421 #[cfg(not(any(target_os = "android", target_os = "fuchsia", target_os = "linux")))]
422 {
423 let interface = std::ffi::CString::new(interface)
424 .expect("interface name should not have nulls in it");
425 self.config_mut().interface = Some(interface);
426 }
427 self
428 }
429
430 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
432 #[inline]
433 pub fn set_tcp_user_timeout(&mut self, time: Option<Duration>) {
434 self.config_mut().tcp_user_timeout = time;
435 }
436
437 fn config_mut(&mut self) -> &mut Config {
440 Arc::make_mut(&mut self.config)
444 }
445}
446
447static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
448static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
449static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
450
451impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> {
453 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
454 f.debug_struct("HttpConnector").finish()
455 }
456}
457
458impl<R> tower_service::Service<Uri> for HttpConnector<R>
459where
460 R: Resolve + Clone + Send + Sync + 'static,
461 R::Future: Send,
462{
463 type Response = TokioIo<TcpStream>;
464 type Error = ConnectError;
465 type Future = HttpConnecting<R>;
466
467 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
468 ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
469 Poll::Ready(Ok(()))
470 }
471
472 fn call(&mut self, dst: Uri) -> Self::Future {
473 let mut self_ = self.clone();
474 HttpConnecting {
475 fut: Box::pin(async move { self_.call_async(dst).await }),
476 _marker: PhantomData,
477 }
478 }
479}
480
481fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> {
482 trace!(
483 "Http::connect; scheme={:?}, host={:?}, port={:?}",
484 dst.scheme(),
485 dst.host(),
486 dst.port(),
487 );
488
489 if config.enforce_http {
490 if dst.scheme() != Some(&Scheme::HTTP) {
491 return Err(ConnectError {
492 msg: INVALID_NOT_HTTP,
493 addr: None,
494 cause: None,
495 });
496 }
497 } else if dst.scheme().is_none() {
498 return Err(ConnectError {
499 msg: INVALID_MISSING_SCHEME,
500 addr: None,
501 cause: None,
502 });
503 }
504
505 let host = match dst.host() {
506 Some(s) => s,
507 None => {
508 return Err(ConnectError {
509 msg: INVALID_MISSING_HOST,
510 addr: None,
511 cause: None,
512 })
513 }
514 };
515 let port = match dst.port() {
516 Some(port) => port.as_u16(),
517 None => {
518 if dst.scheme() == Some(&Scheme::HTTPS) {
519 443
520 } else {
521 80
522 }
523 }
524 };
525
526 Ok((host, port))
527}
528
529impl<R> HttpConnector<R>
530where
531 R: Resolve,
532{
533 async fn call_async(&mut self, dst: Uri) -> Result<TokioIo<TcpStream>, ConnectError> {
534 let config = &self.config;
535
536 let (host, port) = get_host_port(config, &dst)?;
537 let host = host.trim_start_matches('[').trim_end_matches(']');
538
539 let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) {
542 addrs
543 } else {
544 let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
545 .await
546 .map_err(ConnectError::dns)?;
547 let addrs = addrs
548 .map(|mut addr| {
549 set_port(&mut addr, port, dst.port().is_some());
550
551 addr
552 })
553 .collect();
554 dns::SocketAddrs::new(addrs)
555 };
556
557 let c = ConnectingTcp::new(addrs, config);
558
559 let sock = c.connect().await?;
560
561 if let Err(e) = sock.set_nodelay(config.nodelay) {
562 warn!("tcp set_nodelay error: {}", e);
563 }
564
565 Ok(TokioIo::new(sock))
566 }
567}
568
569impl Connection for TcpStream {
570 fn connected(&self) -> Connected {
571 let connected = Connected::new();
572 if let (Ok(remote_addr), Ok(local_addr)) = (self.peer_addr(), self.local_addr()) {
573 connected.extra(HttpInfo {
574 remote_addr,
575 local_addr,
576 })
577 } else {
578 connected
579 }
580 }
581}
582
583#[cfg(unix)]
584impl Connection for tokio::net::UnixStream {
585 fn connected(&self) -> Connected {
586 Connected::new()
587 }
588}
589
590#[cfg(windows)]
591impl Connection for tokio::net::windows::named_pipe::NamedPipeClient {
592 fn connected(&self) -> Connected {
593 Connected::new()
594 }
595}
596
597impl<T> Connection for TokioIo<T>
600where
601 T: Connection,
602{
603 fn connected(&self) -> Connected {
604 self.inner().connected()
605 }
606}
607
608impl HttpInfo {
609 pub fn remote_addr(&self) -> SocketAddr {
611 self.remote_addr
612 }
613
614 pub fn local_addr(&self) -> SocketAddr {
616 self.local_addr
617 }
618}
619
620pin_project! {
621 #[must_use = "futures do nothing unless polled"]
627 #[allow(missing_debug_implementations)]
628 pub struct HttpConnecting<R> {
629 #[pin]
630 fut: BoxConnecting,
631 _marker: PhantomData<R>,
632 }
633}
634
635type ConnectResult = Result<TokioIo<TcpStream>, ConnectError>;
636type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>;
637
638impl<R: Resolve> Future for HttpConnecting<R> {
639 type Output = ConnectResult;
640
641 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
642 self.project().fut.poll(cx)
643 }
644}
645
646pub struct ConnectError {
648 msg: &'static str,
649 addr: Option<SocketAddr>,
650 cause: Option<Box<dyn StdError + Send + Sync>>,
651}
652
653impl ConnectError {
654 fn new<E>(msg: &'static str, cause: E) -> ConnectError
655 where
656 E: Into<Box<dyn StdError + Send + Sync>>,
657 {
658 ConnectError {
659 msg,
660 addr: None,
661 cause: Some(cause.into()),
662 }
663 }
664
665 fn dns<E>(cause: E) -> ConnectError
666 where
667 E: Into<Box<dyn StdError + Send + Sync>>,
668 {
669 ConnectError::new("dns error", cause)
670 }
671
672 fn m<E>(msg: &'static str) -> impl FnOnce(E) -> ConnectError
673 where
674 E: Into<Box<dyn StdError + Send + Sync>>,
675 {
676 move |cause| ConnectError::new(msg, cause)
677 }
678}
679
680impl fmt::Debug for ConnectError {
681 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
682 let mut b = f.debug_tuple("ConnectError");
683 b.field(&self.msg);
684 if let Some(ref addr) = self.addr {
685 b.field(addr);
686 }
687 if let Some(ref cause) = self.cause {
688 b.field(cause);
689 }
690 b.finish()
691 }
692}
693
694impl fmt::Display for ConnectError {
695 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
696 f.write_str(self.msg)
697 }
698}
699
700impl StdError for ConnectError {
701 fn source(&self) -> Option<&(dyn StdError + 'static)> {
702 self.cause.as_ref().map(|e| &**e as _)
703 }
704}
705
706struct ConnectingTcp<'a> {
707 preferred: ConnectingTcpRemote,
708 fallback: Option<ConnectingTcpFallback>,
709 config: &'a Config,
710}
711
712impl<'a> ConnectingTcp<'a> {
713 fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self {
714 if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
715 let (preferred_addrs, fallback_addrs) = remote_addrs
716 .split_by_preference(config.local_address_ipv4, config.local_address_ipv6);
717 if fallback_addrs.is_empty() {
718 return ConnectingTcp {
719 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
720 fallback: None,
721 config,
722 };
723 }
724
725 ConnectingTcp {
726 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
727 fallback: Some(ConnectingTcpFallback {
728 delay: tokio::time::sleep(fallback_timeout),
729 remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout),
730 }),
731 config,
732 }
733 } else {
734 ConnectingTcp {
735 preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout),
736 fallback: None,
737 config,
738 }
739 }
740 }
741}
742
743struct ConnectingTcpFallback {
744 delay: Sleep,
745 remote: ConnectingTcpRemote,
746}
747
748struct ConnectingTcpRemote {
749 addrs: dns::SocketAddrs,
750 connect_timeout: Option<Duration>,
751}
752
753impl ConnectingTcpRemote {
754 fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self {
755 let connect_timeout = connect_timeout.and_then(|t| t.checked_div(addrs.len() as u32));
756
757 Self {
758 addrs,
759 connect_timeout,
760 }
761 }
762}
763
764impl ConnectingTcpRemote {
765 async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> {
766 let mut err = None;
767 for addr in &mut self.addrs {
768 debug!("connecting to {}", addr);
769 match connect(&addr, config, self.connect_timeout)?.await {
770 Ok(tcp) => {
771 debug!("connected to {}", addr);
772 return Ok(tcp);
773 }
774 Err(mut e) => {
775 trace!("connect error for {}: {:?}", addr, e);
776 e.addr = Some(addr);
777 if err.is_none() {
779 err = Some(e);
780 }
781 }
782 }
783 }
784
785 match err {
786 Some(e) => Err(e),
787 None => Err(ConnectError::new(
788 "tcp connect error",
789 std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"),
790 )),
791 }
792 }
793}
794
795fn bind_local_address(
796 socket: &socket2::Socket,
797 dst_addr: &SocketAddr,
798 local_addr_ipv4: &Option<Ipv4Addr>,
799 local_addr_ipv6: &Option<Ipv6Addr>,
800) -> io::Result<()> {
801 match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
802 (SocketAddr::V4(_), Some(addr), _) => {
803 socket.bind(&SocketAddr::new((*addr).into(), 0).into())?;
804 }
805 (SocketAddr::V6(_), _, Some(addr)) => {
806 socket.bind(&SocketAddr::new((*addr).into(), 0).into())?;
807 }
808 _ => {
809 if cfg!(windows) {
810 let any: SocketAddr = match *dst_addr {
812 SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
813 SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
814 };
815 socket.bind(&any.into())?;
816 }
817 }
818 }
819
820 Ok(())
821}
822
823fn connect(
824 addr: &SocketAddr,
825 config: &Config,
826 connect_timeout: Option<Duration>,
827) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
828 use socket2::{Domain, Protocol, Socket, Type};
832
833 let domain = Domain::for_address(*addr);
834 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
835 .map_err(ConnectError::m("tcp open error"))?;
836
837 socket
840 .set_nonblocking(true)
841 .map_err(ConnectError::m("tcp set_nonblocking error"))?;
842
843 if let Some(tcp_keepalive) = &config.tcp_keepalive_config.into_tcpkeepalive() {
844 if let Err(e) = socket.set_tcp_keepalive(tcp_keepalive) {
845 warn!("tcp set_keepalive error: {}", e);
846 }
847 }
848
849 #[cfg(any(
851 target_os = "android",
852 target_os = "fuchsia",
853 target_os = "illumos",
854 target_os = "ios",
855 target_os = "linux",
856 target_os = "macos",
857 target_os = "solaris",
858 target_os = "tvos",
859 target_os = "visionos",
860 target_os = "watchos",
861 ))]
862 if let Some(interface) = &config.interface {
863 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
866 socket
867 .bind_device(Some(interface.as_bytes()))
868 .map_err(ConnectError::m("tcp bind interface error"))?;
869
870 #[cfg(any(
875 target_os = "illumos",
876 target_os = "ios",
877 target_os = "macos",
878 target_os = "solaris",
879 target_os = "tvos",
880 target_os = "visionos",
881 target_os = "watchos",
882 ))]
883 {
884 let idx = unsafe { libc::if_nametoindex(interface.as_ptr()) };
885 let idx = std::num::NonZeroU32::new(idx).ok_or_else(|| {
886 ConnectError::new(
888 "error converting interface name to index",
889 io::Error::last_os_error(),
890 )
891 })?;
892 match addr {
895 SocketAddr::V4(_) => socket.bind_device_by_index_v4(Some(idx)),
896 SocketAddr::V6(_) => socket.bind_device_by_index_v6(Some(idx)),
897 }
898 .map_err(ConnectError::m("tcp bind interface error"))?;
899 }
900 }
901
902 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
903 if let Some(tcp_user_timeout) = &config.tcp_user_timeout {
904 if let Err(e) = socket.set_tcp_user_timeout(Some(*tcp_user_timeout)) {
905 warn!("tcp set_tcp_user_timeout error: {}", e);
906 }
907 }
908
909 bind_local_address(
910 &socket,
911 addr,
912 &config.local_address_ipv4,
913 &config.local_address_ipv6,
914 )
915 .map_err(ConnectError::m("tcp bind local error"))?;
916
917 #[cfg(unix)]
918 let socket = unsafe {
919 use std::os::unix::io::{FromRawFd, IntoRawFd};
924 TcpSocket::from_raw_fd(socket.into_raw_fd())
925 };
926 #[cfg(windows)]
927 let socket = unsafe {
928 use std::os::windows::io::{FromRawSocket, IntoRawSocket};
933 TcpSocket::from_raw_socket(socket.into_raw_socket())
934 };
935
936 if config.reuse_address {
937 if let Err(e) = socket.set_reuseaddr(true) {
938 warn!("tcp set_reuse_address error: {}", e);
939 }
940 }
941
942 if let Some(size) = config.send_buffer_size {
943 if let Err(e) = socket.set_send_buffer_size(size.try_into().unwrap_or(u32::MAX)) {
944 warn!("tcp set_buffer_size error: {}", e);
945 }
946 }
947
948 if let Some(size) = config.recv_buffer_size {
949 if let Err(e) = socket.set_recv_buffer_size(size.try_into().unwrap_or(u32::MAX)) {
950 warn!("tcp set_recv_buffer_size error: {}", e);
951 }
952 }
953
954 let connect = socket.connect(*addr);
955 Ok(async move {
956 match connect_timeout {
957 Some(dur) => match tokio::time::timeout(dur, connect).await {
958 Ok(Ok(s)) => Ok(s),
959 Ok(Err(e)) => Err(e),
960 Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
961 },
962 None => connect.await,
963 }
964 .map_err(ConnectError::m("tcp connect error"))
965 })
966}
967
968impl ConnectingTcp<'_> {
969 async fn connect(mut self) -> Result<TcpStream, ConnectError> {
970 match self.fallback {
971 None => self.preferred.connect(self.config).await,
972 Some(mut fallback) => {
973 let preferred_fut = self.preferred.connect(self.config);
974 futures_util::pin_mut!(preferred_fut);
975
976 let fallback_fut = fallback.remote.connect(self.config);
977 futures_util::pin_mut!(fallback_fut);
978
979 let fallback_delay = fallback.delay;
980 futures_util::pin_mut!(fallback_delay);
981
982 let (result, future) =
983 match futures_util::future::select(preferred_fut, fallback_delay).await {
984 Either::Left((result, _fallback_delay)) => {
985 (result, Either::Right(fallback_fut))
986 }
987 Either::Right(((), preferred_fut)) => {
988 futures_util::future::select(preferred_fut, fallback_fut)
990 .await
991 .factor_first()
992 }
993 };
994
995 if result.is_err() {
996 future.await
999 } else {
1000 result
1001 }
1002 }
1003 }
1004 }
1005}
1006
1007fn set_port(addr: &mut SocketAddr, host_port: u16, explicit: bool) {
1011 if explicit || addr.port() == 0 {
1012 addr.set_port(host_port)
1013 };
1014}
1015
1016#[cfg(test)]
1017mod tests {
1018 use std::io;
1019 use std::net::SocketAddr;
1020
1021 use ::http::Uri;
1022
1023 use crate::client::legacy::connect::http::TcpKeepaliveConfig;
1024
1025 use super::super::sealed::{Connect, ConnectSvc};
1026 use super::{Config, ConnectError, HttpConnector};
1027
1028 use super::set_port;
1029
1030 async fn connect<C>(
1031 connector: C,
1032 dst: Uri,
1033 ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error>
1034 where
1035 C: Connect,
1036 {
1037 connector.connect(super::super::sealed::Internal, dst).await
1038 }
1039
1040 #[tokio::test]
1041 #[cfg_attr(miri, ignore)]
1042 async fn test_errors_enforce_http() {
1043 let dst = "https://example.domain/foo/bar?baz".parse().unwrap();
1044 let connector = HttpConnector::new();
1045
1046 let err = connect(connector, dst).await.unwrap_err();
1047 assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
1048 }
1049
1050 #[cfg(any(target_os = "linux", target_os = "macos"))]
1051 fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) {
1052 use std::net::{IpAddr, TcpListener};
1053
1054 let mut ip_v4 = None;
1055 let mut ip_v6 = None;
1056
1057 let ips = pnet_datalink::interfaces()
1058 .into_iter()
1059 .flat_map(|i| i.ips.into_iter().map(|n| n.ip()));
1060
1061 for ip in ips {
1062 match ip {
1063 IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip),
1064 IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip),
1065 _ => (),
1066 }
1067
1068 if ip_v4.is_some() && ip_v6.is_some() {
1069 break;
1070 }
1071 }
1072
1073 (ip_v4, ip_v6)
1074 }
1075
1076 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
1077 fn default_interface() -> Option<String> {
1078 pnet_datalink::interfaces()
1079 .iter()
1080 .find(|e| e.is_up() && !e.is_loopback() && !e.ips.is_empty())
1081 .map(|e| e.name.clone())
1082 }
1083
1084 #[tokio::test]
1085 #[cfg_attr(miri, ignore)]
1086 async fn test_errors_missing_scheme() {
1087 let dst = "example.domain".parse().unwrap();
1088 let mut connector = HttpConnector::new();
1089 connector.enforce_http(false);
1090
1091 let err = connect(connector, dst).await.unwrap_err();
1092 assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
1093 }
1094
1095 #[cfg(any(target_os = "linux", target_os = "macos"))]
1097 #[cfg_attr(miri, ignore)]
1098 #[tokio::test]
1099 async fn local_address() {
1100 use std::net::{IpAddr, TcpListener};
1101
1102 let (bind_ip_v4, bind_ip_v6) = get_local_ips();
1103 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
1104 let port = server4.local_addr().unwrap().port();
1105 let server6 = TcpListener::bind(format!("[::1]:{port}")).unwrap();
1106
1107 let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move {
1108 let mut connector = HttpConnector::new();
1109
1110 match (bind_ip_v4, bind_ip_v6) {
1111 (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6),
1112 (Some(v4), None) => connector.set_local_address(Some(v4.into())),
1113 (None, Some(v6)) => connector.set_local_address(Some(v6.into())),
1114 _ => unreachable!(),
1115 }
1116
1117 connect(connector, dst.parse().unwrap()).await.unwrap();
1118
1119 let (_, client_addr) = server.accept().unwrap();
1120
1121 assert_eq!(client_addr.ip(), expected_ip);
1122 };
1123
1124 if let Some(ip) = bind_ip_v4 {
1125 assert_client_ip(format!("http://127.0.0.1:{port}"), server4, ip.into()).await;
1126 }
1127
1128 if let Some(ip) = bind_ip_v6 {
1129 assert_client_ip(format!("http://[::1]:{port}"), server6, ip.into()).await;
1130 }
1131 }
1132
1133 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
1135 #[tokio::test]
1136 #[ignore = "setting `SO_BINDTODEVICE` requires the `CAP_NET_RAW` capability (works when running as root)"]
1137 async fn interface() {
1138 use socket2::{Domain, Protocol, Socket, Type};
1139 use std::net::TcpListener;
1140
1141 let interface: Option<String> = default_interface();
1142
1143 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
1144 let port = server4.local_addr().unwrap().port();
1145
1146 let server6 = TcpListener::bind(format!("[::1]:{port}")).unwrap();
1147
1148 let assert_interface_name =
1149 |dst: String,
1150 server: TcpListener,
1151 bind_iface: Option<String>,
1152 expected_interface: Option<String>| async move {
1153 let mut connector = HttpConnector::new();
1154 if let Some(iface) = bind_iface {
1155 connector.set_interface(iface);
1156 }
1157
1158 connect(connector, dst.parse().unwrap()).await.unwrap();
1159 let domain = Domain::for_address(server.local_addr().unwrap());
1160 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)).unwrap();
1161
1162 assert_eq!(
1163 socket.device().unwrap().as_deref(),
1164 expected_interface.as_deref().map(|val| val.as_bytes())
1165 );
1166 };
1167
1168 assert_interface_name(
1169 format!("http://127.0.0.1:{port}"),
1170 server4,
1171 interface.clone(),
1172 interface.clone(),
1173 )
1174 .await;
1175 assert_interface_name(
1176 format!("http://[::1]:{port}"),
1177 server6,
1178 interface.clone(),
1179 interface.clone(),
1180 )
1181 .await;
1182 }
1183
1184 #[test]
1185 #[ignore] #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
1187 fn client_happy_eyeballs() {
1188 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
1189 use std::time::{Duration, Instant};
1190
1191 use super::dns;
1192 use super::ConnectingTcp;
1193
1194 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
1195 let addr = server4.local_addr().unwrap();
1196 let _server6 = TcpListener::bind(format!("[::1]:{}", addr.port())).unwrap();
1197 let rt = tokio::runtime::Builder::new_current_thread()
1198 .enable_all()
1199 .build()
1200 .unwrap();
1201
1202 let local_timeout = Duration::default();
1203 let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1;
1204 let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1;
1205 let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout)
1206 + Duration::from_millis(250);
1207
1208 let scenarios = &[
1209 (&[local_ipv4_addr()][..], 4, local_timeout, false),
1211 (&[local_ipv6_addr()][..], 6, local_timeout, false),
1212 (
1214 &[local_ipv4_addr(), local_ipv6_addr()][..],
1215 4,
1216 local_timeout,
1217 false,
1218 ),
1219 (
1220 &[local_ipv6_addr(), local_ipv4_addr()][..],
1221 6,
1222 local_timeout,
1223 false,
1224 ),
1225 (
1227 &[unreachable_ipv4_addr(), local_ipv4_addr()][..],
1228 4,
1229 unreachable_v4_timeout,
1230 false,
1231 ),
1232 (
1233 &[unreachable_ipv6_addr(), local_ipv6_addr()][..],
1234 6,
1235 unreachable_v6_timeout,
1236 false,
1237 ),
1238 (
1240 &[
1241 unreachable_ipv4_addr(),
1242 local_ipv4_addr(),
1243 local_ipv6_addr(),
1244 ][..],
1245 4,
1246 unreachable_v4_timeout,
1247 false,
1248 ),
1249 (
1250 &[
1251 unreachable_ipv6_addr(),
1252 local_ipv6_addr(),
1253 local_ipv4_addr(),
1254 ][..],
1255 6,
1256 unreachable_v6_timeout,
1257 true,
1258 ),
1259 (
1261 &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
1262 6,
1263 fallback_timeout,
1264 false,
1265 ),
1266 (
1267 &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
1268 4,
1269 fallback_timeout,
1270 true,
1271 ),
1272 (
1274 &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
1275 6,
1276 fallback_timeout + unreachable_v6_timeout,
1277 false,
1278 ),
1279 (
1280 &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
1281 4,
1282 fallback_timeout + unreachable_v4_timeout,
1283 true,
1284 ),
1285 ];
1286
1287 let ipv6_accessible = measure_connect(slow_ipv6_addr()).0;
1290
1291 for &(hosts, family, timeout, needs_ipv6_access) in scenarios {
1292 if needs_ipv6_access && !ipv6_accessible {
1293 continue;
1294 }
1295
1296 let (start, stream) = rt
1297 .block_on(async move {
1298 let addrs = hosts
1299 .iter()
1300 .map(|host| (*host, addr.port()).into())
1301 .collect();
1302 let cfg = Config {
1303 local_address_ipv4: None,
1304 local_address_ipv6: None,
1305 connect_timeout: None,
1306 tcp_keepalive_config: TcpKeepaliveConfig::default(),
1307 happy_eyeballs_timeout: Some(fallback_timeout),
1308 nodelay: false,
1309 reuse_address: false,
1310 enforce_http: false,
1311 send_buffer_size: None,
1312 recv_buffer_size: None,
1313 #[cfg(any(
1314 target_os = "android",
1315 target_os = "fuchsia",
1316 target_os = "linux"
1317 ))]
1318 interface: None,
1319 #[cfg(any(
1320 target_os = "illumos",
1321 target_os = "ios",
1322 target_os = "macos",
1323 target_os = "solaris",
1324 target_os = "tvos",
1325 target_os = "visionos",
1326 target_os = "watchos",
1327 ))]
1328 interface: None,
1329 #[cfg(any(
1330 target_os = "android",
1331 target_os = "fuchsia",
1332 target_os = "linux"
1333 ))]
1334 tcp_user_timeout: None,
1335 };
1336 let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg);
1337 let start = Instant::now();
1338 Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?))
1339 })
1340 .unwrap();
1341 let res = if stream.peer_addr().unwrap().is_ipv4() {
1342 4
1343 } else {
1344 6
1345 };
1346 let duration = start.elapsed();
1347
1348 let min_duration = if timeout >= Duration::from_millis(150) {
1350 timeout - Duration::from_millis(150)
1351 } else {
1352 Duration::default()
1353 };
1354 let max_duration = timeout + Duration::from_millis(150);
1355
1356 assert_eq!(res, family);
1357 assert!(duration >= min_duration);
1358 assert!(duration <= max_duration);
1359 }
1360
1361 fn local_ipv4_addr() -> IpAddr {
1362 Ipv4Addr::new(127, 0, 0, 1).into()
1363 }
1364
1365 fn local_ipv6_addr() -> IpAddr {
1366 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()
1367 }
1368
1369 fn unreachable_ipv4_addr() -> IpAddr {
1370 Ipv4Addr::new(127, 0, 0, 2).into()
1371 }
1372
1373 fn unreachable_ipv6_addr() -> IpAddr {
1374 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into()
1375 }
1376
1377 fn slow_ipv4_addr() -> IpAddr {
1378 Ipv4Addr::new(198, 18, 0, 25).into()
1380 }
1381
1382 fn slow_ipv6_addr() -> IpAddr {
1383 Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into()
1385 }
1386
1387 fn measure_connect(addr: IpAddr) -> (bool, Duration) {
1388 let start = Instant::now();
1389 let result =
1390 std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1));
1391
1392 let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut;
1393 let duration = start.elapsed();
1394 (reachable, duration)
1395 }
1396 }
1397
1398 use std::time::Duration;
1399
1400 #[test]
1401 fn no_tcp_keepalive_config() {
1402 assert!(TcpKeepaliveConfig::default().into_tcpkeepalive().is_none());
1403 }
1404
1405 #[test]
1406 fn tcp_keepalive_time_config() {
1407 let kac = TcpKeepaliveConfig {
1408 time: Some(Duration::from_secs(60)),
1409 ..Default::default()
1410 };
1411 if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1412 assert!(format!("{tcp_keepalive:?}").contains("time: Some(60s)"));
1413 } else {
1414 panic!("test failed");
1415 }
1416 }
1417
1418 #[cfg(not(any(target_os = "openbsd", target_os = "redox", target_os = "solaris")))]
1419 #[test]
1420 fn tcp_keepalive_interval_config() {
1421 let kac = TcpKeepaliveConfig {
1422 interval: Some(Duration::from_secs(1)),
1423 ..Default::default()
1424 };
1425 if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1426 assert!(format!("{tcp_keepalive:?}").contains("interval: Some(1s)"));
1427 } else {
1428 panic!("test failed");
1429 }
1430 }
1431
1432 #[cfg(not(any(
1433 target_os = "openbsd",
1434 target_os = "redox",
1435 target_os = "solaris",
1436 target_os = "windows"
1437 )))]
1438 #[test]
1439 fn tcp_keepalive_retries_config() {
1440 let kac = TcpKeepaliveConfig {
1441 retries: Some(3),
1442 ..Default::default()
1443 };
1444 if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1445 assert!(format!("{tcp_keepalive:?}").contains("retries: Some(3)"));
1446 } else {
1447 panic!("test failed");
1448 }
1449 }
1450
1451 #[test]
1452 fn test_set_port() {
1453 let mut addr = SocketAddr::from(([0, 0, 0, 0], 6881));
1455 set_port(&mut addr, 42, true);
1456 assert_eq!(addr.port(), 42);
1457
1458 let mut addr = SocketAddr::from(([0, 0, 0, 0], 6881));
1460 set_port(&mut addr, 443, false);
1461 assert_eq!(addr.port(), 6881);
1462
1463 let mut addr = SocketAddr::from(([0, 0, 0, 0], 0));
1465 set_port(&mut addr, 443, false);
1466 assert_eq!(addr.port(), 443);
1467 }
1468}