redis/aio/
tokio.rs

1use super::{AsyncStream, RedisResult, RedisRuntime, SocketAddr, TaskHandle};
2use std::{
3    future::Future,
4    io,
5    pin::Pin,
6    task::{self, Poll},
7};
8#[cfg(unix)]
9use tokio::net::UnixStream as UnixStreamTokio;
10use tokio::{
11    io::{AsyncRead, AsyncWrite, ReadBuf},
12    net::TcpStream as TcpStreamTokio,
13};
14
15#[cfg(all(feature = "tokio-native-tls-comp", not(feature = "tokio-rustls-comp")))]
16use native_tls::TlsConnector;
17
18#[cfg(feature = "tokio-rustls-comp")]
19use crate::connection::create_rustls_config;
20#[cfg(feature = "tokio-rustls-comp")]
21use std::sync::Arc;
22#[cfg(feature = "tokio-rustls-comp")]
23use tokio_rustls::{client::TlsStream, TlsConnector};
24
25#[cfg(all(feature = "tokio-native-tls-comp", not(feature = "tokio-rustls-comp")))]
26use tokio_native_tls::TlsStream;
27
28#[cfg(any(feature = "tokio-rustls-comp", feature = "tokio-native-tls-comp"))]
29use crate::connection::TlsConnParams;
30
31#[cfg(unix)]
32use super::Path;
33
34#[inline(always)]
35async fn connect_tcp(
36    addr: &SocketAddr,
37    tcp_settings: &crate::io::tcp::TcpSettings,
38) -> io::Result<TcpStreamTokio> {
39    let socket = TcpStreamTokio::connect(addr).await?;
40    let std_socket = socket.into_std()?;
41    let std_socket = crate::io::tcp::stream_with_settings(std_socket, tcp_settings)?;
42
43    TcpStreamTokio::from_std(std_socket)
44}
45
46pub(crate) enum Tokio {
47    /// Represents a Tokio TCP connection.
48    Tcp(TcpStreamTokio),
49    /// Represents a Tokio TLS encrypted TCP connection
50    #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))]
51    TcpTls(Box<TlsStream<TcpStreamTokio>>),
52    /// Represents a Tokio Unix connection.
53    #[cfg(unix)]
54    Unix(UnixStreamTokio),
55}
56
57impl AsyncWrite for Tokio {
58    fn poll_write(
59        mut self: Pin<&mut Self>,
60        cx: &mut task::Context,
61        buf: &[u8],
62    ) -> Poll<io::Result<usize>> {
63        match &mut *self {
64            Tokio::Tcp(r) => Pin::new(r).poll_write(cx, buf),
65            #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))]
66            Tokio::TcpTls(r) => Pin::new(r).poll_write(cx, buf),
67            #[cfg(unix)]
68            Tokio::Unix(r) => Pin::new(r).poll_write(cx, buf),
69        }
70    }
71
72    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<io::Result<()>> {
73        match &mut *self {
74            Tokio::Tcp(r) => Pin::new(r).poll_flush(cx),
75            #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))]
76            Tokio::TcpTls(r) => Pin::new(r).poll_flush(cx),
77            #[cfg(unix)]
78            Tokio::Unix(r) => Pin::new(r).poll_flush(cx),
79        }
80    }
81
82    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<io::Result<()>> {
83        match &mut *self {
84            Tokio::Tcp(r) => Pin::new(r).poll_shutdown(cx),
85            #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))]
86            Tokio::TcpTls(r) => Pin::new(r).poll_shutdown(cx),
87            #[cfg(unix)]
88            Tokio::Unix(r) => Pin::new(r).poll_shutdown(cx),
89        }
90    }
91}
92
93impl AsyncRead for Tokio {
94    fn poll_read(
95        mut self: Pin<&mut Self>,
96        cx: &mut task::Context,
97        buf: &mut ReadBuf<'_>,
98    ) -> Poll<io::Result<()>> {
99        match &mut *self {
100            Tokio::Tcp(r) => Pin::new(r).poll_read(cx, buf),
101            #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))]
102            Tokio::TcpTls(r) => Pin::new(r).poll_read(cx, buf),
103            #[cfg(unix)]
104            Tokio::Unix(r) => Pin::new(r).poll_read(cx, buf),
105        }
106    }
107}
108
109impl RedisRuntime for Tokio {
110    async fn connect_tcp(
111        socket_addr: SocketAddr,
112        tcp_settings: &crate::io::tcp::TcpSettings,
113    ) -> RedisResult<Self> {
114        Ok(connect_tcp(&socket_addr, tcp_settings)
115            .await
116            .map(Tokio::Tcp)?)
117    }
118
119    #[cfg(all(feature = "tokio-native-tls-comp", not(feature = "tokio-rustls-comp")))]
120    async fn connect_tcp_tls(
121        hostname: &str,
122        socket_addr: SocketAddr,
123        insecure: bool,
124        params: &Option<TlsConnParams>,
125        tcp_settings: &crate::io::tcp::TcpSettings,
126    ) -> RedisResult<Self> {
127        let tls_connector: tokio_native_tls::TlsConnector = if insecure {
128            TlsConnector::builder()
129                .danger_accept_invalid_certs(true)
130                .danger_accept_invalid_hostnames(true)
131                .use_sni(false)
132                .build()?
133        } else if let Some(params) = params {
134            TlsConnector::builder()
135                .danger_accept_invalid_hostnames(params.danger_accept_invalid_hostnames)
136                .build()?
137        } else {
138            TlsConnector::new()?
139        }
140        .into();
141        Ok(tls_connector
142            .connect(hostname, connect_tcp(&socket_addr, tcp_settings).await?)
143            .await
144            .map(|con| Tokio::TcpTls(Box::new(con)))?)
145    }
146
147    #[cfg(feature = "tokio-rustls-comp")]
148    async fn connect_tcp_tls(
149        hostname: &str,
150        socket_addr: SocketAddr,
151        insecure: bool,
152        tls_params: &Option<TlsConnParams>,
153        tcp_settings: &crate::io::tcp::TcpSettings,
154    ) -> RedisResult<Self> {
155        let config = create_rustls_config(insecure, tls_params.clone())?;
156        let tls_connector = TlsConnector::from(Arc::new(config));
157
158        Ok(tls_connector
159            .connect(
160                rustls::pki_types::ServerName::try_from(hostname)?.to_owned(),
161                connect_tcp(&socket_addr, tcp_settings).await?,
162            )
163            .await
164            .map(|con| Tokio::TcpTls(Box::new(con)))?)
165    }
166
167    #[cfg(unix)]
168    async fn connect_unix(path: &Path) -> RedisResult<Self> {
169        Ok(UnixStreamTokio::connect(path).await.map(Tokio::Unix)?)
170    }
171
172    #[cfg(feature = "tokio-comp")]
173    fn spawn(f: impl Future<Output = ()> + Send + 'static) -> TaskHandle {
174        TaskHandle::Tokio(tokio::spawn(f))
175    }
176
177    #[cfg(not(feature = "tokio-comp"))]
178    fn spawn(_: impl Future<Output = ()> + Send + 'static) -> TokioTaskHandle {
179        unreachable!()
180    }
181
182    fn boxed(self) -> Pin<Box<dyn AsyncStream + Send + Sync>> {
183        match self {
184            Tokio::Tcp(x) => Box::pin(x),
185            #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))]
186            Tokio::TcpTls(x) => Box::pin(x),
187            #[cfg(unix)]
188            Tokio::Unix(x) => Box::pin(x),
189        }
190    }
191}