1use crate::cmd::Cmd;
3use crate::connection::{
4 AuthResult, ConnectionSetupComponents, RedisConnectionInfo, check_connection_setup,
5 connection_setup_pipeline,
6};
7use crate::io::AsyncDNSResolver;
8use crate::types::{RedisFuture, RedisResult, Value};
9use crate::{ErrorKind, PushInfo, RedisError, errors::closed_connection_error};
10use ::tokio::io::{AsyncRead, AsyncWrite};
11use futures_util::{
12 future::{Future, FutureExt},
13 sink::{Sink, SinkExt},
14 stream::{Stream, StreamExt},
15};
16pub use monitor::Monitor;
17use std::net::SocketAddr;
18#[cfg(unix)]
19use std::path::Path;
20use std::pin::Pin;
21
22mod monitor;
23
24#[cfg(any(feature = "tls-rustls", feature = "tls-native-tls"))]
25use crate::connection::TlsConnParams;
26
27#[cfg(feature = "smol-comp")]
29#[cfg_attr(docsrs, doc(cfg(feature = "smol-comp")))]
30pub mod smol;
31#[cfg(feature = "tokio-comp")]
33#[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))]
34pub mod tokio;
35
36mod pubsub;
37pub use pubsub::{PubSub, PubSubSink, PubSubStream};
38
39pub(crate) trait RedisRuntime: AsyncStream + Send + Sync + Sized + 'static {
41 async fn connect_tcp(
43 socket_addr: SocketAddr,
44 tcp_settings: &crate::io::tcp::TcpSettings,
45 ) -> RedisResult<Self>;
46
47 #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
49 async fn connect_tcp_tls(
50 hostname: &str,
51 socket_addr: SocketAddr,
52 insecure: bool,
53 tls_params: &Option<TlsConnParams>,
54 tcp_settings: &crate::io::tcp::TcpSettings,
55 ) -> RedisResult<Self>;
56
57 #[cfg(unix)]
59 async fn connect_unix(path: &Path) -> RedisResult<Self>;
60
61 fn spawn(f: impl Future<Output = ()> + Send + 'static) -> TaskHandle;
62
63 fn boxed(self) -> Pin<Box<dyn AsyncStream + Send + Sync>> {
64 Box::pin(self)
65 }
66}
67
68pub trait AsyncStream: AsyncRead + AsyncWrite {}
70impl<S> AsyncStream for S where S: AsyncRead + AsyncWrite {}
71
72pub trait ConnectionLike {
74 fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value>;
77
78 #[doc(hidden)]
86 fn req_packed_commands<'a>(
87 &'a mut self,
88 cmd: &'a crate::Pipeline,
89 offset: usize,
90 count: usize,
91 ) -> RedisFuture<'a, Vec<Value>>;
92
93 fn get_db(&self) -> i64;
98}
99
100async fn execute_connection_pipeline<T>(
101 codec: &mut T,
102 (pipeline, instructions): (crate::Pipeline, ConnectionSetupComponents),
103) -> RedisResult<AuthResult>
104where
105 T: Sink<Vec<u8>, Error = RedisError>,
106 T: Stream<Item = RedisResult<Value>>,
107 T: Unpin + Send + 'static,
108{
109 let count = pipeline.len();
110 if count == 0 {
111 return Ok(AuthResult::Succeeded);
112 }
113 codec.send(pipeline.get_packed_pipeline()).await?;
114
115 let mut results = Vec::with_capacity(count);
116 for _ in 0..count {
117 let value = codec.next().await.ok_or_else(closed_connection_error)??;
118 results.push(value);
119 }
120
121 check_connection_setup(results, instructions)
122}
123
124pub(super) async fn setup_connection<T>(
125 codec: &mut T,
126 connection_info: &RedisConnectionInfo,
127 #[cfg(feature = "cache-aio")] cache_config: Option<crate::caching::CacheConfig>,
128) -> RedisResult<()>
129where
130 T: Sink<Vec<u8>, Error = RedisError>,
131 T: Stream<Item = RedisResult<Value>>,
132 T: Unpin + Send + 'static,
133{
134 if execute_connection_pipeline(
135 codec,
136 connection_setup_pipeline(
137 connection_info,
138 true,
139 #[cfg(feature = "cache-aio")]
140 cache_config,
141 ),
142 )
143 .await?
144 == AuthResult::ShouldRetryWithoutUsername
145 {
146 execute_connection_pipeline(
147 codec,
148 connection_setup_pipeline(
149 connection_info,
150 false,
151 #[cfg(feature = "cache-aio")]
152 cache_config,
153 ),
154 )
155 .await?;
156 }
157
158 Ok(())
159}
160
161mod connection;
162pub(crate) use connection::connect_simple;
163pub use connection::transaction_async;
164mod multiplexed_connection;
165pub use multiplexed_connection::*;
166#[cfg(feature = "connection-manager")]
167mod connection_manager;
168#[cfg(feature = "connection-manager")]
169#[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))]
170pub use connection_manager::*;
171mod runtime;
172#[cfg(all(feature = "smol-comp", feature = "tokio-comp"))]
173pub use runtime::prefer_smol;
174#[cfg(all(feature = "tokio-comp", feature = "smol-comp"))]
175pub use runtime::prefer_tokio;
176pub(super) use runtime::*;
177
178pub struct SendError;
180
181pub trait AsyncPushSender: Send + Sync + 'static {
184 fn send(&self, info: PushInfo) -> Result<(), SendError>;
186}
187
188impl AsyncPushSender for ::tokio::sync::mpsc::UnboundedSender<PushInfo> {
189 fn send(&self, info: PushInfo) -> Result<(), SendError> {
190 match self.send(info) {
191 Ok(_) => Ok(()),
192 Err(_) => Err(SendError),
193 }
194 }
195}
196
197impl AsyncPushSender for ::tokio::sync::broadcast::Sender<PushInfo> {
198 fn send(&self, info: PushInfo) -> Result<(), SendError> {
199 match self.send(info) {
200 Ok(_) => Ok(()),
201 Err(_) => Err(SendError),
202 }
203 }
204}
205
206impl<T, Func: Fn(PushInfo) -> Result<(), T> + Send + Sync + 'static> AsyncPushSender for Func {
207 fn send(&self, info: PushInfo) -> Result<(), SendError> {
208 match self(info) {
209 Ok(_) => Ok(()),
210 Err(_) => Err(SendError),
211 }
212 }
213}
214
215impl AsyncPushSender for std::sync::mpsc::Sender<PushInfo> {
216 fn send(&self, info: PushInfo) -> Result<(), SendError> {
217 match self.send(info) {
218 Ok(_) => Ok(()),
219 Err(_) => Err(SendError),
220 }
221 }
222}
223
224impl<T> AsyncPushSender for std::sync::Arc<T>
225where
226 T: AsyncPushSender,
227{
228 fn send(&self, info: PushInfo) -> Result<(), SendError> {
229 self.as_ref().send(info)
230 }
231}
232
233#[derive(Clone)]
235pub(crate) struct DefaultAsyncDNSResolver;
236
237impl AsyncDNSResolver for DefaultAsyncDNSResolver {
238 fn resolve<'a, 'b: 'a>(
239 &'a self,
240 host: &'b str,
241 port: u16,
242 ) -> RedisFuture<'a, Box<dyn Iterator<Item = SocketAddr> + Send + 'a>> {
243 Box::pin(get_socket_addrs(host, port).map(|vec| {
244 Ok(Box::new(vec?.into_iter()) as Box<dyn Iterator<Item = SocketAddr> + Send>)
245 }))
246 }
247}
248
249async fn get_socket_addrs(host: &str, port: u16) -> RedisResult<Vec<SocketAddr>> {
250 let socket_addrs: Vec<_> = match Runtime::locate() {
251 #[cfg(feature = "tokio-comp")]
252 Runtime::Tokio => ::tokio::net::lookup_host((host, port))
253 .await
254 .map_err(RedisError::from)
255 .map(|iter| iter.collect()),
256
257 #[cfg(feature = "smol-comp")]
258 Runtime::Smol => ::smol::net::resolve((host, port))
259 .await
260 .map_err(RedisError::from),
261 }?;
262
263 if socket_addrs.is_empty() {
264 Err(RedisError::from((
265 ErrorKind::InvalidClientConfig,
266 "No address found for host",
267 )))
268 } else {
269 Ok(socket_addrs)
270 }
271}