1use super::{AsyncStream, RedisResult, RedisRuntime, SocketAddr, TaskHandle};
2use std::{
3 future::Future,
4 io,
5 pin::Pin,
6 task::{self, Poll},
7};
8#[cfg(unix)]
9use tokio::net::UnixStream as UnixStreamTokio;
10use tokio::{
11 io::{AsyncRead, AsyncWrite, ReadBuf},
12 net::TcpStream as TcpStreamTokio,
13};
14
15#[cfg(all(feature = "tokio-native-tls-comp", not(feature = "tokio-rustls-comp")))]
16use native_tls::TlsConnector;
17
18#[cfg(feature = "tokio-rustls-comp")]
19use crate::connection::create_rustls_config;
20#[cfg(feature = "tokio-rustls-comp")]
21use std::sync::Arc;
22#[cfg(feature = "tokio-rustls-comp")]
23use tokio_rustls::{client::TlsStream, TlsConnector};
24
25#[cfg(all(feature = "tokio-native-tls-comp", not(feature = "tokio-rustls-comp")))]
26use tokio_native_tls::TlsStream;
27
28#[cfg(any(feature = "tokio-rustls-comp", feature = "tokio-native-tls-comp"))]
29use crate::connection::TlsConnParams;
30
31#[cfg(unix)]
32use super::Path;
33
34#[inline(always)]
35async fn connect_tcp(
36 addr: &SocketAddr,
37 tcp_settings: &crate::io::tcp::TcpSettings,
38) -> io::Result<TcpStreamTokio> {
39 let socket = TcpStreamTokio::connect(addr).await?;
40 let std_socket = socket.into_std()?;
41 let std_socket = crate::io::tcp::stream_with_settings(std_socket, tcp_settings)?;
42
43 TcpStreamTokio::from_std(std_socket)
44}
45
46pub(crate) enum Tokio {
47 Tcp(TcpStreamTokio),
49 #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))]
51 TcpTls(Box<TlsStream<TcpStreamTokio>>),
52 #[cfg(unix)]
54 Unix(UnixStreamTokio),
55}
56
57impl AsyncWrite for Tokio {
58 fn poll_write(
59 mut self: Pin<&mut Self>,
60 cx: &mut task::Context,
61 buf: &[u8],
62 ) -> Poll<io::Result<usize>> {
63 match &mut *self {
64 Tokio::Tcp(r) => Pin::new(r).poll_write(cx, buf),
65 #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))]
66 Tokio::TcpTls(r) => Pin::new(r).poll_write(cx, buf),
67 #[cfg(unix)]
68 Tokio::Unix(r) => Pin::new(r).poll_write(cx, buf),
69 }
70 }
71
72 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<io::Result<()>> {
73 match &mut *self {
74 Tokio::Tcp(r) => Pin::new(r).poll_flush(cx),
75 #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))]
76 Tokio::TcpTls(r) => Pin::new(r).poll_flush(cx),
77 #[cfg(unix)]
78 Tokio::Unix(r) => Pin::new(r).poll_flush(cx),
79 }
80 }
81
82 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<io::Result<()>> {
83 match &mut *self {
84 Tokio::Tcp(r) => Pin::new(r).poll_shutdown(cx),
85 #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))]
86 Tokio::TcpTls(r) => Pin::new(r).poll_shutdown(cx),
87 #[cfg(unix)]
88 Tokio::Unix(r) => Pin::new(r).poll_shutdown(cx),
89 }
90 }
91}
92
93impl AsyncRead for Tokio {
94 fn poll_read(
95 mut self: Pin<&mut Self>,
96 cx: &mut task::Context,
97 buf: &mut ReadBuf<'_>,
98 ) -> Poll<io::Result<()>> {
99 match &mut *self {
100 Tokio::Tcp(r) => Pin::new(r).poll_read(cx, buf),
101 #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))]
102 Tokio::TcpTls(r) => Pin::new(r).poll_read(cx, buf),
103 #[cfg(unix)]
104 Tokio::Unix(r) => Pin::new(r).poll_read(cx, buf),
105 }
106 }
107}
108
109impl RedisRuntime for Tokio {
110 async fn connect_tcp(
111 socket_addr: SocketAddr,
112 tcp_settings: &crate::io::tcp::TcpSettings,
113 ) -> RedisResult<Self> {
114 Ok(connect_tcp(&socket_addr, tcp_settings)
115 .await
116 .map(Tokio::Tcp)?)
117 }
118
119 #[cfg(all(feature = "tokio-native-tls-comp", not(feature = "tokio-rustls-comp")))]
120 async fn connect_tcp_tls(
121 hostname: &str,
122 socket_addr: SocketAddr,
123 insecure: bool,
124 params: &Option<TlsConnParams>,
125 tcp_settings: &crate::io::tcp::TcpSettings,
126 ) -> RedisResult<Self> {
127 let tls_connector: tokio_native_tls::TlsConnector = if insecure {
128 TlsConnector::builder()
129 .danger_accept_invalid_certs(true)
130 .danger_accept_invalid_hostnames(true)
131 .use_sni(false)
132 .build()?
133 } else if let Some(params) = params {
134 TlsConnector::builder()
135 .danger_accept_invalid_hostnames(params.danger_accept_invalid_hostnames)
136 .build()?
137 } else {
138 TlsConnector::new()?
139 }
140 .into();
141 Ok(tls_connector
142 .connect(hostname, connect_tcp(&socket_addr, tcp_settings).await?)
143 .await
144 .map(|con| Tokio::TcpTls(Box::new(con)))?)
145 }
146
147 #[cfg(feature = "tokio-rustls-comp")]
148 async fn connect_tcp_tls(
149 hostname: &str,
150 socket_addr: SocketAddr,
151 insecure: bool,
152 tls_params: &Option<TlsConnParams>,
153 tcp_settings: &crate::io::tcp::TcpSettings,
154 ) -> RedisResult<Self> {
155 let config = create_rustls_config(insecure, tls_params.clone())?;
156 let tls_connector = TlsConnector::from(Arc::new(config));
157
158 Ok(tls_connector
159 .connect(
160 rustls::pki_types::ServerName::try_from(hostname)?.to_owned(),
161 connect_tcp(&socket_addr, tcp_settings).await?,
162 )
163 .await
164 .map(|con| Tokio::TcpTls(Box::new(con)))?)
165 }
166
167 #[cfg(unix)]
168 async fn connect_unix(path: &Path) -> RedisResult<Self> {
169 Ok(UnixStreamTokio::connect(path).await.map(Tokio::Unix)?)
170 }
171
172 #[cfg(feature = "tokio-comp")]
173 fn spawn(f: impl Future<Output = ()> + Send + 'static) -> TaskHandle {
174 TaskHandle::Tokio(tokio::spawn(f))
175 }
176
177 #[cfg(not(feature = "tokio-comp"))]
178 fn spawn(_: impl Future<Output = ()> + Send + 'static) -> TokioTaskHandle {
179 unreachable!()
180 }
181
182 fn boxed(self) -> Pin<Box<dyn AsyncStream + Send + Sync>> {
183 match self {
184 Tokio::Tcp(x) => Box::pin(x),
185 #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))]
186 Tokio::TcpTls(x) => Box::pin(x),
187 #[cfg(unix)]
188 Tokio::Unix(x) => Box::pin(x),
189 }
190 }
191}