redis/
connection.rs

1use std::borrow::Cow;
2use std::collections::VecDeque;
3use std::fmt;
4use std::io::{self, Write};
5use std::net::{self, SocketAddr, TcpStream, ToSocketAddrs};
6use std::ops::DerefMut;
7use std::path::PathBuf;
8use std::str::{from_utf8, FromStr};
9use std::time::{Duration, Instant};
10
11use crate::cmd::{cmd, pipe, Cmd};
12use crate::io::tcp::{stream_with_settings, TcpSettings};
13use crate::parser::Parser;
14use crate::pipeline::Pipeline;
15use crate::types::{
16    from_redis_value, ErrorKind, FromRedisValue, HashMap, PushKind, RedisError, RedisResult,
17    ServerError, ServerErrorKind, SyncPushSender, ToRedisArgs, Value,
18};
19use crate::{from_owned_redis_value, ProtocolVersion};
20
21#[cfg(unix)]
22use std::os::unix::net::UnixStream;
23
24use crate::commands::resp3_hello;
25#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
26use native_tls::{TlsConnector, TlsStream};
27
28#[cfg(feature = "tls-rustls")]
29use rustls::{RootCertStore, StreamOwned};
30#[cfg(feature = "tls-rustls")]
31use std::sync::Arc;
32
33use crate::PushInfo;
34
35#[cfg(all(
36    feature = "tls-rustls",
37    not(feature = "tls-native-tls"),
38    not(feature = "tls-rustls-webpki-roots")
39))]
40use rustls_native_certs::load_native_certs;
41
42#[cfg(feature = "tls-rustls")]
43use crate::tls::ClientTlsParams;
44
45// Non-exhaustive to prevent construction outside this crate
46#[derive(Clone, Debug)]
47#[non_exhaustive]
48pub struct TlsConnParams {
49    #[cfg(feature = "tls-rustls")]
50    pub(crate) client_tls_params: Option<ClientTlsParams>,
51    #[cfg(feature = "tls-rustls")]
52    pub(crate) root_cert_store: Option<RootCertStore>,
53    #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
54    pub(crate) danger_accept_invalid_hostnames: bool,
55}
56
57static DEFAULT_PORT: u16 = 6379;
58
59#[inline(always)]
60fn connect_tcp(addr: (&str, u16)) -> io::Result<TcpStream> {
61    let socket = TcpStream::connect(addr)?;
62    stream_with_settings(socket, &TcpSettings::default())
63}
64
65#[inline(always)]
66fn connect_tcp_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
67    let socket = TcpStream::connect_timeout(addr, timeout)?;
68    stream_with_settings(socket, &TcpSettings::default())
69}
70
71/// This function takes a redis URL string and parses it into a URL
72/// as used by rust-url.
73///
74/// This is necessary as the default parser does not understand how redis URLs function.
75pub fn parse_redis_url(input: &str) -> Option<url::Url> {
76    match url::Url::parse(input) {
77        Ok(result) => match result.scheme() {
78            "redis" | "rediss" | "valkey" | "valkeys" | "redis+unix" | "valkey+unix" | "unix" => {
79                Some(result)
80            }
81            _ => None,
82        },
83        Err(_) => None,
84    }
85}
86
87/// TlsMode indicates use or do not use verification of certification.
88///
89/// Check [ConnectionAddr](ConnectionAddr::TcpTls::insecure) for more.
90#[derive(Clone, Copy, PartialEq)]
91pub enum TlsMode {
92    /// Secure verify certification.
93    Secure,
94    /// Insecure do not verify certification.
95    Insecure,
96}
97
98/// Defines the connection address.
99///
100/// Not all connection addresses are supported on all platforms.  For instance
101/// to connect to a unix socket you need to run this on an operating system
102/// that supports them.
103#[derive(Clone, Debug)]
104pub enum ConnectionAddr {
105    /// Format for this is `(host, port)`.
106    Tcp(String, u16),
107    /// Format for this is `(host, port)`.
108    TcpTls {
109        /// Hostname
110        host: String,
111        /// Port
112        port: u16,
113        /// Disable hostname verification when connecting.
114        ///
115        /// # Warning
116        ///
117        /// You should think very carefully before you use this method. If hostname
118        /// verification is not used, any valid certificate for any site will be
119        /// trusted for use from any other. This introduces a significant
120        /// vulnerability to man-in-the-middle attacks.
121        insecure: bool,
122
123        /// TLS certificates and client key.
124        tls_params: Option<TlsConnParams>,
125    },
126    /// Format for this is the path to the unix socket.
127    Unix(PathBuf),
128}
129
130impl PartialEq for ConnectionAddr {
131    fn eq(&self, other: &Self) -> bool {
132        match (self, other) {
133            (ConnectionAddr::Tcp(host1, port1), ConnectionAddr::Tcp(host2, port2)) => {
134                host1 == host2 && port1 == port2
135            }
136            (
137                ConnectionAddr::TcpTls {
138                    host: host1,
139                    port: port1,
140                    insecure: insecure1,
141                    tls_params: _,
142                },
143                ConnectionAddr::TcpTls {
144                    host: host2,
145                    port: port2,
146                    insecure: insecure2,
147                    tls_params: _,
148                },
149            ) => port1 == port2 && host1 == host2 && insecure1 == insecure2,
150            (ConnectionAddr::Unix(path1), ConnectionAddr::Unix(path2)) => path1 == path2,
151            _ => false,
152        }
153    }
154}
155
156impl Eq for ConnectionAddr {}
157
158impl ConnectionAddr {
159    /// Checks if this address is supported.
160    ///
161    /// Because not all platforms support all connection addresses this is a
162    /// quick way to figure out if a connection method is supported. Currently
163    /// this affects:
164    ///
165    /// - Unix socket addresses, which are supported only on Unix
166    ///
167    /// - TLS addresses, which are supported only if a TLS feature is enabled
168    ///   (either `tls-native-tls` or `tls-rustls`).
169    pub fn is_supported(&self) -> bool {
170        match *self {
171            ConnectionAddr::Tcp(_, _) => true,
172            ConnectionAddr::TcpTls { .. } => {
173                cfg!(any(feature = "tls-native-tls", feature = "tls-rustls"))
174            }
175            ConnectionAddr::Unix(_) => cfg!(unix),
176        }
177    }
178
179    /// Configure this address to connect without checking certificate hostnames.
180    ///
181    /// # Warning
182    ///
183    /// You should think very carefully before you use this method. If hostname
184    /// verification is not used, any valid certificate for any site will be
185    /// trusted for use from any other. This introduces a significant
186    /// vulnerability to man-in-the-middle attacks.
187    #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
188    pub fn set_danger_accept_invalid_hostnames(&mut self, insecure: bool) {
189        if let ConnectionAddr::TcpTls { tls_params, .. } = self {
190            if let Some(ref mut params) = tls_params {
191                params.danger_accept_invalid_hostnames = insecure;
192            } else if insecure {
193                *tls_params = Some(TlsConnParams {
194                    #[cfg(feature = "tls-rustls")]
195                    client_tls_params: None,
196                    #[cfg(feature = "tls-rustls")]
197                    root_cert_store: None,
198                    danger_accept_invalid_hostnames: insecure,
199                });
200            }
201        }
202    }
203
204    #[cfg(feature = "cluster")]
205    pub(crate) fn tls_mode(&self) -> Option<TlsMode> {
206        match self {
207            ConnectionAddr::TcpTls { insecure, .. } => {
208                if *insecure {
209                    Some(TlsMode::Insecure)
210                } else {
211                    Some(TlsMode::Secure)
212                }
213            }
214            _ => None,
215        }
216    }
217}
218
219impl fmt::Display for ConnectionAddr {
220    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
221        // Cluster::get_connection_info depends on the return value from this function
222        match *self {
223            ConnectionAddr::Tcp(ref host, port) => write!(f, "{host}:{port}"),
224            ConnectionAddr::TcpTls { ref host, port, .. } => write!(f, "{host}:{port}"),
225            ConnectionAddr::Unix(ref path) => write!(f, "{}", path.display()),
226        }
227    }
228}
229
230/// Holds the connection information that redis should use for connecting.
231#[derive(Clone, Debug)]
232pub struct ConnectionInfo {
233    /// A connection address for where to connect to.
234    pub addr: ConnectionAddr,
235
236    /// A redis connection info for how to handshake with redis.
237    pub redis: RedisConnectionInfo,
238}
239
240/// Redis specific/connection independent information used to establish a connection to redis.
241#[derive(Clone, Debug, Default)]
242pub struct RedisConnectionInfo {
243    /// The database number to use.  This is usually `0`.
244    pub db: i64,
245    /// Optionally a username that should be used for connection.
246    pub username: Option<String>,
247    /// Optionally a password that should be used for connection.
248    pub password: Option<String>,
249    /// Version of the protocol to use.
250    pub protocol: ProtocolVersion,
251}
252
253impl FromStr for ConnectionInfo {
254    type Err = RedisError;
255
256    fn from_str(s: &str) -> Result<Self, Self::Err> {
257        s.into_connection_info()
258    }
259}
260
261/// Converts an object into a connection info struct.  This allows the
262/// constructor of the client to accept connection information in a
263/// range of different formats.
264pub trait IntoConnectionInfo {
265    /// Converts the object into a connection info object.
266    fn into_connection_info(self) -> RedisResult<ConnectionInfo>;
267}
268
269impl IntoConnectionInfo for ConnectionInfo {
270    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
271        Ok(self)
272    }
273}
274
275/// URL format: `{redis|rediss|valkey|valkeys}://[<username>][:<password>@]<hostname>[:port][/<db>]`
276///
277/// - Basic: `redis://127.0.0.1:6379`
278/// - Username & Password: `redis://user:password@127.0.0.1:6379`
279/// - Password only: `redis://:password@127.0.0.1:6379`
280/// - Specifying DB: `redis://127.0.0.1:6379/0`
281/// - Enabling TLS: `rediss://127.0.0.1:6379`
282/// - Enabling Insecure TLS: `rediss://127.0.0.1:6379/#insecure`
283/// - Enabling RESP3: `redis://127.0.0.1:6379/?protocol=resp3`
284impl IntoConnectionInfo for &str {
285    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
286        match parse_redis_url(self) {
287            Some(u) => u.into_connection_info(),
288            None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
289        }
290    }
291}
292
293impl<T> IntoConnectionInfo for (T, u16)
294where
295    T: Into<String>,
296{
297    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
298        Ok(ConnectionInfo {
299            addr: ConnectionAddr::Tcp(self.0.into(), self.1),
300            redis: RedisConnectionInfo::default(),
301        })
302    }
303}
304
305/// URL format: `{redis|rediss|valkey|valkeys}://[<username>][:<password>@]<hostname>[:port][/<db>]`
306///
307/// - Basic: `redis://127.0.0.1:6379`
308/// - Username & Password: `redis://user:password@127.0.0.1:6379`
309/// - Password only: `redis://:password@127.0.0.1:6379`
310/// - Specifying DB: `redis://127.0.0.1:6379/0`
311/// - Enabling TLS: `rediss://127.0.0.1:6379`
312/// - Enabling Insecure TLS: `rediss://127.0.0.1:6379/#insecure`
313/// - Enabling RESP3: `redis://127.0.0.1:6379/?protocol=resp3`
314impl IntoConnectionInfo for String {
315    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
316        match parse_redis_url(&self) {
317            Some(u) => u.into_connection_info(),
318            None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
319        }
320    }
321}
322
323fn parse_protocol(query: &HashMap<Cow<str>, Cow<str>>) -> RedisResult<ProtocolVersion> {
324    Ok(match query.get("protocol") {
325        Some(protocol) => {
326            if protocol == "2" || protocol == "resp2" {
327                ProtocolVersion::RESP2
328            } else if protocol == "3" || protocol == "resp3" {
329                ProtocolVersion::RESP3
330            } else {
331                fail!((
332                    ErrorKind::InvalidClientConfig,
333                    "Invalid protocol version",
334                    protocol.to_string()
335                ))
336            }
337        }
338        None => ProtocolVersion::RESP2,
339    })
340}
341
342fn url_to_tcp_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
343    let host = match url.host() {
344        Some(host) => {
345            // Here we manually match host's enum arms and call their to_string().
346            // Because url.host().to_string() will add `[` and `]` for ipv6:
347            // https://docs.rs/url/latest/src/url/host.rs.html#170
348            // And these brackets will break host.parse::<Ipv6Addr>() when
349            // `client.open()` - `ActualConnection::new()` - `addr.to_socket_addrs()`:
350            // https://doc.rust-lang.org/src/std/net/addr.rs.html#963
351            // https://doc.rust-lang.org/src/std/net/parser.rs.html#158
352            // IpAddr string with brackets can ONLY parse to SocketAddrV6:
353            // https://doc.rust-lang.org/src/std/net/parser.rs.html#255
354            // But if we call Ipv6Addr.to_string directly, it follows rfc5952 without brackets:
355            // https://doc.rust-lang.org/src/std/net/ip.rs.html#1755
356            match host {
357                url::Host::Domain(path) => path.to_string(),
358                url::Host::Ipv4(v4) => v4.to_string(),
359                url::Host::Ipv6(v6) => v6.to_string(),
360            }
361        }
362        None => fail!((ErrorKind::InvalidClientConfig, "Missing hostname")),
363    };
364    let port = url.port().unwrap_or(DEFAULT_PORT);
365    let addr = if url.scheme() == "rediss" || url.scheme() == "valkeys" {
366        #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
367        {
368            match url.fragment() {
369                Some("insecure") => ConnectionAddr::TcpTls {
370                    host,
371                    port,
372                    insecure: true,
373                    tls_params: None,
374                },
375                Some(_) => fail!((
376                    ErrorKind::InvalidClientConfig,
377                    "only #insecure is supported as URL fragment"
378                )),
379                _ => ConnectionAddr::TcpTls {
380                    host,
381                    port,
382                    insecure: false,
383                    tls_params: None,
384                },
385            }
386        }
387
388        #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
389        fail!((
390            ErrorKind::InvalidClientConfig,
391            "can't connect with TLS, the feature is not enabled"
392        ));
393    } else {
394        ConnectionAddr::Tcp(host, port)
395    };
396    let query: HashMap<_, _> = url.query_pairs().collect();
397    Ok(ConnectionInfo {
398        addr,
399        redis: RedisConnectionInfo {
400            db: match url.path().trim_matches('/') {
401                "" => 0,
402                path => path.parse::<i64>().map_err(|_| -> RedisError {
403                    (ErrorKind::InvalidClientConfig, "Invalid database number").into()
404                })?,
405            },
406            username: if url.username().is_empty() {
407                None
408            } else {
409                match percent_encoding::percent_decode(url.username().as_bytes()).decode_utf8() {
410                    Ok(decoded) => Some(decoded.into_owned()),
411                    Err(_) => fail!((
412                        ErrorKind::InvalidClientConfig,
413                        "Username is not valid UTF-8 string"
414                    )),
415                }
416            },
417            password: match url.password() {
418                Some(pw) => match percent_encoding::percent_decode(pw.as_bytes()).decode_utf8() {
419                    Ok(decoded) => Some(decoded.into_owned()),
420                    Err(_) => fail!((
421                        ErrorKind::InvalidClientConfig,
422                        "Password is not valid UTF-8 string"
423                    )),
424                },
425                None => None,
426            },
427            protocol: parse_protocol(&query)?,
428        },
429    })
430}
431
432#[cfg(unix)]
433fn url_to_unix_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
434    let query: HashMap<_, _> = url.query_pairs().collect();
435    Ok(ConnectionInfo {
436        addr: ConnectionAddr::Unix(url.to_file_path().map_err(|_| -> RedisError {
437            (ErrorKind::InvalidClientConfig, "Missing path").into()
438        })?),
439        redis: RedisConnectionInfo {
440            db: match query.get("db") {
441                Some(db) => db.parse::<i64>().map_err(|_| -> RedisError {
442                    (ErrorKind::InvalidClientConfig, "Invalid database number").into()
443                })?,
444
445                None => 0,
446            },
447            username: query.get("user").map(|username| username.to_string()),
448            password: query.get("pass").map(|password| password.to_string()),
449            protocol: parse_protocol(&query)?,
450        },
451    })
452}
453
454#[cfg(not(unix))]
455fn url_to_unix_connection_info(_: url::Url) -> RedisResult<ConnectionInfo> {
456    fail!((
457        ErrorKind::InvalidClientConfig,
458        "Unix sockets are not available on this platform."
459    ));
460}
461
462impl IntoConnectionInfo for url::Url {
463    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
464        match self.scheme() {
465            "redis" | "rediss" | "valkey" | "valkeys" => url_to_tcp_connection_info(self),
466            "unix" | "redis+unix" | "valkey+unix" => url_to_unix_connection_info(self),
467            _ => fail!((
468                ErrorKind::InvalidClientConfig,
469                "URL provided is not a redis URL"
470            )),
471        }
472    }
473}
474
475struct TcpConnection {
476    reader: TcpStream,
477    open: bool,
478}
479
480#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
481struct TcpNativeTlsConnection {
482    reader: TlsStream<TcpStream>,
483    open: bool,
484}
485
486#[cfg(feature = "tls-rustls")]
487struct TcpRustlsConnection {
488    reader: StreamOwned<rustls::ClientConnection, TcpStream>,
489    open: bool,
490}
491
492#[cfg(unix)]
493struct UnixConnection {
494    sock: UnixStream,
495    open: bool,
496}
497
498enum ActualConnection {
499    Tcp(TcpConnection),
500    #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
501    TcpNativeTls(Box<TcpNativeTlsConnection>),
502    #[cfg(feature = "tls-rustls")]
503    TcpRustls(Box<TcpRustlsConnection>),
504    #[cfg(unix)]
505    Unix(UnixConnection),
506}
507
508#[cfg(feature = "tls-rustls-insecure")]
509struct NoCertificateVerification {
510    supported: rustls::crypto::WebPkiSupportedAlgorithms,
511}
512
513#[cfg(feature = "tls-rustls-insecure")]
514impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
515    fn verify_server_cert(
516        &self,
517        _end_entity: &rustls::pki_types::CertificateDer<'_>,
518        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
519        _server_name: &rustls::pki_types::ServerName<'_>,
520        _ocsp_response: &[u8],
521        _now: rustls::pki_types::UnixTime,
522    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
523        Ok(rustls::client::danger::ServerCertVerified::assertion())
524    }
525
526    fn verify_tls12_signature(
527        &self,
528        _message: &[u8],
529        _cert: &rustls::pki_types::CertificateDer<'_>,
530        _dss: &rustls::DigitallySignedStruct,
531    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
532        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
533    }
534
535    fn verify_tls13_signature(
536        &self,
537        _message: &[u8],
538        _cert: &rustls::pki_types::CertificateDer<'_>,
539        _dss: &rustls::DigitallySignedStruct,
540    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
541        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
542    }
543
544    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
545        self.supported.supported_schemes()
546    }
547}
548
549#[cfg(feature = "tls-rustls-insecure")]
550impl fmt::Debug for NoCertificateVerification {
551    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
552        f.debug_struct("NoCertificateVerification").finish()
553    }
554}
555
556/// Insecure `ServerCertVerifier` for rustls that implements `danger_accept_invalid_hostnames`.
557#[cfg(feature = "tls-rustls-insecure")]
558#[derive(Debug)]
559struct AcceptInvalidHostnamesCertVerifier {
560    inner: Arc<rustls::client::WebPkiServerVerifier>,
561}
562
563#[cfg(feature = "tls-rustls-insecure")]
564fn is_hostname_error(err: &rustls::Error) -> bool {
565    matches!(
566        err,
567        rustls::Error::InvalidCertificate(
568            rustls::CertificateError::NotValidForName
569                | rustls::CertificateError::NotValidForNameContext { .. }
570        )
571    )
572}
573
574#[cfg(feature = "tls-rustls-insecure")]
575impl rustls::client::danger::ServerCertVerifier for AcceptInvalidHostnamesCertVerifier {
576    fn verify_server_cert(
577        &self,
578        end_entity: &rustls::pki_types::CertificateDer<'_>,
579        intermediates: &[rustls::pki_types::CertificateDer<'_>],
580        server_name: &rustls::pki_types::ServerName<'_>,
581        ocsp_response: &[u8],
582        now: rustls::pki_types::UnixTime,
583    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
584        self.inner
585            .verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)
586            .or_else(|err| {
587                if is_hostname_error(&err) {
588                    Ok(rustls::client::danger::ServerCertVerified::assertion())
589                } else {
590                    Err(err)
591                }
592            })
593    }
594
595    fn verify_tls12_signature(
596        &self,
597        message: &[u8],
598        cert: &rustls::pki_types::CertificateDer<'_>,
599        dss: &rustls::DigitallySignedStruct,
600    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
601        self.inner
602            .verify_tls12_signature(message, cert, dss)
603            .or_else(|err| {
604                if is_hostname_error(&err) {
605                    Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
606                } else {
607                    Err(err)
608                }
609            })
610    }
611
612    fn verify_tls13_signature(
613        &self,
614        message: &[u8],
615        cert: &rustls::pki_types::CertificateDer<'_>,
616        dss: &rustls::DigitallySignedStruct,
617    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
618        self.inner
619            .verify_tls13_signature(message, cert, dss)
620            .or_else(|err| {
621                if is_hostname_error(&err) {
622                    Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
623                } else {
624                    Err(err)
625                }
626            })
627    }
628
629    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
630        self.inner.supported_verify_schemes()
631    }
632}
633
634/// Represents a stateful redis TCP connection.
635pub struct Connection {
636    con: ActualConnection,
637    parser: Parser,
638    db: i64,
639
640    /// Flag indicating whether the connection was left in the PubSub state after dropping `PubSub`.
641    ///
642    /// This flag is checked when attempting to send a command, and if it's raised, we attempt to
643    /// exit the pubsub state before executing the new request.
644    pubsub: bool,
645
646    // Field indicating which protocol to use for server communications.
647    protocol: ProtocolVersion,
648
649    /// This is used to manage Push messages in RESP3 mode.
650    push_sender: Option<SyncPushSender>,
651
652    /// The number of messages that are expected to be returned from the server,
653    /// but the user no longer waits for - answers for requests that already returned a transient error.
654    messages_to_skip: usize,
655}
656
657/// Represents a pubsub connection.
658pub struct PubSub<'a> {
659    con: &'a mut Connection,
660    waiting_messages: VecDeque<Msg>,
661}
662
663/// Represents a pubsub message.
664#[derive(Debug, Clone)]
665pub struct Msg {
666    payload: Value,
667    channel: Value,
668    pattern: Option<Value>,
669}
670
671impl ActualConnection {
672    pub fn new(addr: &ConnectionAddr, timeout: Option<Duration>) -> RedisResult<ActualConnection> {
673        Ok(match *addr {
674            ConnectionAddr::Tcp(ref host, ref port) => {
675                let addr = (host.as_str(), *port);
676                let tcp = match timeout {
677                    None => connect_tcp(addr)?,
678                    Some(timeout) => {
679                        let mut tcp = None;
680                        let mut last_error = None;
681                        for addr in addr.to_socket_addrs()? {
682                            match connect_tcp_timeout(&addr, timeout) {
683                                Ok(l) => {
684                                    tcp = Some(l);
685                                    break;
686                                }
687                                Err(e) => {
688                                    last_error = Some(e);
689                                }
690                            };
691                        }
692                        match (tcp, last_error) {
693                            (Some(tcp), _) => tcp,
694                            (None, Some(e)) => {
695                                fail!(e);
696                            }
697                            (None, None) => {
698                                fail!((
699                                    ErrorKind::InvalidClientConfig,
700                                    "could not resolve to any addresses"
701                                ));
702                            }
703                        }
704                    }
705                };
706                ActualConnection::Tcp(TcpConnection {
707                    reader: tcp,
708                    open: true,
709                })
710            }
711            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
712            ConnectionAddr::TcpTls {
713                ref host,
714                port,
715                insecure,
716                ref tls_params,
717            } => {
718                let tls_connector = if insecure {
719                    TlsConnector::builder()
720                        .danger_accept_invalid_certs(true)
721                        .danger_accept_invalid_hostnames(true)
722                        .use_sni(false)
723                        .build()?
724                } else if let Some(params) = tls_params {
725                    TlsConnector::builder()
726                        .danger_accept_invalid_hostnames(params.danger_accept_invalid_hostnames)
727                        .build()?
728                } else {
729                    TlsConnector::new()?
730                };
731                let addr = (host.as_str(), port);
732                let tls = match timeout {
733                    None => {
734                        let tcp = connect_tcp(addr)?;
735                        match tls_connector.connect(host, tcp) {
736                            Ok(res) => res,
737                            Err(e) => {
738                                fail!((ErrorKind::IoError, "SSL Handshake error", e.to_string()));
739                            }
740                        }
741                    }
742                    Some(timeout) => {
743                        let mut tcp = None;
744                        let mut last_error = None;
745                        for addr in (host.as_str(), port).to_socket_addrs()? {
746                            match connect_tcp_timeout(&addr, timeout) {
747                                Ok(l) => {
748                                    tcp = Some(l);
749                                    break;
750                                }
751                                Err(e) => {
752                                    last_error = Some(e);
753                                }
754                            };
755                        }
756                        match (tcp, last_error) {
757                            (Some(tcp), _) => tls_connector.connect(host, tcp).unwrap(),
758                            (None, Some(e)) => {
759                                fail!(e);
760                            }
761                            (None, None) => {
762                                fail!((
763                                    ErrorKind::InvalidClientConfig,
764                                    "could not resolve to any addresses"
765                                ));
766                            }
767                        }
768                    }
769                };
770                ActualConnection::TcpNativeTls(Box::new(TcpNativeTlsConnection {
771                    reader: tls,
772                    open: true,
773                }))
774            }
775            #[cfg(feature = "tls-rustls")]
776            ConnectionAddr::TcpTls {
777                ref host,
778                port,
779                insecure,
780                ref tls_params,
781            } => {
782                let host: &str = host;
783                let config = create_rustls_config(insecure, tls_params.clone())?;
784                let conn = rustls::ClientConnection::new(
785                    Arc::new(config),
786                    rustls::pki_types::ServerName::try_from(host)?.to_owned(),
787                )?;
788                let reader = match timeout {
789                    None => {
790                        let tcp = connect_tcp((host, port))?;
791                        StreamOwned::new(conn, tcp)
792                    }
793                    Some(timeout) => {
794                        let mut tcp = None;
795                        let mut last_error = None;
796                        for addr in (host, port).to_socket_addrs()? {
797                            match connect_tcp_timeout(&addr, timeout) {
798                                Ok(l) => {
799                                    tcp = Some(l);
800                                    break;
801                                }
802                                Err(e) => {
803                                    last_error = Some(e);
804                                }
805                            };
806                        }
807                        match (tcp, last_error) {
808                            (Some(tcp), _) => StreamOwned::new(conn, tcp),
809                            (None, Some(e)) => {
810                                fail!(e);
811                            }
812                            (None, None) => {
813                                fail!((
814                                    ErrorKind::InvalidClientConfig,
815                                    "could not resolve to any addresses"
816                                ));
817                            }
818                        }
819                    }
820                };
821
822                ActualConnection::TcpRustls(Box::new(TcpRustlsConnection { reader, open: true }))
823            }
824            #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
825            ConnectionAddr::TcpTls { .. } => {
826                fail!((
827                    ErrorKind::InvalidClientConfig,
828                    "Cannot connect to TCP with TLS without the tls feature"
829                ));
830            }
831            #[cfg(unix)]
832            ConnectionAddr::Unix(ref path) => ActualConnection::Unix(UnixConnection {
833                sock: UnixStream::connect(path)?,
834                open: true,
835            }),
836            #[cfg(not(unix))]
837            ConnectionAddr::Unix(ref _path) => {
838                fail!((
839                    ErrorKind::InvalidClientConfig,
840                    "Cannot connect to unix sockets \
841                     on this platform"
842                ));
843            }
844        })
845    }
846
847    pub fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult<Value> {
848        match *self {
849            ActualConnection::Tcp(ref mut connection) => {
850                let res = connection.reader.write_all(bytes).map_err(RedisError::from);
851                match res {
852                    Err(e) => {
853                        if e.is_unrecoverable_error() {
854                            connection.open = false;
855                        }
856                        Err(e)
857                    }
858                    Ok(_) => Ok(Value::Okay),
859                }
860            }
861            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
862            ActualConnection::TcpNativeTls(ref mut connection) => {
863                let res = connection.reader.write_all(bytes).map_err(RedisError::from);
864                match res {
865                    Err(e) => {
866                        if e.is_unrecoverable_error() {
867                            connection.open = false;
868                        }
869                        Err(e)
870                    }
871                    Ok(_) => Ok(Value::Okay),
872                }
873            }
874            #[cfg(feature = "tls-rustls")]
875            ActualConnection::TcpRustls(ref mut connection) => {
876                let res = connection.reader.write_all(bytes).map_err(RedisError::from);
877                match res {
878                    Err(e) => {
879                        if e.is_unrecoverable_error() {
880                            connection.open = false;
881                        }
882                        Err(e)
883                    }
884                    Ok(_) => Ok(Value::Okay),
885                }
886            }
887            #[cfg(unix)]
888            ActualConnection::Unix(ref mut connection) => {
889                let result = connection.sock.write_all(bytes).map_err(RedisError::from);
890                match result {
891                    Err(e) => {
892                        if e.is_unrecoverable_error() {
893                            connection.open = false;
894                        }
895                        Err(e)
896                    }
897                    Ok(_) => Ok(Value::Okay),
898                }
899            }
900        }
901    }
902
903    pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
904        match *self {
905            ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
906                reader.set_write_timeout(dur)?;
907            }
908            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
909            ActualConnection::TcpNativeTls(ref boxed_tls_connection) => {
910                let reader = &(boxed_tls_connection.reader);
911                reader.get_ref().set_write_timeout(dur)?;
912            }
913            #[cfg(feature = "tls-rustls")]
914            ActualConnection::TcpRustls(ref boxed_tls_connection) => {
915                let reader = &(boxed_tls_connection.reader);
916                reader.get_ref().set_write_timeout(dur)?;
917            }
918            #[cfg(unix)]
919            ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
920                sock.set_write_timeout(dur)?;
921            }
922        }
923        Ok(())
924    }
925
926    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
927        match *self {
928            ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
929                reader.set_read_timeout(dur)?;
930            }
931            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
932            ActualConnection::TcpNativeTls(ref boxed_tls_connection) => {
933                let reader = &(boxed_tls_connection.reader);
934                reader.get_ref().set_read_timeout(dur)?;
935            }
936            #[cfg(feature = "tls-rustls")]
937            ActualConnection::TcpRustls(ref boxed_tls_connection) => {
938                let reader = &(boxed_tls_connection.reader);
939                reader.get_ref().set_read_timeout(dur)?;
940            }
941            #[cfg(unix)]
942            ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
943                sock.set_read_timeout(dur)?;
944            }
945        }
946        Ok(())
947    }
948
949    pub fn is_open(&self) -> bool {
950        match *self {
951            ActualConnection::Tcp(TcpConnection { open, .. }) => open,
952            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
953            ActualConnection::TcpNativeTls(ref boxed_tls_connection) => boxed_tls_connection.open,
954            #[cfg(feature = "tls-rustls")]
955            ActualConnection::TcpRustls(ref boxed_tls_connection) => boxed_tls_connection.open,
956            #[cfg(unix)]
957            ActualConnection::Unix(UnixConnection { open, .. }) => open,
958        }
959    }
960}
961
962#[cfg(feature = "tls-rustls")]
963pub(crate) fn create_rustls_config(
964    insecure: bool,
965    tls_params: Option<TlsConnParams>,
966) -> RedisResult<rustls::ClientConfig> {
967    #[allow(unused_mut)]
968    let mut root_store = RootCertStore::empty();
969    #[cfg(feature = "tls-rustls-webpki-roots")]
970    root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
971    #[cfg(all(
972        feature = "tls-rustls",
973        not(feature = "tls-native-tls"),
974        not(feature = "tls-rustls-webpki-roots")
975    ))]
976    {
977        let mut certificate_result = load_native_certs();
978        if let Some(error) = certificate_result.errors.pop() {
979            return Err(error.into());
980        }
981        for cert in certificate_result.certs {
982            root_store.add(cert)?;
983        }
984    }
985
986    let config = rustls::ClientConfig::builder();
987    let config = if let Some(tls_params) = tls_params {
988        let root_cert_store = tls_params.root_cert_store.unwrap_or(root_store);
989        let config_builder = config.with_root_certificates(root_cert_store.clone());
990
991        let config_builder = if let Some(ClientTlsParams {
992            client_cert_chain: client_cert,
993            client_key,
994        }) = tls_params.client_tls_params
995        {
996            config_builder
997                .with_client_auth_cert(client_cert, client_key)
998                .map_err(|err| {
999                    RedisError::from((
1000                        ErrorKind::InvalidClientConfig,
1001                        "Unable to build client with TLS parameters provided.",
1002                        err.to_string(),
1003                    ))
1004                })?
1005        } else {
1006            config_builder.with_no_client_auth()
1007        };
1008
1009        // Implement `danger_accept_invalid_hostnames`.
1010        //
1011        // The strange cfg here is to handle a specific unusual combination of features: if
1012        // `tls-native-tls` and `tls-rustls` are enabled, but `tls-rustls-insecure` is not, and the
1013        // application tries to use the danger flag.
1014        #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
1015        let config_builder = if !insecure && tls_params.danger_accept_invalid_hostnames {
1016            #[cfg(not(feature = "tls-rustls-insecure"))]
1017            {
1018                // This code should not enable an insecure mode if the `insecure` feature is not
1019                // set, but it shouldn't silently ignore the flag either. So return an error.
1020                fail!((
1021                    ErrorKind::InvalidClientConfig,
1022                    "Cannot create insecure client via danger_accept_invalid_hostnames without tls-rustls-insecure feature"
1023                ));
1024            }
1025
1026            #[cfg(feature = "tls-rustls-insecure")]
1027            {
1028                let mut config = config_builder;
1029                config.dangerous().set_certificate_verifier(Arc::new(
1030                    AcceptInvalidHostnamesCertVerifier {
1031                        inner: rustls::client::WebPkiServerVerifier::builder(Arc::new(
1032                            root_cert_store,
1033                        ))
1034                        .build()
1035                        .map_err(|err| rustls::Error::from(rustls::OtherError(Arc::new(err))))?,
1036                    },
1037                ));
1038                config
1039            }
1040        } else {
1041            config_builder
1042        };
1043
1044        config_builder
1045    } else {
1046        config
1047            .with_root_certificates(root_store)
1048            .with_no_client_auth()
1049    };
1050
1051    match (insecure, cfg!(feature = "tls-rustls-insecure")) {
1052        #[cfg(feature = "tls-rustls-insecure")]
1053        (true, true) => {
1054            let mut config = config;
1055            config.enable_sni = false;
1056            let Some(crypto_provider) = rustls::crypto::CryptoProvider::get_default() else {
1057                return Err(RedisError::from((
1058                    ErrorKind::InvalidClientConfig,
1059                    "No crypto provider available for rustls",
1060                )));
1061            };
1062            config
1063                .dangerous()
1064                .set_certificate_verifier(Arc::new(NoCertificateVerification {
1065                    supported: crypto_provider.signature_verification_algorithms,
1066                }));
1067
1068            Ok(config)
1069        }
1070        (true, false) => {
1071            fail!((
1072                ErrorKind::InvalidClientConfig,
1073                "Cannot create insecure client without tls-rustls-insecure feature"
1074            ));
1075        }
1076        _ => Ok(config),
1077    }
1078}
1079
1080fn authenticate_cmd(
1081    connection_info: &RedisConnectionInfo,
1082    check_username: bool,
1083    password: &str,
1084) -> Cmd {
1085    let mut command = cmd("AUTH");
1086    if check_username {
1087        if let Some(username) = &connection_info.username {
1088            command.arg(username);
1089        }
1090    }
1091    command.arg(password);
1092    command
1093}
1094
1095pub fn connect(
1096    connection_info: &ConnectionInfo,
1097    timeout: Option<Duration>,
1098) -> RedisResult<Connection> {
1099    let start = Instant::now();
1100    let con: ActualConnection = ActualConnection::new(&connection_info.addr, timeout)?;
1101
1102    // we temporarily set the timeout, and will remove it after finishing setup.
1103    let remaining_timeout = timeout.and_then(|timeout| timeout.checked_sub(start.elapsed()));
1104    // TLS could run logic that doesn't contain a timeout, and should fail if it takes too long.
1105    if timeout.is_some() && remaining_timeout.is_none() {
1106        return Err(RedisError::from(std::io::Error::new(
1107            std::io::ErrorKind::TimedOut,
1108            "Connection timed out",
1109        )));
1110    }
1111    con.set_read_timeout(remaining_timeout)?;
1112    con.set_write_timeout(remaining_timeout)?;
1113
1114    let con = setup_connection(
1115        con,
1116        &connection_info.redis,
1117        #[cfg(feature = "cache-aio")]
1118        None,
1119    )?;
1120
1121    // remove the temporary timeout.
1122    con.set_read_timeout(None)?;
1123    con.set_write_timeout(None)?;
1124
1125    Ok(con)
1126}
1127
1128pub(crate) struct ConnectionSetupComponents {
1129    resp3_auth_cmd_idx: Option<usize>,
1130    resp2_auth_cmd_idx: Option<usize>,
1131    select_cmd_idx: Option<usize>,
1132    #[cfg(feature = "cache-aio")]
1133    cache_cmd_idx: Option<usize>,
1134}
1135
1136pub(crate) fn connection_setup_pipeline(
1137    connection_info: &RedisConnectionInfo,
1138    check_username: bool,
1139    #[cfg(feature = "cache-aio")] cache_config: Option<crate::caching::CacheConfig>,
1140) -> (crate::Pipeline, ConnectionSetupComponents) {
1141    let mut pipeline = pipe();
1142    let (authenticate_with_resp3_cmd_index, authenticate_with_resp2_cmd_index) =
1143        if connection_info.protocol != ProtocolVersion::RESP2 {
1144            pipeline.add_command(resp3_hello(connection_info));
1145            (Some(0), None)
1146        } else if connection_info.password.is_some() {
1147            pipeline.add_command(authenticate_cmd(
1148                connection_info,
1149                check_username,
1150                connection_info.password.as_ref().unwrap(),
1151            ));
1152            (None, Some(0))
1153        } else {
1154            (None, None)
1155        };
1156
1157    let select_db_cmd_index = (connection_info.db != 0)
1158        .then(|| pipeline.len())
1159        .inspect(|_| {
1160            pipeline.cmd("SELECT").arg(connection_info.db);
1161        });
1162
1163    #[cfg(feature = "cache-aio")]
1164    let cache_cmd_index = cache_config.map(|cache_config| {
1165        pipeline.cmd("CLIENT").arg("TRACKING").arg("ON");
1166        match cache_config.mode {
1167            crate::caching::CacheMode::All => {}
1168            crate::caching::CacheMode::OptIn => {
1169                pipeline.arg("OPTIN");
1170            }
1171        }
1172        pipeline.len() - 1
1173    });
1174
1175    // result is ignored, as per the command's instructions.
1176    // https://redis.io/commands/client-setinfo/
1177    #[cfg(not(feature = "disable-client-setinfo"))]
1178    pipeline
1179        .cmd("CLIENT")
1180        .arg("SETINFO")
1181        .arg("LIB-NAME")
1182        .arg("redis-rs")
1183        .ignore();
1184    #[cfg(not(feature = "disable-client-setinfo"))]
1185    pipeline
1186        .cmd("CLIENT")
1187        .arg("SETINFO")
1188        .arg("LIB-VER")
1189        .arg(env!("CARGO_PKG_VERSION"))
1190        .ignore();
1191
1192    (
1193        pipeline,
1194        ConnectionSetupComponents {
1195            resp3_auth_cmd_idx: authenticate_with_resp3_cmd_index,
1196            resp2_auth_cmd_idx: authenticate_with_resp2_cmd_index,
1197            select_cmd_idx: select_db_cmd_index,
1198            #[cfg(feature = "cache-aio")]
1199            cache_cmd_idx: cache_cmd_index,
1200        },
1201    )
1202}
1203
1204fn check_resp3_auth(result: &Value) -> RedisResult<()> {
1205    if let Value::ServerError(err) = result {
1206        return Err(get_resp3_hello_command_error(err.clone().into()));
1207    }
1208    Ok(())
1209}
1210
1211#[derive(PartialEq)]
1212pub(crate) enum AuthResult {
1213    Succeeded,
1214    ShouldRetryWithoutUsername,
1215}
1216
1217fn check_resp2_auth(result: &Value) -> RedisResult<AuthResult> {
1218    let err = match result {
1219        Value::Okay => {
1220            return Ok(AuthResult::Succeeded);
1221        }
1222        Value::ServerError(err) => err,
1223        _ => {
1224            return Err((
1225                ErrorKind::ResponseError,
1226                "Redis server refused to authenticate, returns Ok() != Value::Okay",
1227            )
1228                .into());
1229        }
1230    };
1231
1232    let err_msg = err.details().ok_or((
1233        ErrorKind::AuthenticationFailed,
1234        "Password authentication failed",
1235    ))?;
1236    if !err_msg.contains("wrong number of arguments for 'auth' command") {
1237        return Err((
1238            ErrorKind::AuthenticationFailed,
1239            "Password authentication failed",
1240        )
1241            .into());
1242    }
1243    Ok(AuthResult::ShouldRetryWithoutUsername)
1244}
1245
1246fn check_db_select(value: &Value) -> RedisResult<()> {
1247    let Value::ServerError(err) = value else {
1248        return Ok(());
1249    };
1250
1251    match err.details() {
1252        Some(err_msg) => Err((
1253            ErrorKind::ResponseError,
1254            "Redis server refused to switch database",
1255            err_msg.to_string(),
1256        )
1257            .into()),
1258        None => Err((
1259            ErrorKind::ResponseError,
1260            "Redis server refused to switch database",
1261        )
1262            .into()),
1263    }
1264}
1265
1266#[cfg(feature = "cache-aio")]
1267fn check_caching(result: &Value) -> RedisResult<()> {
1268    match result {
1269        Value::Okay => Ok(()),
1270        _ => Err((
1271            ErrorKind::ResponseError,
1272            "Client-side caching returned unknown response",
1273        )
1274            .into()),
1275    }
1276}
1277
1278pub(crate) fn check_connection_setup(
1279    results: Vec<Value>,
1280    ConnectionSetupComponents {
1281        resp3_auth_cmd_idx,
1282        resp2_auth_cmd_idx,
1283        select_cmd_idx,
1284        #[cfg(feature = "cache-aio")]
1285        cache_cmd_idx,
1286    }: ConnectionSetupComponents,
1287) -> RedisResult<AuthResult> {
1288    // can't have both values set
1289    assert!(!(resp2_auth_cmd_idx.is_some() && resp3_auth_cmd_idx.is_some()));
1290
1291    if let Some(index) = resp3_auth_cmd_idx {
1292        let Some(value) = results.get(index) else {
1293            return Err((ErrorKind::ClientError, "Missing RESP3 auth response").into());
1294        };
1295        check_resp3_auth(value)?;
1296    } else if let Some(index) = resp2_auth_cmd_idx {
1297        let Some(value) = results.get(index) else {
1298            return Err((ErrorKind::ClientError, "Missing RESP2 auth response").into());
1299        };
1300        if check_resp2_auth(value)? == AuthResult::ShouldRetryWithoutUsername {
1301            return Ok(AuthResult::ShouldRetryWithoutUsername);
1302        }
1303    }
1304
1305    if let Some(index) = select_cmd_idx {
1306        let Some(value) = results.get(index) else {
1307            return Err((ErrorKind::ClientError, "Missing SELECT DB response").into());
1308        };
1309        check_db_select(value)?;
1310    }
1311
1312    #[cfg(feature = "cache-aio")]
1313    if let Some(index) = cache_cmd_idx {
1314        let Some(value) = results.get(index) else {
1315            return Err((ErrorKind::ClientError, "Missing Caching response").into());
1316        };
1317        check_caching(value)?;
1318    }
1319
1320    Ok(AuthResult::Succeeded)
1321}
1322
1323fn execute_connection_pipeline(
1324    rv: &mut Connection,
1325    (pipeline, instructions): (crate::Pipeline, ConnectionSetupComponents),
1326) -> RedisResult<AuthResult> {
1327    if pipeline.is_empty() {
1328        return Ok(AuthResult::Succeeded);
1329    }
1330    let results = rv.req_packed_commands(&pipeline.get_packed_pipeline(), 0, pipeline.len())?;
1331
1332    check_connection_setup(results, instructions)
1333}
1334
1335fn setup_connection(
1336    con: ActualConnection,
1337    connection_info: &RedisConnectionInfo,
1338    #[cfg(feature = "cache-aio")] cache_config: Option<crate::caching::CacheConfig>,
1339) -> RedisResult<Connection> {
1340    let mut rv = Connection {
1341        con,
1342        parser: Parser::new(),
1343        db: connection_info.db,
1344        pubsub: false,
1345        protocol: connection_info.protocol,
1346        push_sender: None,
1347        messages_to_skip: 0,
1348    };
1349
1350    if execute_connection_pipeline(
1351        &mut rv,
1352        connection_setup_pipeline(
1353            connection_info,
1354            true,
1355            #[cfg(feature = "cache-aio")]
1356            cache_config,
1357        ),
1358    )? == AuthResult::ShouldRetryWithoutUsername
1359    {
1360        execute_connection_pipeline(
1361            &mut rv,
1362            connection_setup_pipeline(
1363                connection_info,
1364                false,
1365                #[cfg(feature = "cache-aio")]
1366                cache_config,
1367            ),
1368        )?;
1369    }
1370
1371    Ok(rv)
1372}
1373
1374/// Implements the "stateless" part of the connection interface that is used by the
1375/// different objects in redis-rs.
1376///
1377/// Primarily it obviously applies to `Connection` object but also some other objects
1378///  implement the interface (for instance whole clients or certain redis results).
1379///
1380/// Generally clients and connections (as well as redis results of those) implement
1381/// this trait.  Actual connections provide more functionality which can be used
1382/// to implement things like `PubSub` but they also can modify the intrinsic
1383/// state of the TCP connection.  This is not possible with `ConnectionLike`
1384/// implementors because that functionality is not exposed.
1385pub trait ConnectionLike {
1386    /// Sends an already encoded (packed) command into the TCP socket and
1387    /// reads the single response from it.
1388    fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value>;
1389
1390    /// Sends multiple already encoded (packed) command into the TCP socket
1391    /// and reads `count` responses from it.  This is used to implement
1392    /// pipelining.
1393    /// Important - this function is meant for internal usage, since it's
1394    /// easy to pass incorrect `offset` & `count` parameters, which might
1395    /// cause the connection to enter an erroneous state. Users shouldn't
1396    /// call it, instead using the Pipeline::query function.
1397    #[doc(hidden)]
1398    fn req_packed_commands(
1399        &mut self,
1400        cmd: &[u8],
1401        offset: usize,
1402        count: usize,
1403    ) -> RedisResult<Vec<Value>>;
1404
1405    /// Sends a [Cmd] into the TCP socket and reads a single response from it.
1406    fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1407        let pcmd = cmd.get_packed_command();
1408        self.req_packed_command(&pcmd)
1409    }
1410
1411    /// Returns the database this connection is bound to.  Note that this
1412    /// information might be unreliable because it's initially cached and
1413    /// also might be incorrect if the connection like object is not
1414    /// actually connected.
1415    fn get_db(&self) -> i64;
1416
1417    /// Does this connection support pipelining?
1418    #[doc(hidden)]
1419    fn supports_pipelining(&self) -> bool {
1420        true
1421    }
1422
1423    /// Check that all connections it has are available (`PING` internally).
1424    fn check_connection(&mut self) -> bool;
1425
1426    /// Returns the connection status.
1427    ///
1428    /// The connection is open until any `read` call received an
1429    /// invalid response from the server (most likely a closed or dropped
1430    /// connection, otherwise a Redis protocol error). When using unix
1431    /// sockets the connection is open until writing a command failed with a
1432    /// `BrokenPipe` error.
1433    fn is_open(&self) -> bool;
1434}
1435
1436/// A connection is an object that represents a single redis connection.  It
1437/// provides basic support for sending encoded commands into a redis connection
1438/// and to read a response from it.  It's bound to a single database and can
1439/// only be created from the client.
1440///
1441/// You generally do not much with this object other than passing it to
1442/// `Cmd` objects.
1443impl Connection {
1444    /// Sends an already encoded (packed) command into the TCP socket and
1445    /// does not read a response.  This is useful for commands like
1446    /// `MONITOR` which yield multiple items.  This needs to be used with
1447    /// care because it changes the state of the connection.
1448    pub fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()> {
1449        self.send_bytes(cmd)?;
1450        Ok(())
1451    }
1452
1453    /// Fetches a single response from the connection.  This is useful
1454    /// if used in combination with `send_packed_command`.
1455    pub fn recv_response(&mut self) -> RedisResult<Value> {
1456        self.read(true)
1457    }
1458
1459    /// Sets the write timeout for the connection.
1460    ///
1461    /// If the provided value is `None`, then `send_packed_command` call will
1462    /// block indefinitely. It is an error to pass the zero `Duration` to this
1463    /// method.
1464    pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1465        self.con.set_write_timeout(dur)
1466    }
1467
1468    /// Sets the read timeout for the connection.
1469    ///
1470    /// If the provided value is `None`, then `recv_response` call will
1471    /// block indefinitely. It is an error to pass the zero `Duration` to this
1472    /// method.
1473    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1474        self.con.set_read_timeout(dur)
1475    }
1476
1477    /// Creates a [`PubSub`] instance for this connection.
1478    pub fn as_pubsub(&mut self) -> PubSub<'_> {
1479        // NOTE: The pubsub flag is intentionally not raised at this time since
1480        // running commands within the pubsub state should not try and exit from
1481        // the pubsub state.
1482        PubSub::new(self)
1483    }
1484
1485    fn exit_pubsub(&mut self) -> RedisResult<()> {
1486        let res = self.clear_active_subscriptions();
1487        if res.is_ok() {
1488            self.pubsub = false;
1489        } else {
1490            // Raise the pubsub flag to indicate the connection is "stuck" in that state.
1491            self.pubsub = true;
1492        }
1493
1494        res
1495    }
1496
1497    /// Get the inner connection out of a PubSub
1498    ///
1499    /// Any active subscriptions are unsubscribed. In the event of an error, the connection is
1500    /// dropped.
1501    fn clear_active_subscriptions(&mut self) -> RedisResult<()> {
1502        // Responses to unsubscribe commands return in a 3-tuple with values
1503        // ("unsubscribe" or "punsubscribe", name of subscription removed, count of remaining subs).
1504        // The "count of remaining subs" includes both pattern subscriptions and non pattern
1505        // subscriptions. Thus, to accurately drain all unsubscribe messages received from the
1506        // server, both commands need to be executed at once.
1507        {
1508            // Prepare both unsubscribe commands
1509            let unsubscribe = cmd("UNSUBSCRIBE").get_packed_command();
1510            let punsubscribe = cmd("PUNSUBSCRIBE").get_packed_command();
1511
1512            // Execute commands
1513            self.send_bytes(&unsubscribe)?;
1514            self.send_bytes(&punsubscribe)?;
1515        }
1516
1517        // Receive responses
1518        //
1519        // There will be at minimum two responses - 1 for each of punsubscribe and unsubscribe
1520        // commands. There may be more responses if there are active subscriptions. In this case,
1521        // messages are received until the _subscription count_ in the responses reach zero.
1522        let mut received_unsub = false;
1523        let mut received_punsub = false;
1524
1525        loop {
1526            let resp = self.recv_response()?;
1527
1528            match resp {
1529                Value::Push { kind, data } => {
1530                    if data.len() >= 2 {
1531                        if let Value::Int(num) = data[1] {
1532                            if resp3_is_pub_sub_state_cleared(
1533                                &mut received_unsub,
1534                                &mut received_punsub,
1535                                &kind,
1536                                num as isize,
1537                            ) {
1538                                break;
1539                            }
1540                        }
1541                    }
1542                }
1543                Value::ServerError(err) => {
1544                    // a new error behavior, introduced in valkey 8.
1545                    // https://github.com/valkey-io/valkey/pull/759
1546                    if err.kind() == Some(ServerErrorKind::NoSub) {
1547                        if no_sub_err_is_pub_sub_state_cleared(
1548                            &mut received_unsub,
1549                            &mut received_punsub,
1550                            &err,
1551                        ) {
1552                            break;
1553                        } else {
1554                            continue;
1555                        }
1556                    }
1557
1558                    return Err(err.into());
1559                }
1560                Value::Array(vec) => {
1561                    let res: (Vec<u8>, (), isize) = from_owned_redis_value(Value::Array(vec))?;
1562                    if resp2_is_pub_sub_state_cleared(
1563                        &mut received_unsub,
1564                        &mut received_punsub,
1565                        &res.0,
1566                        res.2,
1567                    ) {
1568                        break;
1569                    }
1570                }
1571                _ => {
1572                    return Err((
1573                        ErrorKind::ClientError,
1574                        "Unexpected unsubscribe response",
1575                        format!("{resp:?}"),
1576                    )
1577                        .into())
1578                }
1579            }
1580        }
1581
1582        // Finally, the connection is back in its normal state since all subscriptions were
1583        // cancelled *and* all unsubscribe messages were received.
1584        Ok(())
1585    }
1586
1587    fn send_push(&self, push: PushInfo) {
1588        if let Some(sender) = &self.push_sender {
1589            let _ = sender.send(push);
1590        }
1591    }
1592
1593    fn try_send(&self, value: &RedisResult<Value>) {
1594        if let Ok(Value::Push { kind, data }) = value {
1595            self.send_push(PushInfo {
1596                kind: kind.clone(),
1597                data: data.clone(),
1598            });
1599        }
1600    }
1601
1602    fn send_disconnect(&self) {
1603        self.send_push(PushInfo::disconnect())
1604    }
1605
1606    fn close_connection(&mut self) {
1607        // Notify the PushManager that the connection was lost
1608        self.send_disconnect();
1609        match self.con {
1610            ActualConnection::Tcp(ref mut connection) => {
1611                let _ = connection.reader.shutdown(net::Shutdown::Both);
1612                connection.open = false;
1613            }
1614            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1615            ActualConnection::TcpNativeTls(ref mut connection) => {
1616                let _ = connection.reader.shutdown();
1617                connection.open = false;
1618            }
1619            #[cfg(feature = "tls-rustls")]
1620            ActualConnection::TcpRustls(ref mut connection) => {
1621                let _ = connection.reader.get_mut().shutdown(net::Shutdown::Both);
1622                connection.open = false;
1623            }
1624            #[cfg(unix)]
1625            ActualConnection::Unix(ref mut connection) => {
1626                let _ = connection.sock.shutdown(net::Shutdown::Both);
1627                connection.open = false;
1628            }
1629        }
1630    }
1631
1632    /// Fetches a single message from the connection. If the message is a response,
1633    /// increment `messages_to_skip` if it wasn't received before a timeout.
1634    fn read(&mut self, is_response: bool) -> RedisResult<Value> {
1635        loop {
1636            let result = match self.con {
1637                ActualConnection::Tcp(TcpConnection { ref mut reader, .. }) => {
1638                    self.parser.parse_value(reader)
1639                }
1640                #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1641                ActualConnection::TcpNativeTls(ref mut boxed_tls_connection) => {
1642                    let reader = &mut boxed_tls_connection.reader;
1643                    self.parser.parse_value(reader)
1644                }
1645                #[cfg(feature = "tls-rustls")]
1646                ActualConnection::TcpRustls(ref mut boxed_tls_connection) => {
1647                    let reader = &mut boxed_tls_connection.reader;
1648                    self.parser.parse_value(reader)
1649                }
1650                #[cfg(unix)]
1651                ActualConnection::Unix(UnixConnection { ref mut sock, .. }) => {
1652                    self.parser.parse_value(sock)
1653                }
1654            };
1655            self.try_send(&result);
1656
1657            let Err(err) = &result else {
1658                if self.messages_to_skip > 0 {
1659                    self.messages_to_skip -= 1;
1660                    continue;
1661                }
1662                return result;
1663            };
1664            let Some(io_error) = err.as_io_error() else {
1665                if self.messages_to_skip > 0 {
1666                    self.messages_to_skip -= 1;
1667                    continue;
1668                }
1669                return result;
1670            };
1671            // shutdown connection on protocol error
1672            if io_error.kind() == io::ErrorKind::UnexpectedEof {
1673                self.close_connection();
1674            } else if is_response {
1675                self.messages_to_skip += 1;
1676            }
1677
1678            return result;
1679        }
1680    }
1681
1682    /// Sets sender channel for push values.
1683    pub fn set_push_sender(&mut self, sender: SyncPushSender) {
1684        self.push_sender = Some(sender);
1685    }
1686
1687    fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult<Value> {
1688        let result = self.con.send_bytes(bytes);
1689        if self.protocol != ProtocolVersion::RESP2 {
1690            if let Err(e) = &result {
1691                if e.is_connection_dropped() {
1692                    self.send_disconnect();
1693                }
1694            }
1695        }
1696        result
1697    }
1698}
1699
1700impl ConnectionLike for Connection {
1701    /// Sends a [Cmd] into the TCP socket and reads a single response from it.
1702    fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1703        let pcmd = cmd.get_packed_command();
1704        if self.pubsub {
1705            self.exit_pubsub()?;
1706        }
1707
1708        self.send_bytes(&pcmd)?;
1709        if cmd.is_no_response() {
1710            return Ok(Value::Nil);
1711        }
1712        loop {
1713            match self.read(true)? {
1714                Value::Push {
1715                    kind: _kind,
1716                    data: _data,
1717                } => continue,
1718                val => return Ok(val),
1719            }
1720        }
1721    }
1722    fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
1723        if self.pubsub {
1724            self.exit_pubsub()?;
1725        }
1726
1727        self.send_bytes(cmd)?;
1728        loop {
1729            match self.read(true)? {
1730                Value::Push {
1731                    kind: _kind,
1732                    data: _data,
1733                } => continue,
1734                val => return Ok(val),
1735            }
1736        }
1737    }
1738
1739    fn req_packed_commands(
1740        &mut self,
1741        cmd: &[u8],
1742        offset: usize,
1743        count: usize,
1744    ) -> RedisResult<Vec<Value>> {
1745        if self.pubsub {
1746            self.exit_pubsub()?;
1747        }
1748        self.send_bytes(cmd)?;
1749        let mut rv = vec![];
1750        let mut first_err = None;
1751        let mut count = count;
1752        let mut idx = 0;
1753        while idx < (offset + count) {
1754            // When processing a transaction, some responses may be errors.
1755            // We need to keep processing the rest of the responses in that case,
1756            // so bailing early with `?` would not be correct.
1757            // See: https://github.com/redis-rs/redis-rs/issues/436
1758            let response = self.read(true);
1759            match response {
1760                Ok(Value::ServerError(err)) => {
1761                    if idx < offset {
1762                        if first_err.is_none() {
1763                            first_err = Some(err.into());
1764                        }
1765                    } else {
1766                        rv.push(Value::ServerError(err));
1767                    }
1768                }
1769                Ok(item) => {
1770                    // RESP3 can insert push data between command replies
1771                    if let Value::Push {
1772                        kind: _kind,
1773                        data: _data,
1774                    } = item
1775                    {
1776                        // if that is the case we have to extend the loop and handle push data
1777                        count += 1;
1778                    } else if idx >= offset {
1779                        rv.push(item);
1780                    }
1781                }
1782                Err(err) => {
1783                    if first_err.is_none() {
1784                        first_err = Some(err);
1785                    }
1786                }
1787            }
1788            idx += 1;
1789        }
1790
1791        first_err.map_or(Ok(rv), Err)
1792    }
1793
1794    fn get_db(&self) -> i64 {
1795        self.db
1796    }
1797
1798    fn check_connection(&mut self) -> bool {
1799        cmd("PING").query::<String>(self).is_ok()
1800    }
1801
1802    fn is_open(&self) -> bool {
1803        self.con.is_open()
1804    }
1805}
1806
1807impl<C, T> ConnectionLike for T
1808where
1809    C: ConnectionLike,
1810    T: DerefMut<Target = C>,
1811{
1812    fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
1813        self.deref_mut().req_packed_command(cmd)
1814    }
1815
1816    fn req_packed_commands(
1817        &mut self,
1818        cmd: &[u8],
1819        offset: usize,
1820        count: usize,
1821    ) -> RedisResult<Vec<Value>> {
1822        self.deref_mut().req_packed_commands(cmd, offset, count)
1823    }
1824
1825    fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1826        self.deref_mut().req_command(cmd)
1827    }
1828
1829    fn get_db(&self) -> i64 {
1830        self.deref().get_db()
1831    }
1832
1833    fn supports_pipelining(&self) -> bool {
1834        self.deref().supports_pipelining()
1835    }
1836
1837    fn check_connection(&mut self) -> bool {
1838        self.deref_mut().check_connection()
1839    }
1840
1841    fn is_open(&self) -> bool {
1842        self.deref().is_open()
1843    }
1844}
1845
1846/// The pubsub object provides convenient access to the redis pubsub
1847/// system.  Once created you can subscribe and unsubscribe from channels
1848/// and listen in on messages.
1849///
1850/// Example:
1851///
1852/// ```rust,no_run
1853/// # fn do_something() -> redis::RedisResult<()> {
1854/// let client = redis::Client::open("redis://127.0.0.1/")?;
1855/// let mut con = client.get_connection()?;
1856/// let mut pubsub = con.as_pubsub();
1857/// pubsub.subscribe("channel_1")?;
1858/// pubsub.subscribe("channel_2")?;
1859///
1860/// loop {
1861///     let msg = pubsub.get_message()?;
1862///     let payload : String = msg.get_payload()?;
1863///     println!("channel '{}': {}", msg.get_channel_name(), payload);
1864/// }
1865/// # }
1866/// ```
1867impl<'a> PubSub<'a> {
1868    fn new(con: &'a mut Connection) -> Self {
1869        Self {
1870            con,
1871            waiting_messages: VecDeque::new(),
1872        }
1873    }
1874
1875    fn cache_messages_until_received_response(
1876        &mut self,
1877        cmd: &mut Cmd,
1878        is_sub_unsub: bool,
1879    ) -> RedisResult<Value> {
1880        let ignore_response = self.con.protocol != ProtocolVersion::RESP2 && is_sub_unsub;
1881        cmd.set_no_response(ignore_response);
1882
1883        self.con.send_packed_command(&cmd.get_packed_command())?;
1884
1885        loop {
1886            let response = self.con.recv_response()?;
1887            if let Some(msg) = Msg::from_value(&response) {
1888                self.waiting_messages.push_back(msg);
1889            } else {
1890                return Ok(response);
1891            }
1892        }
1893    }
1894
1895    /// Subscribes to a new channel(s).    
1896    pub fn subscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1897        self.cache_messages_until_received_response(cmd("SUBSCRIBE").arg(channel), true)?;
1898        Ok(())
1899    }
1900
1901    /// Subscribes to new channel(s) with pattern(s).
1902    pub fn psubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1903        self.cache_messages_until_received_response(cmd("PSUBSCRIBE").arg(pchannel), true)?;
1904        Ok(())
1905    }
1906
1907    /// Unsubscribes from a channel(s).
1908    pub fn unsubscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1909        self.cache_messages_until_received_response(cmd("UNSUBSCRIBE").arg(channel), true)?;
1910        Ok(())
1911    }
1912
1913    /// Unsubscribes from channel pattern(s).
1914    pub fn punsubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1915        self.cache_messages_until_received_response(cmd("PUNSUBSCRIBE").arg(pchannel), true)?;
1916        Ok(())
1917    }
1918
1919    /// Sends a ping with a message to the server
1920    pub fn ping_message<T: FromRedisValue>(&mut self, message: impl ToRedisArgs) -> RedisResult<T> {
1921        from_owned_redis_value(
1922            self.cache_messages_until_received_response(cmd("PING").arg(message), false)?,
1923        )
1924    }
1925    /// Sends a ping to the server
1926    pub fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
1927        from_owned_redis_value(
1928            self.cache_messages_until_received_response(&mut cmd("PING"), false)?,
1929        )
1930    }
1931
1932    /// Fetches the next message from the pubsub connection.  Blocks until
1933    /// a message becomes available.  This currently does not provide a
1934    /// wait not to block :(
1935    ///
1936    /// The message itself is still generic and can be converted into an
1937    /// appropriate type through the helper methods on it.
1938    pub fn get_message(&mut self) -> RedisResult<Msg> {
1939        if let Some(msg) = self.waiting_messages.pop_front() {
1940            return Ok(msg);
1941        }
1942        loop {
1943            if let Some(msg) = Msg::from_owned_value(self.con.read(false)?) {
1944                return Ok(msg);
1945            } else {
1946                continue;
1947            }
1948        }
1949    }
1950
1951    /// Sets the read timeout for the connection.
1952    ///
1953    /// If the provided value is `None`, then `get_message` call will
1954    /// block indefinitely. It is an error to pass the zero `Duration` to this
1955    /// method.
1956    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1957        self.con.set_read_timeout(dur)
1958    }
1959}
1960
1961impl Drop for PubSub<'_> {
1962    fn drop(&mut self) {
1963        let _ = self.con.exit_pubsub();
1964    }
1965}
1966
1967/// This holds the data that comes from listening to a pubsub
1968/// connection.  It only contains actual message data.
1969impl Msg {
1970    /// Tries to convert provided [`Value`] into [`Msg`].
1971    pub fn from_value(value: &Value) -> Option<Self> {
1972        Self::from_owned_value(value.clone())
1973    }
1974
1975    /// Tries to convert provided [`Value`] into [`Msg`].
1976    pub fn from_owned_value(value: Value) -> Option<Self> {
1977        let mut pattern = None;
1978        let payload;
1979        let channel;
1980
1981        if let Value::Push { kind, data } = value {
1982            return Self::from_push_info(PushInfo { kind, data });
1983        } else {
1984            let raw_msg: Vec<Value> = from_owned_redis_value(value).ok()?;
1985            let mut iter = raw_msg.into_iter();
1986            let msg_type: String = from_owned_redis_value(iter.next()?).ok()?;
1987            if msg_type == "message" {
1988                channel = iter.next()?;
1989                payload = iter.next()?;
1990            } else if msg_type == "pmessage" {
1991                pattern = Some(iter.next()?);
1992                channel = iter.next()?;
1993                payload = iter.next()?;
1994            } else {
1995                return None;
1996            }
1997        };
1998        Some(Msg {
1999            payload,
2000            channel,
2001            pattern,
2002        })
2003    }
2004
2005    /// Tries to convert provided [`PushInfo`] into [`Msg`].
2006    pub fn from_push_info(push_info: PushInfo) -> Option<Self> {
2007        let mut pattern = None;
2008        let payload;
2009        let channel;
2010
2011        let mut iter = push_info.data.into_iter();
2012        if push_info.kind == PushKind::Message || push_info.kind == PushKind::SMessage {
2013            channel = iter.next()?;
2014            payload = iter.next()?;
2015        } else if push_info.kind == PushKind::PMessage {
2016            pattern = Some(iter.next()?);
2017            channel = iter.next()?;
2018            payload = iter.next()?;
2019        } else {
2020            return None;
2021        }
2022
2023        Some(Msg {
2024            payload,
2025            channel,
2026            pattern,
2027        })
2028    }
2029
2030    /// Returns the channel this message came on.
2031    pub fn get_channel<T: FromRedisValue>(&self) -> RedisResult<T> {
2032        from_redis_value(&self.channel)
2033    }
2034
2035    /// Convenience method to get a string version of the channel.  Unless
2036    /// your channel contains non utf-8 bytes you can always use this
2037    /// method.  If the channel is not a valid string (which really should
2038    /// not happen) then the return value is `"?"`.
2039    pub fn get_channel_name(&self) -> &str {
2040        match self.channel {
2041            Value::BulkString(ref bytes) => from_utf8(bytes).unwrap_or("?"),
2042            _ => "?",
2043        }
2044    }
2045
2046    /// Returns the message's payload in a specific format.
2047    pub fn get_payload<T: FromRedisValue>(&self) -> RedisResult<T> {
2048        from_redis_value(&self.payload)
2049    }
2050
2051    /// Returns the bytes that are the message's payload.  This can be used
2052    /// as an alternative to the `get_payload` function if you are interested
2053    /// in the raw bytes in it.
2054    pub fn get_payload_bytes(&self) -> &[u8] {
2055        match self.payload {
2056            Value::BulkString(ref bytes) => bytes,
2057            _ => b"",
2058        }
2059    }
2060
2061    /// Returns true if the message was constructed from a pattern
2062    /// subscription.
2063    #[allow(clippy::wrong_self_convention)]
2064    pub fn from_pattern(&self) -> bool {
2065        self.pattern.is_some()
2066    }
2067
2068    /// If the message was constructed from a message pattern this can be
2069    /// used to find out which one.  It's recommended to match against
2070    /// an `Option<String>` so that you do not need to use `from_pattern`
2071    /// to figure out if a pattern was set.
2072    pub fn get_pattern<T: FromRedisValue>(&self) -> RedisResult<T> {
2073        match self.pattern {
2074            None => from_redis_value(&Value::Nil),
2075            Some(ref x) => from_redis_value(x),
2076        }
2077    }
2078}
2079
2080/// This function simplifies transaction management slightly.  What it
2081/// does is automatically watching keys and then going into a transaction
2082/// loop util it succeeds.  Once it goes through the results are
2083/// returned.
2084///
2085/// To use the transaction two pieces of information are needed: a list
2086/// of all the keys that need to be watched for modifications and a
2087/// closure with the code that should be execute in the context of the
2088/// transaction.  The closure is invoked with a fresh pipeline in atomic
2089/// mode.  To use the transaction the function needs to return the result
2090/// from querying the pipeline with the connection.
2091///
2092/// The end result of the transaction is then available as the return
2093/// value from the function call.
2094///
2095/// Example:
2096///
2097/// ```rust,no_run
2098/// use redis::Commands;
2099/// # fn do_something() -> redis::RedisResult<()> {
2100/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap();
2101/// # let mut con = client.get_connection().unwrap();
2102/// let key = "the_key";
2103/// let (new_val,) : (isize,) = redis::transaction(&mut con, &[key], |con, pipe| {
2104///     let old_val : isize = con.get(key)?;
2105///     pipe
2106///         .set(key, old_val + 1).ignore()
2107///         .get(key).query(con)
2108/// })?;
2109/// println!("The incremented number is: {}", new_val);
2110/// # Ok(()) }
2111/// ```
2112pub fn transaction<
2113    C: ConnectionLike,
2114    K: ToRedisArgs,
2115    T,
2116    F: FnMut(&mut C, &mut Pipeline) -> RedisResult<Option<T>>,
2117>(
2118    con: &mut C,
2119    keys: &[K],
2120    func: F,
2121) -> RedisResult<T> {
2122    let mut func = func;
2123    loop {
2124        cmd("WATCH").arg(keys).exec(con)?;
2125        let mut p = pipe();
2126        let response: Option<T> = func(con, p.atomic())?;
2127        match response {
2128            None => {
2129                continue;
2130            }
2131            Some(response) => {
2132                // make sure no watch is left in the connection, even if
2133                // someone forgot to use the pipeline.
2134                cmd("UNWATCH").exec(con)?;
2135                return Ok(response);
2136            }
2137        }
2138    }
2139}
2140//TODO: for both clearing logic support sharded channels.
2141
2142/// Common logic for clearing subscriptions in RESP2 async/sync
2143pub fn resp2_is_pub_sub_state_cleared(
2144    received_unsub: &mut bool,
2145    received_punsub: &mut bool,
2146    kind: &[u8],
2147    num: isize,
2148) -> bool {
2149    match kind.first() {
2150        Some(&b'u') => *received_unsub = true,
2151        Some(&b'p') => *received_punsub = true,
2152        _ => (),
2153    };
2154    *received_unsub && *received_punsub && num == 0
2155}
2156
2157/// Common logic for clearing subscriptions in RESP3 async/sync
2158pub fn resp3_is_pub_sub_state_cleared(
2159    received_unsub: &mut bool,
2160    received_punsub: &mut bool,
2161    kind: &PushKind,
2162    num: isize,
2163) -> bool {
2164    match kind {
2165        PushKind::Unsubscribe => *received_unsub = true,
2166        PushKind::PUnsubscribe => *received_punsub = true,
2167        _ => (),
2168    };
2169    *received_unsub && *received_punsub && num == 0
2170}
2171
2172pub fn no_sub_err_is_pub_sub_state_cleared(
2173    received_unsub: &mut bool,
2174    received_punsub: &mut bool,
2175    err: &ServerError,
2176) -> bool {
2177    let details = err.details();
2178    *received_unsub = *received_unsub
2179        || details
2180            .map(|details| details.starts_with("'unsub"))
2181            .unwrap_or_default();
2182    *received_punsub = *received_punsub
2183        || details
2184            .map(|details| details.starts_with("'punsub"))
2185            .unwrap_or_default();
2186    *received_unsub && *received_punsub
2187}
2188
2189/// Common logic for checking real cause of hello3 command error
2190pub fn get_resp3_hello_command_error(err: RedisError) -> RedisError {
2191    if let Some(detail) = err.detail() {
2192        if detail.starts_with("unknown command `HELLO`") {
2193            return (
2194                ErrorKind::RESP3NotSupported,
2195                "Redis Server doesn't support HELLO command therefore resp3 cannot be used",
2196            )
2197                .into();
2198        }
2199    }
2200    err
2201}
2202
2203#[cfg(test)]
2204mod tests {
2205    use super::*;
2206
2207    #[test]
2208    fn test_parse_redis_url() {
2209        let cases = vec![
2210            ("redis://127.0.0.1", true),
2211            ("redis://[::1]", true),
2212            ("rediss://127.0.0.1", true),
2213            ("rediss://[::1]", true),
2214            ("valkey://127.0.0.1", true),
2215            ("valkey://[::1]", true),
2216            ("valkeys://127.0.0.1", true),
2217            ("valkeys://[::1]", true),
2218            ("redis+unix:///run/redis.sock", true),
2219            ("valkey+unix:///run/valkey.sock", true),
2220            ("unix:///run/redis.sock", true),
2221            ("http://127.0.0.1", false),
2222            ("tcp://127.0.0.1", false),
2223        ];
2224        for (url, expected) in cases.into_iter() {
2225            let res = parse_redis_url(url);
2226            assert_eq!(
2227                res.is_some(),
2228                expected,
2229                "Parsed result of `{url}` is not expected",
2230            );
2231        }
2232    }
2233
2234    #[test]
2235    fn test_url_to_tcp_connection_info() {
2236        let cases = vec![
2237            (
2238                url::Url::parse("redis://127.0.0.1").unwrap(),
2239                ConnectionInfo {
2240                    addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2241                    redis: Default::default(),
2242                },
2243            ),
2244            (
2245                url::Url::parse("redis://[::1]").unwrap(),
2246                ConnectionInfo {
2247                    addr: ConnectionAddr::Tcp("::1".to_string(), 6379),
2248                    redis: Default::default(),
2249                },
2250            ),
2251            (
2252                url::Url::parse("redis://%25johndoe%25:%23%40%3C%3E%24@example.com/2").unwrap(),
2253                ConnectionInfo {
2254                    addr: ConnectionAddr::Tcp("example.com".to_string(), 6379),
2255                    redis: RedisConnectionInfo {
2256                        db: 2,
2257                        username: Some("%johndoe%".to_string()),
2258                        password: Some("#@<>$".to_string()),
2259                        ..Default::default()
2260                    },
2261                },
2262            ),
2263            (
2264                url::Url::parse("redis://127.0.0.1/?protocol=2").unwrap(),
2265                ConnectionInfo {
2266                    addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2267                    redis: Default::default(),
2268                },
2269            ),
2270            (
2271                url::Url::parse("redis://127.0.0.1/?protocol=resp3").unwrap(),
2272                ConnectionInfo {
2273                    addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2274                    redis: RedisConnectionInfo {
2275                        protocol: ProtocolVersion::RESP3,
2276                        ..Default::default()
2277                    },
2278                },
2279            ),
2280        ];
2281        for (url, expected) in cases.into_iter() {
2282            let res = url_to_tcp_connection_info(url.clone()).unwrap();
2283            assert_eq!(res.addr, expected.addr, "addr of {url} is not expected");
2284            assert_eq!(
2285                res.redis.db, expected.redis.db,
2286                "db of {url} is not expected",
2287            );
2288            assert_eq!(
2289                res.redis.username, expected.redis.username,
2290                "username of {url} is not expected",
2291            );
2292            assert_eq!(
2293                res.redis.password, expected.redis.password,
2294                "password of {url} is not expected",
2295            );
2296        }
2297    }
2298
2299    #[test]
2300    fn test_url_to_tcp_connection_info_failed() {
2301        let cases = vec![
2302            (
2303                url::Url::parse("redis://").unwrap(),
2304                "Missing hostname",
2305                None,
2306            ),
2307            (
2308                url::Url::parse("redis://127.0.0.1/db").unwrap(),
2309                "Invalid database number",
2310                None,
2311            ),
2312            (
2313                url::Url::parse("redis://C3%B0@127.0.0.1").unwrap(),
2314                "Username is not valid UTF-8 string",
2315                None,
2316            ),
2317            (
2318                url::Url::parse("redis://:C3%B0@127.0.0.1").unwrap(),
2319                "Password is not valid UTF-8 string",
2320                None,
2321            ),
2322            (
2323                url::Url::parse("redis://127.0.0.1/?protocol=4").unwrap(),
2324                "Invalid protocol version",
2325                Some("4"),
2326            ),
2327        ];
2328        for (url, expected, detail) in cases.into_iter() {
2329            let res = url_to_tcp_connection_info(url).unwrap_err();
2330            assert_eq!(
2331                res.kind(),
2332                crate::ErrorKind::InvalidClientConfig,
2333                "{}",
2334                &res,
2335            );
2336            #[allow(deprecated)]
2337            let desc = std::error::Error::description(&res);
2338            assert_eq!(desc, expected, "{}", &res);
2339            assert_eq!(res.detail(), detail, "{}", &res);
2340        }
2341    }
2342
2343    #[test]
2344    #[cfg(unix)]
2345    fn test_url_to_unix_connection_info() {
2346        let cases = vec![
2347            (
2348                url::Url::parse("unix:///var/run/redis.sock").unwrap(),
2349                ConnectionInfo {
2350                    addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2351                    redis: RedisConnectionInfo {
2352                        db: 0,
2353                        username: None,
2354                        password: None,
2355                        protocol: ProtocolVersion::RESP2,
2356                    },
2357                },
2358            ),
2359            (
2360                url::Url::parse("redis+unix:///var/run/redis.sock?db=1").unwrap(),
2361                ConnectionInfo {
2362                    addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2363                    redis: RedisConnectionInfo {
2364                        db: 1,
2365                        ..Default::default()
2366                    },
2367                },
2368            ),
2369            (
2370                url::Url::parse(
2371                    "unix:///example.sock?user=%25johndoe%25&pass=%23%40%3C%3E%24&db=2",
2372                )
2373                .unwrap(),
2374                ConnectionInfo {
2375                    addr: ConnectionAddr::Unix("/example.sock".into()),
2376                    redis: RedisConnectionInfo {
2377                        db: 2,
2378                        username: Some("%johndoe%".to_string()),
2379                        password: Some("#@<>$".to_string()),
2380                        ..Default::default()
2381                    },
2382                },
2383            ),
2384            (
2385                url::Url::parse(
2386                    "redis+unix:///example.sock?pass=%26%3F%3D+%2A%2B&db=2&user=%25johndoe%25",
2387                )
2388                .unwrap(),
2389                ConnectionInfo {
2390                    addr: ConnectionAddr::Unix("/example.sock".into()),
2391                    redis: RedisConnectionInfo {
2392                        db: 2,
2393                        username: Some("%johndoe%".to_string()),
2394                        password: Some("&?= *+".to_string()),
2395                        ..Default::default()
2396                    },
2397                },
2398            ),
2399            (
2400                url::Url::parse("redis+unix:///var/run/redis.sock?protocol=3").unwrap(),
2401                ConnectionInfo {
2402                    addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2403                    redis: RedisConnectionInfo {
2404                        protocol: ProtocolVersion::RESP3,
2405                        ..Default::default()
2406                    },
2407                },
2408            ),
2409        ];
2410        for (url, expected) in cases.into_iter() {
2411            assert_eq!(
2412                ConnectionAddr::Unix(url.to_file_path().unwrap()),
2413                expected.addr,
2414                "addr of {url} is not expected",
2415            );
2416            let res = url_to_unix_connection_info(url.clone()).unwrap();
2417            assert_eq!(res.addr, expected.addr, "addr of {url} is not expected");
2418            assert_eq!(
2419                res.redis.db, expected.redis.db,
2420                "db of {url} is not expected",
2421            );
2422            assert_eq!(
2423                res.redis.username, expected.redis.username,
2424                "username of {url} is not expected",
2425            );
2426            assert_eq!(
2427                res.redis.password, expected.redis.password,
2428                "password of {url} is not expected",
2429            );
2430        }
2431    }
2432}