redis/
connection.rs

1use std::borrow::Cow;
2use std::collections::VecDeque;
3use std::fmt;
4use std::io::{self, Write};
5use std::net::{self, SocketAddr, TcpStream, ToSocketAddrs};
6use std::ops::DerefMut;
7use std::path::PathBuf;
8use std::str::{from_utf8, FromStr};
9use std::time::{Duration, Instant};
10
11use crate::cmd::{cmd, pipe, Cmd};
12use crate::io::tcp::{stream_with_settings, TcpSettings};
13use crate::parser::Parser;
14use crate::pipeline::Pipeline;
15use crate::types::{
16    from_redis_value, ErrorKind, FromRedisValue, HashMap, PushKind, RedisError, RedisResult,
17    ServerError, ServerErrorKind, SyncPushSender, ToRedisArgs, Value,
18};
19use crate::{check_resp3, from_owned_redis_value, ProtocolVersion};
20
21#[cfg(unix)]
22use std::os::unix::net::UnixStream;
23
24use crate::commands::resp3_hello;
25#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
26use native_tls::{TlsConnector, TlsStream};
27
28#[cfg(feature = "tls-rustls")]
29use rustls::{RootCertStore, StreamOwned};
30#[cfg(feature = "tls-rustls")]
31use std::sync::Arc;
32
33use crate::PushInfo;
34
35#[cfg(all(
36    feature = "tls-rustls",
37    not(feature = "tls-native-tls"),
38    not(feature = "tls-rustls-webpki-roots")
39))]
40use rustls_native_certs::load_native_certs;
41
42#[cfg(feature = "tls-rustls")]
43use crate::tls::ClientTlsParams;
44
45// Non-exhaustive to prevent construction outside this crate
46#[derive(Clone, Debug)]
47#[non_exhaustive]
48pub struct TlsConnParams {
49    #[cfg(feature = "tls-rustls")]
50    pub(crate) client_tls_params: Option<ClientTlsParams>,
51    #[cfg(feature = "tls-rustls")]
52    pub(crate) root_cert_store: Option<RootCertStore>,
53    #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
54    pub(crate) danger_accept_invalid_hostnames: bool,
55}
56
57static DEFAULT_PORT: u16 = 6379;
58
59#[inline(always)]
60fn connect_tcp(addr: (&str, u16)) -> io::Result<TcpStream> {
61    let socket = TcpStream::connect(addr)?;
62    stream_with_settings(socket, &TcpSettings::default())
63}
64
65#[inline(always)]
66fn connect_tcp_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
67    let socket = TcpStream::connect_timeout(addr, timeout)?;
68    stream_with_settings(socket, &TcpSettings::default())
69}
70
71/// This function takes a redis URL string and parses it into a URL
72/// as used by rust-url.
73///
74/// This is necessary as the default parser does not understand how redis URLs function.
75pub fn parse_redis_url(input: &str) -> Option<url::Url> {
76    match url::Url::parse(input) {
77        Ok(result) => match result.scheme() {
78            "redis" | "rediss" | "valkey" | "valkeys" | "redis+unix" | "valkey+unix" | "unix" => {
79                Some(result)
80            }
81            _ => None,
82        },
83        Err(_) => None,
84    }
85}
86
87/// TlsMode indicates use or do not use verification of certification.
88///
89/// Check [ConnectionAddr](ConnectionAddr::TcpTls::insecure) for more.
90#[derive(Clone, Copy, PartialEq)]
91pub enum TlsMode {
92    /// Secure verify certification.
93    Secure,
94    /// Insecure do not verify certification.
95    Insecure,
96}
97
98/// Defines the connection address.
99///
100/// Not all connection addresses are supported on all platforms.  For instance
101/// to connect to a unix socket you need to run this on an operating system
102/// that supports them.
103#[derive(Clone, Debug)]
104pub enum ConnectionAddr {
105    /// Format for this is `(host, port)`.
106    Tcp(String, u16),
107    /// Format for this is `(host, port)`.
108    TcpTls {
109        /// Hostname
110        host: String,
111        /// Port
112        port: u16,
113        /// Disable hostname verification when connecting.
114        ///
115        /// # Warning
116        ///
117        /// You should think very carefully before you use this method. If hostname
118        /// verification is not used, any valid certificate for any site will be
119        /// trusted for use from any other. This introduces a significant
120        /// vulnerability to man-in-the-middle attacks.
121        insecure: bool,
122
123        /// TLS certificates and client key.
124        tls_params: Option<TlsConnParams>,
125    },
126    /// Format for this is the path to the unix socket.
127    Unix(PathBuf),
128}
129
130impl PartialEq for ConnectionAddr {
131    fn eq(&self, other: &Self) -> bool {
132        match (self, other) {
133            (ConnectionAddr::Tcp(host1, port1), ConnectionAddr::Tcp(host2, port2)) => {
134                host1 == host2 && port1 == port2
135            }
136            (
137                ConnectionAddr::TcpTls {
138                    host: host1,
139                    port: port1,
140                    insecure: insecure1,
141                    tls_params: _,
142                },
143                ConnectionAddr::TcpTls {
144                    host: host2,
145                    port: port2,
146                    insecure: insecure2,
147                    tls_params: _,
148                },
149            ) => port1 == port2 && host1 == host2 && insecure1 == insecure2,
150            (ConnectionAddr::Unix(path1), ConnectionAddr::Unix(path2)) => path1 == path2,
151            _ => false,
152        }
153    }
154}
155
156impl Eq for ConnectionAddr {}
157
158impl ConnectionAddr {
159    /// Checks if this address is supported.
160    ///
161    /// Because not all platforms support all connection addresses this is a
162    /// quick way to figure out if a connection method is supported. Currently
163    /// this affects:
164    ///
165    /// - Unix socket addresses, which are supported only on Unix
166    ///
167    /// - TLS addresses, which are supported only if a TLS feature is enabled
168    ///   (either `tls-native-tls` or `tls-rustls`).
169    pub fn is_supported(&self) -> bool {
170        match *self {
171            ConnectionAddr::Tcp(_, _) => true,
172            ConnectionAddr::TcpTls { .. } => {
173                cfg!(any(feature = "tls-native-tls", feature = "tls-rustls"))
174            }
175            ConnectionAddr::Unix(_) => cfg!(unix),
176        }
177    }
178
179    /// Configure this address to connect without checking certificate hostnames.
180    ///
181    /// # Warning
182    ///
183    /// You should think very carefully before you use this method. If hostname
184    /// verification is not used, any valid certificate for any site will be
185    /// trusted for use from any other. This introduces a significant
186    /// vulnerability to man-in-the-middle attacks.
187    #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
188    pub fn set_danger_accept_invalid_hostnames(&mut self, insecure: bool) {
189        if let ConnectionAddr::TcpTls { tls_params, .. } = self {
190            if let Some(ref mut params) = tls_params {
191                params.danger_accept_invalid_hostnames = insecure;
192            } else if insecure {
193                *tls_params = Some(TlsConnParams {
194                    #[cfg(feature = "tls-rustls")]
195                    client_tls_params: None,
196                    #[cfg(feature = "tls-rustls")]
197                    root_cert_store: None,
198                    danger_accept_invalid_hostnames: insecure,
199                });
200            }
201        }
202    }
203
204    #[cfg(feature = "cluster")]
205    pub(crate) fn tls_mode(&self) -> Option<TlsMode> {
206        match self {
207            ConnectionAddr::TcpTls { insecure, .. } => {
208                if *insecure {
209                    Some(TlsMode::Insecure)
210                } else {
211                    Some(TlsMode::Secure)
212                }
213            }
214            _ => None,
215        }
216    }
217}
218
219impl fmt::Display for ConnectionAddr {
220    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
221        // Cluster::get_connection_info depends on the return value from this function
222        match *self {
223            ConnectionAddr::Tcp(ref host, port) => write!(f, "{host}:{port}"),
224            ConnectionAddr::TcpTls { ref host, port, .. } => write!(f, "{host}:{port}"),
225            ConnectionAddr::Unix(ref path) => write!(f, "{}", path.display()),
226        }
227    }
228}
229
230/// Holds the connection information that redis should use for connecting.
231#[derive(Clone, Debug)]
232pub struct ConnectionInfo {
233    /// A connection address for where to connect to.
234    pub addr: ConnectionAddr,
235
236    /// A redis connection info for how to handshake with redis.
237    pub redis: RedisConnectionInfo,
238}
239
240/// Redis specific/connection independent information used to establish a connection to redis.
241#[derive(Clone, Debug, Default)]
242pub struct RedisConnectionInfo {
243    /// The database number to use.  This is usually `0`.
244    pub db: i64,
245    /// Optionally a username that should be used for connection.
246    pub username: Option<String>,
247    /// Optionally a password that should be used for connection.
248    pub password: Option<String>,
249    /// Version of the protocol to use.
250    pub protocol: ProtocolVersion,
251}
252
253impl FromStr for ConnectionInfo {
254    type Err = RedisError;
255
256    fn from_str(s: &str) -> Result<Self, Self::Err> {
257        s.into_connection_info()
258    }
259}
260
261/// Converts an object into a connection info struct.  This allows the
262/// constructor of the client to accept connection information in a
263/// range of different formats.
264pub trait IntoConnectionInfo {
265    /// Converts the object into a connection info object.
266    fn into_connection_info(self) -> RedisResult<ConnectionInfo>;
267}
268
269impl IntoConnectionInfo for ConnectionInfo {
270    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
271        Ok(self)
272    }
273}
274
275/// URL format: `{redis|rediss|valkey|valkeys}://[<username>][:<password>@]<hostname>[:port][/<db>]`
276///
277/// - Basic: `redis://127.0.0.1:6379`
278/// - Username & Password: `redis://user:password@127.0.0.1:6379`
279/// - Password only: `redis://:password@127.0.0.1:6379`
280/// - Specifying DB: `redis://127.0.0.1:6379/0`
281/// - Enabling TLS: `rediss://127.0.0.1:6379`
282/// - Enabling Insecure TLS: `rediss://127.0.0.1:6379/#insecure`
283/// - Enabling RESP3: `redis://127.0.0.1:6379/?protocol=resp3`
284impl IntoConnectionInfo for &str {
285    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
286        match parse_redis_url(self) {
287            Some(u) => u.into_connection_info(),
288            None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
289        }
290    }
291}
292
293impl<T> IntoConnectionInfo for (T, u16)
294where
295    T: Into<String>,
296{
297    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
298        Ok(ConnectionInfo {
299            addr: ConnectionAddr::Tcp(self.0.into(), self.1),
300            redis: RedisConnectionInfo::default(),
301        })
302    }
303}
304
305/// URL format: `{redis|rediss|valkey|valkeys}://[<username>][:<password>@]<hostname>[:port][/<db>]`
306///
307/// - Basic: `redis://127.0.0.1:6379`
308/// - Username & Password: `redis://user:password@127.0.0.1:6379`
309/// - Password only: `redis://:password@127.0.0.1:6379`
310/// - Specifying DB: `redis://127.0.0.1:6379/0`
311/// - Enabling TLS: `rediss://127.0.0.1:6379`
312/// - Enabling Insecure TLS: `rediss://127.0.0.1:6379/#insecure`
313/// - Enabling RESP3: `redis://127.0.0.1:6379/?protocol=resp3`
314impl IntoConnectionInfo for String {
315    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
316        match parse_redis_url(&self) {
317            Some(u) => u.into_connection_info(),
318            None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
319        }
320    }
321}
322
323fn parse_protocol(query: &HashMap<Cow<str>, Cow<str>>) -> RedisResult<ProtocolVersion> {
324    Ok(match query.get("protocol") {
325        Some(protocol) => {
326            if protocol == "2" || protocol == "resp2" {
327                ProtocolVersion::RESP2
328            } else if protocol == "3" || protocol == "resp3" {
329                ProtocolVersion::RESP3
330            } else {
331                fail!((
332                    ErrorKind::InvalidClientConfig,
333                    "Invalid protocol version",
334                    protocol.to_string()
335                ))
336            }
337        }
338        None => ProtocolVersion::RESP2,
339    })
340}
341
342fn url_to_tcp_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
343    let host = match url.host() {
344        Some(host) => {
345            // Here we manually match host's enum arms and call their to_string().
346            // Because url.host().to_string() will add `[` and `]` for ipv6:
347            // https://docs.rs/url/latest/src/url/host.rs.html#170
348            // And these brackets will break host.parse::<Ipv6Addr>() when
349            // `client.open()` - `ActualConnection::new()` - `addr.to_socket_addrs()`:
350            // https://doc.rust-lang.org/src/std/net/addr.rs.html#963
351            // https://doc.rust-lang.org/src/std/net/parser.rs.html#158
352            // IpAddr string with brackets can ONLY parse to SocketAddrV6:
353            // https://doc.rust-lang.org/src/std/net/parser.rs.html#255
354            // But if we call Ipv6Addr.to_string directly, it follows rfc5952 without brackets:
355            // https://doc.rust-lang.org/src/std/net/ip.rs.html#1755
356            match host {
357                url::Host::Domain(path) => path.to_string(),
358                url::Host::Ipv4(v4) => v4.to_string(),
359                url::Host::Ipv6(v6) => v6.to_string(),
360            }
361        }
362        None => fail!((ErrorKind::InvalidClientConfig, "Missing hostname")),
363    };
364    let port = url.port().unwrap_or(DEFAULT_PORT);
365    let addr = if url.scheme() == "rediss" || url.scheme() == "valkeys" {
366        #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
367        {
368            match url.fragment() {
369                Some("insecure") => ConnectionAddr::TcpTls {
370                    host,
371                    port,
372                    insecure: true,
373                    tls_params: None,
374                },
375                Some(_) => fail!((
376                    ErrorKind::InvalidClientConfig,
377                    "only #insecure is supported as URL fragment"
378                )),
379                _ => ConnectionAddr::TcpTls {
380                    host,
381                    port,
382                    insecure: false,
383                    tls_params: None,
384                },
385            }
386        }
387
388        #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
389        fail!((
390            ErrorKind::InvalidClientConfig,
391            "can't connect with TLS, the feature is not enabled"
392        ));
393    } else {
394        ConnectionAddr::Tcp(host, port)
395    };
396    let query: HashMap<_, _> = url.query_pairs().collect();
397    Ok(ConnectionInfo {
398        addr,
399        redis: RedisConnectionInfo {
400            db: match url.path().trim_matches('/') {
401                "" => 0,
402                path => path.parse::<i64>().map_err(|_| -> RedisError {
403                    (ErrorKind::InvalidClientConfig, "Invalid database number").into()
404                })?,
405            },
406            username: if url.username().is_empty() {
407                None
408            } else {
409                match percent_encoding::percent_decode(url.username().as_bytes()).decode_utf8() {
410                    Ok(decoded) => Some(decoded.into_owned()),
411                    Err(_) => fail!((
412                        ErrorKind::InvalidClientConfig,
413                        "Username is not valid UTF-8 string"
414                    )),
415                }
416            },
417            password: match url.password() {
418                Some(pw) => match percent_encoding::percent_decode(pw.as_bytes()).decode_utf8() {
419                    Ok(decoded) => Some(decoded.into_owned()),
420                    Err(_) => fail!((
421                        ErrorKind::InvalidClientConfig,
422                        "Password is not valid UTF-8 string"
423                    )),
424                },
425                None => None,
426            },
427            protocol: parse_protocol(&query)?,
428        },
429    })
430}
431
432#[cfg(unix)]
433fn url_to_unix_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
434    let query: HashMap<_, _> = url.query_pairs().collect();
435    Ok(ConnectionInfo {
436        addr: ConnectionAddr::Unix(url.to_file_path().map_err(|_| -> RedisError {
437            (ErrorKind::InvalidClientConfig, "Missing path").into()
438        })?),
439        redis: RedisConnectionInfo {
440            db: match query.get("db") {
441                Some(db) => db.parse::<i64>().map_err(|_| -> RedisError {
442                    (ErrorKind::InvalidClientConfig, "Invalid database number").into()
443                })?,
444
445                None => 0,
446            },
447            username: query.get("user").map(|username| username.to_string()),
448            password: query.get("pass").map(|password| password.to_string()),
449            protocol: parse_protocol(&query)?,
450        },
451    })
452}
453
454#[cfg(not(unix))]
455fn url_to_unix_connection_info(_: url::Url) -> RedisResult<ConnectionInfo> {
456    fail!((
457        ErrorKind::InvalidClientConfig,
458        "Unix sockets are not available on this platform."
459    ));
460}
461
462impl IntoConnectionInfo for url::Url {
463    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
464        match self.scheme() {
465            "redis" | "rediss" | "valkey" | "valkeys" => url_to_tcp_connection_info(self),
466            "unix" | "redis+unix" | "valkey+unix" => url_to_unix_connection_info(self),
467            _ => fail!((
468                ErrorKind::InvalidClientConfig,
469                "URL provided is not a redis URL"
470            )),
471        }
472    }
473}
474
475struct TcpConnection {
476    reader: TcpStream,
477    open: bool,
478}
479
480#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
481struct TcpNativeTlsConnection {
482    reader: TlsStream<TcpStream>,
483    open: bool,
484}
485
486#[cfg(feature = "tls-rustls")]
487struct TcpRustlsConnection {
488    reader: StreamOwned<rustls::ClientConnection, TcpStream>,
489    open: bool,
490}
491
492#[cfg(unix)]
493struct UnixConnection {
494    sock: UnixStream,
495    open: bool,
496}
497
498enum ActualConnection {
499    Tcp(TcpConnection),
500    #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
501    TcpNativeTls(Box<TcpNativeTlsConnection>),
502    #[cfg(feature = "tls-rustls")]
503    TcpRustls(Box<TcpRustlsConnection>),
504    #[cfg(unix)]
505    Unix(UnixConnection),
506}
507
508#[cfg(feature = "tls-rustls-insecure")]
509struct NoCertificateVerification {
510    supported: rustls::crypto::WebPkiSupportedAlgorithms,
511}
512
513#[cfg(feature = "tls-rustls-insecure")]
514impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
515    fn verify_server_cert(
516        &self,
517        _end_entity: &rustls::pki_types::CertificateDer<'_>,
518        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
519        _server_name: &rustls::pki_types::ServerName<'_>,
520        _ocsp_response: &[u8],
521        _now: rustls::pki_types::UnixTime,
522    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
523        Ok(rustls::client::danger::ServerCertVerified::assertion())
524    }
525
526    fn verify_tls12_signature(
527        &self,
528        _message: &[u8],
529        _cert: &rustls::pki_types::CertificateDer<'_>,
530        _dss: &rustls::DigitallySignedStruct,
531    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
532        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
533    }
534
535    fn verify_tls13_signature(
536        &self,
537        _message: &[u8],
538        _cert: &rustls::pki_types::CertificateDer<'_>,
539        _dss: &rustls::DigitallySignedStruct,
540    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
541        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
542    }
543
544    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
545        self.supported.supported_schemes()
546    }
547}
548
549#[cfg(feature = "tls-rustls-insecure")]
550impl fmt::Debug for NoCertificateVerification {
551    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
552        f.debug_struct("NoCertificateVerification").finish()
553    }
554}
555
556/// Insecure `ServerCertVerifier` for rustls that implements `danger_accept_invalid_hostnames`.
557#[cfg(feature = "tls-rustls-insecure")]
558#[derive(Debug)]
559struct AcceptInvalidHostnamesCertVerifier {
560    inner: Arc<rustls::client::WebPkiServerVerifier>,
561}
562
563#[cfg(feature = "tls-rustls-insecure")]
564fn is_hostname_error(err: &rustls::Error) -> bool {
565    matches!(
566        err,
567        rustls::Error::InvalidCertificate(
568            rustls::CertificateError::NotValidForName
569                | rustls::CertificateError::NotValidForNameContext { .. }
570        )
571    )
572}
573
574#[cfg(feature = "tls-rustls-insecure")]
575impl rustls::client::danger::ServerCertVerifier for AcceptInvalidHostnamesCertVerifier {
576    fn verify_server_cert(
577        &self,
578        end_entity: &rustls::pki_types::CertificateDer<'_>,
579        intermediates: &[rustls::pki_types::CertificateDer<'_>],
580        server_name: &rustls::pki_types::ServerName<'_>,
581        ocsp_response: &[u8],
582        now: rustls::pki_types::UnixTime,
583    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
584        self.inner
585            .verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)
586            .or_else(|err| {
587                if is_hostname_error(&err) {
588                    Ok(rustls::client::danger::ServerCertVerified::assertion())
589                } else {
590                    Err(err)
591                }
592            })
593    }
594
595    fn verify_tls12_signature(
596        &self,
597        message: &[u8],
598        cert: &rustls::pki_types::CertificateDer<'_>,
599        dss: &rustls::DigitallySignedStruct,
600    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
601        self.inner
602            .verify_tls12_signature(message, cert, dss)
603            .or_else(|err| {
604                if is_hostname_error(&err) {
605                    Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
606                } else {
607                    Err(err)
608                }
609            })
610    }
611
612    fn verify_tls13_signature(
613        &self,
614        message: &[u8],
615        cert: &rustls::pki_types::CertificateDer<'_>,
616        dss: &rustls::DigitallySignedStruct,
617    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
618        self.inner
619            .verify_tls13_signature(message, cert, dss)
620            .or_else(|err| {
621                if is_hostname_error(&err) {
622                    Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
623                } else {
624                    Err(err)
625                }
626            })
627    }
628
629    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
630        self.inner.supported_verify_schemes()
631    }
632}
633
634/// Represents a stateful redis TCP connection.
635pub struct Connection {
636    con: ActualConnection,
637    parser: Parser,
638    db: i64,
639
640    /// Flag indicating whether the connection was left in the PubSub state after dropping `PubSub`.
641    ///
642    /// This flag is checked when attempting to send a command, and if it's raised, we attempt to
643    /// exit the pubsub state before executing the new request.
644    pubsub: bool,
645
646    // Field indicating which protocol to use for server communications.
647    protocol: ProtocolVersion,
648
649    /// This is used to manage Push messages in RESP3 mode.
650    push_sender: Option<SyncPushSender>,
651
652    /// The number of messages that are expected to be returned from the server,
653    /// but the user no longer waits for - answers for requests that already returned a transient error.
654    messages_to_skip: usize,
655}
656
657/// Represents a RESP2 pubsub connection.
658///
659/// If you're using a DB that supports RESP3, consider using a regular connection and setting a push sender it using [Connection::set_push_sender].
660pub struct PubSub<'a> {
661    con: &'a mut Connection,
662    waiting_messages: VecDeque<Msg>,
663}
664
665/// Represents a pubsub message.
666#[derive(Debug, Clone)]
667pub struct Msg {
668    payload: Value,
669    channel: Value,
670    pattern: Option<Value>,
671}
672
673impl ActualConnection {
674    pub fn new(addr: &ConnectionAddr, timeout: Option<Duration>) -> RedisResult<ActualConnection> {
675        Ok(match *addr {
676            ConnectionAddr::Tcp(ref host, ref port) => {
677                let addr = (host.as_str(), *port);
678                let tcp = match timeout {
679                    None => connect_tcp(addr)?,
680                    Some(timeout) => {
681                        let mut tcp = None;
682                        let mut last_error = None;
683                        for addr in addr.to_socket_addrs()? {
684                            match connect_tcp_timeout(&addr, timeout) {
685                                Ok(l) => {
686                                    tcp = Some(l);
687                                    break;
688                                }
689                                Err(e) => {
690                                    last_error = Some(e);
691                                }
692                            };
693                        }
694                        match (tcp, last_error) {
695                            (Some(tcp), _) => tcp,
696                            (None, Some(e)) => {
697                                fail!(e);
698                            }
699                            (None, None) => {
700                                fail!((
701                                    ErrorKind::InvalidClientConfig,
702                                    "could not resolve to any addresses"
703                                ));
704                            }
705                        }
706                    }
707                };
708                ActualConnection::Tcp(TcpConnection {
709                    reader: tcp,
710                    open: true,
711                })
712            }
713            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
714            ConnectionAddr::TcpTls {
715                ref host,
716                port,
717                insecure,
718                ref tls_params,
719            } => {
720                let tls_connector = if insecure {
721                    TlsConnector::builder()
722                        .danger_accept_invalid_certs(true)
723                        .danger_accept_invalid_hostnames(true)
724                        .use_sni(false)
725                        .build()?
726                } else if let Some(params) = tls_params {
727                    TlsConnector::builder()
728                        .danger_accept_invalid_hostnames(params.danger_accept_invalid_hostnames)
729                        .build()?
730                } else {
731                    TlsConnector::new()?
732                };
733                let addr = (host.as_str(), port);
734                let tls = match timeout {
735                    None => {
736                        let tcp = connect_tcp(addr)?;
737                        match tls_connector.connect(host, tcp) {
738                            Ok(res) => res,
739                            Err(e) => {
740                                fail!((ErrorKind::IoError, "SSL Handshake error", e.to_string()));
741                            }
742                        }
743                    }
744                    Some(timeout) => {
745                        let mut tcp = None;
746                        let mut last_error = None;
747                        for addr in (host.as_str(), port).to_socket_addrs()? {
748                            match connect_tcp_timeout(&addr, timeout) {
749                                Ok(l) => {
750                                    tcp = Some(l);
751                                    break;
752                                }
753                                Err(e) => {
754                                    last_error = Some(e);
755                                }
756                            };
757                        }
758                        match (tcp, last_error) {
759                            (Some(tcp), _) => tls_connector.connect(host, tcp).unwrap(),
760                            (None, Some(e)) => {
761                                fail!(e);
762                            }
763                            (None, None) => {
764                                fail!((
765                                    ErrorKind::InvalidClientConfig,
766                                    "could not resolve to any addresses"
767                                ));
768                            }
769                        }
770                    }
771                };
772                ActualConnection::TcpNativeTls(Box::new(TcpNativeTlsConnection {
773                    reader: tls,
774                    open: true,
775                }))
776            }
777            #[cfg(feature = "tls-rustls")]
778            ConnectionAddr::TcpTls {
779                ref host,
780                port,
781                insecure,
782                ref tls_params,
783            } => {
784                let host: &str = host;
785                let config = create_rustls_config(insecure, tls_params.clone())?;
786                let conn = rustls::ClientConnection::new(
787                    Arc::new(config),
788                    rustls::pki_types::ServerName::try_from(host)?.to_owned(),
789                )?;
790                let reader = match timeout {
791                    None => {
792                        let tcp = connect_tcp((host, port))?;
793                        StreamOwned::new(conn, tcp)
794                    }
795                    Some(timeout) => {
796                        let mut tcp = None;
797                        let mut last_error = None;
798                        for addr in (host, port).to_socket_addrs()? {
799                            match connect_tcp_timeout(&addr, timeout) {
800                                Ok(l) => {
801                                    tcp = Some(l);
802                                    break;
803                                }
804                                Err(e) => {
805                                    last_error = Some(e);
806                                }
807                            };
808                        }
809                        match (tcp, last_error) {
810                            (Some(tcp), _) => StreamOwned::new(conn, tcp),
811                            (None, Some(e)) => {
812                                fail!(e);
813                            }
814                            (None, None) => {
815                                fail!((
816                                    ErrorKind::InvalidClientConfig,
817                                    "could not resolve to any addresses"
818                                ));
819                            }
820                        }
821                    }
822                };
823
824                ActualConnection::TcpRustls(Box::new(TcpRustlsConnection { reader, open: true }))
825            }
826            #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
827            ConnectionAddr::TcpTls { .. } => {
828                fail!((
829                    ErrorKind::InvalidClientConfig,
830                    "Cannot connect to TCP with TLS without the tls feature"
831                ));
832            }
833            #[cfg(unix)]
834            ConnectionAddr::Unix(ref path) => ActualConnection::Unix(UnixConnection {
835                sock: UnixStream::connect(path)?,
836                open: true,
837            }),
838            #[cfg(not(unix))]
839            ConnectionAddr::Unix(ref _path) => {
840                fail!((
841                    ErrorKind::InvalidClientConfig,
842                    "Cannot connect to unix sockets \
843                     on this platform"
844                ));
845            }
846        })
847    }
848
849    pub fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult<Value> {
850        match *self {
851            ActualConnection::Tcp(ref mut connection) => {
852                let res = connection.reader.write_all(bytes).map_err(RedisError::from);
853                match res {
854                    Err(e) => {
855                        if e.is_unrecoverable_error() {
856                            connection.open = false;
857                        }
858                        Err(e)
859                    }
860                    Ok(_) => Ok(Value::Okay),
861                }
862            }
863            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
864            ActualConnection::TcpNativeTls(ref mut connection) => {
865                let res = connection.reader.write_all(bytes).map_err(RedisError::from);
866                match res {
867                    Err(e) => {
868                        if e.is_unrecoverable_error() {
869                            connection.open = false;
870                        }
871                        Err(e)
872                    }
873                    Ok(_) => Ok(Value::Okay),
874                }
875            }
876            #[cfg(feature = "tls-rustls")]
877            ActualConnection::TcpRustls(ref mut connection) => {
878                let res = connection.reader.write_all(bytes).map_err(RedisError::from);
879                match res {
880                    Err(e) => {
881                        if e.is_unrecoverable_error() {
882                            connection.open = false;
883                        }
884                        Err(e)
885                    }
886                    Ok(_) => Ok(Value::Okay),
887                }
888            }
889            #[cfg(unix)]
890            ActualConnection::Unix(ref mut connection) => {
891                let result = connection.sock.write_all(bytes).map_err(RedisError::from);
892                match result {
893                    Err(e) => {
894                        if e.is_unrecoverable_error() {
895                            connection.open = false;
896                        }
897                        Err(e)
898                    }
899                    Ok(_) => Ok(Value::Okay),
900                }
901            }
902        }
903    }
904
905    pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
906        match *self {
907            ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
908                reader.set_write_timeout(dur)?;
909            }
910            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
911            ActualConnection::TcpNativeTls(ref boxed_tls_connection) => {
912                let reader = &(boxed_tls_connection.reader);
913                reader.get_ref().set_write_timeout(dur)?;
914            }
915            #[cfg(feature = "tls-rustls")]
916            ActualConnection::TcpRustls(ref boxed_tls_connection) => {
917                let reader = &(boxed_tls_connection.reader);
918                reader.get_ref().set_write_timeout(dur)?;
919            }
920            #[cfg(unix)]
921            ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
922                sock.set_write_timeout(dur)?;
923            }
924        }
925        Ok(())
926    }
927
928    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
929        match *self {
930            ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
931                reader.set_read_timeout(dur)?;
932            }
933            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
934            ActualConnection::TcpNativeTls(ref boxed_tls_connection) => {
935                let reader = &(boxed_tls_connection.reader);
936                reader.get_ref().set_read_timeout(dur)?;
937            }
938            #[cfg(feature = "tls-rustls")]
939            ActualConnection::TcpRustls(ref boxed_tls_connection) => {
940                let reader = &(boxed_tls_connection.reader);
941                reader.get_ref().set_read_timeout(dur)?;
942            }
943            #[cfg(unix)]
944            ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
945                sock.set_read_timeout(dur)?;
946            }
947        }
948        Ok(())
949    }
950
951    pub fn is_open(&self) -> bool {
952        match *self {
953            ActualConnection::Tcp(TcpConnection { open, .. }) => open,
954            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
955            ActualConnection::TcpNativeTls(ref boxed_tls_connection) => boxed_tls_connection.open,
956            #[cfg(feature = "tls-rustls")]
957            ActualConnection::TcpRustls(ref boxed_tls_connection) => boxed_tls_connection.open,
958            #[cfg(unix)]
959            ActualConnection::Unix(UnixConnection { open, .. }) => open,
960        }
961    }
962}
963
964#[cfg(feature = "tls-rustls")]
965pub(crate) fn create_rustls_config(
966    insecure: bool,
967    tls_params: Option<TlsConnParams>,
968) -> RedisResult<rustls::ClientConfig> {
969    #[allow(unused_mut)]
970    let mut root_store = RootCertStore::empty();
971    #[cfg(feature = "tls-rustls-webpki-roots")]
972    root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
973    #[cfg(all(
974        feature = "tls-rustls",
975        not(feature = "tls-native-tls"),
976        not(feature = "tls-rustls-webpki-roots")
977    ))]
978    {
979        let mut certificate_result = load_native_certs();
980        if let Some(error) = certificate_result.errors.pop() {
981            return Err(error.into());
982        }
983        for cert in certificate_result.certs {
984            root_store.add(cert)?;
985        }
986    }
987
988    let config = rustls::ClientConfig::builder();
989    let config = if let Some(tls_params) = tls_params {
990        let root_cert_store = tls_params.root_cert_store.unwrap_or(root_store);
991        let config_builder = config.with_root_certificates(root_cert_store.clone());
992
993        let config_builder = if let Some(ClientTlsParams {
994            client_cert_chain: client_cert,
995            client_key,
996        }) = tls_params.client_tls_params
997        {
998            config_builder
999                .with_client_auth_cert(client_cert, client_key)
1000                .map_err(|err| {
1001                    RedisError::from((
1002                        ErrorKind::InvalidClientConfig,
1003                        "Unable to build client with TLS parameters provided.",
1004                        err.to_string(),
1005                    ))
1006                })?
1007        } else {
1008            config_builder.with_no_client_auth()
1009        };
1010
1011        // Implement `danger_accept_invalid_hostnames`.
1012        //
1013        // The strange cfg here is to handle a specific unusual combination of features: if
1014        // `tls-native-tls` and `tls-rustls` are enabled, but `tls-rustls-insecure` is not, and the
1015        // application tries to use the danger flag.
1016        #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
1017        let config_builder = if !insecure && tls_params.danger_accept_invalid_hostnames {
1018            #[cfg(not(feature = "tls-rustls-insecure"))]
1019            {
1020                // This code should not enable an insecure mode if the `insecure` feature is not
1021                // set, but it shouldn't silently ignore the flag either. So return an error.
1022                fail!((
1023                    ErrorKind::InvalidClientConfig,
1024                    "Cannot create insecure client via danger_accept_invalid_hostnames without tls-rustls-insecure feature"
1025                ));
1026            }
1027
1028            #[cfg(feature = "tls-rustls-insecure")]
1029            {
1030                let mut config = config_builder;
1031                config.dangerous().set_certificate_verifier(Arc::new(
1032                    AcceptInvalidHostnamesCertVerifier {
1033                        inner: rustls::client::WebPkiServerVerifier::builder(Arc::new(
1034                            root_cert_store,
1035                        ))
1036                        .build()
1037                        .map_err(|err| rustls::Error::from(rustls::OtherError(Arc::new(err))))?,
1038                    },
1039                ));
1040                config
1041            }
1042        } else {
1043            config_builder
1044        };
1045
1046        config_builder
1047    } else {
1048        config
1049            .with_root_certificates(root_store)
1050            .with_no_client_auth()
1051    };
1052
1053    match (insecure, cfg!(feature = "tls-rustls-insecure")) {
1054        #[cfg(feature = "tls-rustls-insecure")]
1055        (true, true) => {
1056            let mut config = config;
1057            config.enable_sni = false;
1058            let Some(crypto_provider) = rustls::crypto::CryptoProvider::get_default() else {
1059                return Err(RedisError::from((
1060                    ErrorKind::InvalidClientConfig,
1061                    "No crypto provider available for rustls",
1062                )));
1063            };
1064            config
1065                .dangerous()
1066                .set_certificate_verifier(Arc::new(NoCertificateVerification {
1067                    supported: crypto_provider.signature_verification_algorithms,
1068                }));
1069
1070            Ok(config)
1071        }
1072        (true, false) => {
1073            fail!((
1074                ErrorKind::InvalidClientConfig,
1075                "Cannot create insecure client without tls-rustls-insecure feature"
1076            ));
1077        }
1078        _ => Ok(config),
1079    }
1080}
1081
1082fn authenticate_cmd(
1083    connection_info: &RedisConnectionInfo,
1084    check_username: bool,
1085    password: &str,
1086) -> Cmd {
1087    let mut command = cmd("AUTH");
1088    if check_username {
1089        if let Some(username) = &connection_info.username {
1090            command.arg(username);
1091        }
1092    }
1093    command.arg(password);
1094    command
1095}
1096
1097pub fn connect(
1098    connection_info: &ConnectionInfo,
1099    timeout: Option<Duration>,
1100) -> RedisResult<Connection> {
1101    let start = Instant::now();
1102    let con: ActualConnection = ActualConnection::new(&connection_info.addr, timeout)?;
1103
1104    // we temporarily set the timeout, and will remove it after finishing setup.
1105    let remaining_timeout = timeout.and_then(|timeout| timeout.checked_sub(start.elapsed()));
1106    // TLS could run logic that doesn't contain a timeout, and should fail if it takes too long.
1107    if timeout.is_some() && remaining_timeout.is_none() {
1108        return Err(RedisError::from(std::io::Error::new(
1109            std::io::ErrorKind::TimedOut,
1110            "Connection timed out",
1111        )));
1112    }
1113    con.set_read_timeout(remaining_timeout)?;
1114    con.set_write_timeout(remaining_timeout)?;
1115
1116    let con = setup_connection(
1117        con,
1118        &connection_info.redis,
1119        #[cfg(feature = "cache-aio")]
1120        None,
1121    )?;
1122
1123    // remove the temporary timeout.
1124    con.set_read_timeout(None)?;
1125    con.set_write_timeout(None)?;
1126
1127    Ok(con)
1128}
1129
1130pub(crate) struct ConnectionSetupComponents {
1131    resp3_auth_cmd_idx: Option<usize>,
1132    resp2_auth_cmd_idx: Option<usize>,
1133    select_cmd_idx: Option<usize>,
1134    #[cfg(feature = "cache-aio")]
1135    cache_cmd_idx: Option<usize>,
1136}
1137
1138pub(crate) fn connection_setup_pipeline(
1139    connection_info: &RedisConnectionInfo,
1140    check_username: bool,
1141    #[cfg(feature = "cache-aio")] cache_config: Option<crate::caching::CacheConfig>,
1142) -> (crate::Pipeline, ConnectionSetupComponents) {
1143    let mut pipeline = pipe();
1144    let (authenticate_with_resp3_cmd_index, authenticate_with_resp2_cmd_index) =
1145        if connection_info.protocol != ProtocolVersion::RESP2 {
1146            pipeline.add_command(resp3_hello(connection_info));
1147            (Some(0), None)
1148        } else if connection_info.password.is_some() {
1149            pipeline.add_command(authenticate_cmd(
1150                connection_info,
1151                check_username,
1152                connection_info.password.as_ref().unwrap(),
1153            ));
1154            (None, Some(0))
1155        } else {
1156            (None, None)
1157        };
1158
1159    let select_db_cmd_index = (connection_info.db != 0)
1160        .then(|| pipeline.len())
1161        .inspect(|_| {
1162            pipeline.cmd("SELECT").arg(connection_info.db);
1163        });
1164
1165    #[cfg(feature = "cache-aio")]
1166    let cache_cmd_index = cache_config.map(|cache_config| {
1167        pipeline.cmd("CLIENT").arg("TRACKING").arg("ON");
1168        match cache_config.mode {
1169            crate::caching::CacheMode::All => {}
1170            crate::caching::CacheMode::OptIn => {
1171                pipeline.arg("OPTIN");
1172            }
1173        }
1174        pipeline.len() - 1
1175    });
1176
1177    // result is ignored, as per the command's instructions.
1178    // https://redis.io/commands/client-setinfo/
1179    #[cfg(not(feature = "disable-client-setinfo"))]
1180    pipeline
1181        .cmd("CLIENT")
1182        .arg("SETINFO")
1183        .arg("LIB-NAME")
1184        .arg("redis-rs")
1185        .ignore();
1186    #[cfg(not(feature = "disable-client-setinfo"))]
1187    pipeline
1188        .cmd("CLIENT")
1189        .arg("SETINFO")
1190        .arg("LIB-VER")
1191        .arg(env!("CARGO_PKG_VERSION"))
1192        .ignore();
1193
1194    (
1195        pipeline,
1196        ConnectionSetupComponents {
1197            resp3_auth_cmd_idx: authenticate_with_resp3_cmd_index,
1198            resp2_auth_cmd_idx: authenticate_with_resp2_cmd_index,
1199            select_cmd_idx: select_db_cmd_index,
1200            #[cfg(feature = "cache-aio")]
1201            cache_cmd_idx: cache_cmd_index,
1202        },
1203    )
1204}
1205
1206fn check_resp3_auth(result: &Value) -> RedisResult<()> {
1207    if let Value::ServerError(err) = result {
1208        return Err(get_resp3_hello_command_error(err.clone().into()));
1209    }
1210    Ok(())
1211}
1212
1213#[derive(PartialEq)]
1214pub(crate) enum AuthResult {
1215    Succeeded,
1216    ShouldRetryWithoutUsername,
1217}
1218
1219fn check_resp2_auth(result: &Value) -> RedisResult<AuthResult> {
1220    let err = match result {
1221        Value::Okay => {
1222            return Ok(AuthResult::Succeeded);
1223        }
1224        Value::ServerError(err) => err,
1225        _ => {
1226            return Err((
1227                ErrorKind::ResponseError,
1228                "Redis server refused to authenticate, returns Ok() != Value::Okay",
1229            )
1230                .into());
1231        }
1232    };
1233
1234    let err_msg = err.details().ok_or((
1235        ErrorKind::AuthenticationFailed,
1236        "Password authentication failed",
1237    ))?;
1238    if !err_msg.contains("wrong number of arguments for 'auth' command") {
1239        return Err((
1240            ErrorKind::AuthenticationFailed,
1241            "Password authentication failed",
1242        )
1243            .into());
1244    }
1245    Ok(AuthResult::ShouldRetryWithoutUsername)
1246}
1247
1248fn check_db_select(value: &Value) -> RedisResult<()> {
1249    let Value::ServerError(err) = value else {
1250        return Ok(());
1251    };
1252
1253    match err.details() {
1254        Some(err_msg) => Err((
1255            ErrorKind::ResponseError,
1256            "Redis server refused to switch database",
1257            err_msg.to_string(),
1258        )
1259            .into()),
1260        None => Err((
1261            ErrorKind::ResponseError,
1262            "Redis server refused to switch database",
1263        )
1264            .into()),
1265    }
1266}
1267
1268#[cfg(feature = "cache-aio")]
1269fn check_caching(result: &Value) -> RedisResult<()> {
1270    match result {
1271        Value::Okay => Ok(()),
1272        _ => Err((
1273            ErrorKind::ResponseError,
1274            "Client-side caching returned unknown response",
1275        )
1276            .into()),
1277    }
1278}
1279
1280pub(crate) fn check_connection_setup(
1281    results: Vec<Value>,
1282    ConnectionSetupComponents {
1283        resp3_auth_cmd_idx,
1284        resp2_auth_cmd_idx,
1285        select_cmd_idx,
1286        #[cfg(feature = "cache-aio")]
1287        cache_cmd_idx,
1288    }: ConnectionSetupComponents,
1289) -> RedisResult<AuthResult> {
1290    // can't have both values set
1291    assert!(!(resp2_auth_cmd_idx.is_some() && resp3_auth_cmd_idx.is_some()));
1292
1293    if let Some(index) = resp3_auth_cmd_idx {
1294        let Some(value) = results.get(index) else {
1295            return Err((ErrorKind::ClientError, "Missing RESP3 auth response").into());
1296        };
1297        check_resp3_auth(value)?;
1298    } else if let Some(index) = resp2_auth_cmd_idx {
1299        let Some(value) = results.get(index) else {
1300            return Err((ErrorKind::ClientError, "Missing RESP2 auth response").into());
1301        };
1302        if check_resp2_auth(value)? == AuthResult::ShouldRetryWithoutUsername {
1303            return Ok(AuthResult::ShouldRetryWithoutUsername);
1304        }
1305    }
1306
1307    if let Some(index) = select_cmd_idx {
1308        let Some(value) = results.get(index) else {
1309            return Err((ErrorKind::ClientError, "Missing SELECT DB response").into());
1310        };
1311        check_db_select(value)?;
1312    }
1313
1314    #[cfg(feature = "cache-aio")]
1315    if let Some(index) = cache_cmd_idx {
1316        let Some(value) = results.get(index) else {
1317            return Err((ErrorKind::ClientError, "Missing Caching response").into());
1318        };
1319        check_caching(value)?;
1320    }
1321
1322    Ok(AuthResult::Succeeded)
1323}
1324
1325fn execute_connection_pipeline(
1326    rv: &mut Connection,
1327    (pipeline, instructions): (crate::Pipeline, ConnectionSetupComponents),
1328) -> RedisResult<AuthResult> {
1329    if pipeline.is_empty() {
1330        return Ok(AuthResult::Succeeded);
1331    }
1332    let results = rv.req_packed_commands(&pipeline.get_packed_pipeline(), 0, pipeline.len())?;
1333
1334    check_connection_setup(results, instructions)
1335}
1336
1337fn setup_connection(
1338    con: ActualConnection,
1339    connection_info: &RedisConnectionInfo,
1340    #[cfg(feature = "cache-aio")] cache_config: Option<crate::caching::CacheConfig>,
1341) -> RedisResult<Connection> {
1342    let mut rv = Connection {
1343        con,
1344        parser: Parser::new(),
1345        db: connection_info.db,
1346        pubsub: false,
1347        protocol: connection_info.protocol,
1348        push_sender: None,
1349        messages_to_skip: 0,
1350    };
1351
1352    if execute_connection_pipeline(
1353        &mut rv,
1354        connection_setup_pipeline(
1355            connection_info,
1356            true,
1357            #[cfg(feature = "cache-aio")]
1358            cache_config,
1359        ),
1360    )? == AuthResult::ShouldRetryWithoutUsername
1361    {
1362        execute_connection_pipeline(
1363            &mut rv,
1364            connection_setup_pipeline(
1365                connection_info,
1366                false,
1367                #[cfg(feature = "cache-aio")]
1368                cache_config,
1369            ),
1370        )?;
1371    }
1372
1373    Ok(rv)
1374}
1375
1376/// Implements the "stateless" part of the connection interface that is used by the
1377/// different objects in redis-rs.
1378///
1379/// Primarily it obviously applies to `Connection` object but also some other objects
1380///  implement the interface (for instance whole clients or certain redis results).
1381///
1382/// Generally clients and connections (as well as redis results of those) implement
1383/// this trait.  Actual connections provide more functionality which can be used
1384/// to implement things like `PubSub` but they also can modify the intrinsic
1385/// state of the TCP connection.  This is not possible with `ConnectionLike`
1386/// implementors because that functionality is not exposed.
1387pub trait ConnectionLike {
1388    /// Sends an already encoded (packed) command into the TCP socket and
1389    /// reads the single response from it.
1390    fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value>;
1391
1392    /// Sends multiple already encoded (packed) command into the TCP socket
1393    /// and reads `count` responses from it.  This is used to implement
1394    /// pipelining.
1395    /// Important - this function is meant for internal usage, since it's
1396    /// easy to pass incorrect `offset` & `count` parameters, which might
1397    /// cause the connection to enter an erroneous state. Users shouldn't
1398    /// call it, instead using the Pipeline::query function.
1399    #[doc(hidden)]
1400    fn req_packed_commands(
1401        &mut self,
1402        cmd: &[u8],
1403        offset: usize,
1404        count: usize,
1405    ) -> RedisResult<Vec<Value>>;
1406
1407    /// Sends a [Cmd] into the TCP socket and reads a single response from it.
1408    fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1409        let pcmd = cmd.get_packed_command();
1410        self.req_packed_command(&pcmd)
1411    }
1412
1413    /// Returns the database this connection is bound to.  Note that this
1414    /// information might be unreliable because it's initially cached and
1415    /// also might be incorrect if the connection like object is not
1416    /// actually connected.
1417    fn get_db(&self) -> i64;
1418
1419    /// Does this connection support pipelining?
1420    #[doc(hidden)]
1421    fn supports_pipelining(&self) -> bool {
1422        true
1423    }
1424
1425    /// Check that all connections it has are available (`PING` internally).
1426    fn check_connection(&mut self) -> bool;
1427
1428    /// Returns the connection status.
1429    ///
1430    /// The connection is open until any `read` call received an
1431    /// invalid response from the server (most likely a closed or dropped
1432    /// connection, otherwise a Redis protocol error). When using unix
1433    /// sockets the connection is open until writing a command failed with a
1434    /// `BrokenPipe` error.
1435    fn is_open(&self) -> bool;
1436}
1437
1438/// A connection is an object that represents a single redis connection.  It
1439/// provides basic support for sending encoded commands into a redis connection
1440/// and to read a response from it.  It's bound to a single database and can
1441/// only be created from the client.
1442///
1443/// You generally do not much with this object other than passing it to
1444/// `Cmd` objects.
1445impl Connection {
1446    /// Sends an already encoded (packed) command into the TCP socket and
1447    /// does not read a response.  This is useful for commands like
1448    /// `MONITOR` which yield multiple items.  This needs to be used with
1449    /// care because it changes the state of the connection.
1450    pub fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()> {
1451        self.send_bytes(cmd)?;
1452        Ok(())
1453    }
1454
1455    /// Fetches a single response from the connection.  This is useful
1456    /// if used in combination with `send_packed_command`.
1457    pub fn recv_response(&mut self) -> RedisResult<Value> {
1458        self.read(true)
1459    }
1460
1461    /// Sets the write timeout for the connection.
1462    ///
1463    /// If the provided value is `None`, then `send_packed_command` call will
1464    /// block indefinitely. It is an error to pass the zero `Duration` to this
1465    /// method.
1466    pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1467        self.con.set_write_timeout(dur)
1468    }
1469
1470    /// Sets the read timeout for the connection.
1471    ///
1472    /// If the provided value is `None`, then `recv_response` call will
1473    /// block indefinitely. It is an error to pass the zero `Duration` to this
1474    /// method.
1475    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1476        self.con.set_read_timeout(dur)
1477    }
1478
1479    /// Creates a [`PubSub`] instance for this connection.
1480    pub fn as_pubsub(&mut self) -> PubSub<'_> {
1481        // NOTE: The pubsub flag is intentionally not raised at this time since
1482        // running commands within the pubsub state should not try and exit from
1483        // the pubsub state.
1484        PubSub::new(self)
1485    }
1486
1487    fn exit_pubsub(&mut self) -> RedisResult<()> {
1488        let res = self.clear_active_subscriptions();
1489        if res.is_ok() {
1490            self.pubsub = false;
1491        } else {
1492            // Raise the pubsub flag to indicate the connection is "stuck" in that state.
1493            self.pubsub = true;
1494        }
1495
1496        res
1497    }
1498
1499    /// Get the inner connection out of a PubSub
1500    ///
1501    /// Any active subscriptions are unsubscribed. In the event of an error, the connection is
1502    /// dropped.
1503    fn clear_active_subscriptions(&mut self) -> RedisResult<()> {
1504        // Responses to unsubscribe commands return in a 3-tuple with values
1505        // ("unsubscribe" or "punsubscribe", name of subscription removed, count of remaining subs).
1506        // The "count of remaining subs" includes both pattern subscriptions and non pattern
1507        // subscriptions. Thus, to accurately drain all unsubscribe messages received from the
1508        // server, both commands need to be executed at once.
1509        {
1510            // Prepare both unsubscribe commands
1511            let unsubscribe = cmd("UNSUBSCRIBE").get_packed_command();
1512            let punsubscribe = cmd("PUNSUBSCRIBE").get_packed_command();
1513
1514            // Execute commands
1515            self.send_bytes(&unsubscribe)?;
1516            self.send_bytes(&punsubscribe)?;
1517        }
1518
1519        // Receive responses
1520        //
1521        // There will be at minimum two responses - 1 for each of punsubscribe and unsubscribe
1522        // commands. There may be more responses if there are active subscriptions. In this case,
1523        // messages are received until the _subscription count_ in the responses reach zero.
1524        let mut received_unsub = false;
1525        let mut received_punsub = false;
1526
1527        loop {
1528            let resp = self.recv_response()?;
1529
1530            match resp {
1531                Value::Push { kind, data } => {
1532                    if data.len() >= 2 {
1533                        if let Value::Int(num) = data[1] {
1534                            if resp3_is_pub_sub_state_cleared(
1535                                &mut received_unsub,
1536                                &mut received_punsub,
1537                                &kind,
1538                                num as isize,
1539                            ) {
1540                                break;
1541                            }
1542                        }
1543                    }
1544                }
1545                Value::ServerError(err) => {
1546                    // a new error behavior, introduced in valkey 8.
1547                    // https://github.com/valkey-io/valkey/pull/759
1548                    if err.kind() == Some(ServerErrorKind::NoSub) {
1549                        if no_sub_err_is_pub_sub_state_cleared(
1550                            &mut received_unsub,
1551                            &mut received_punsub,
1552                            &err,
1553                        ) {
1554                            break;
1555                        } else {
1556                            continue;
1557                        }
1558                    }
1559
1560                    return Err(err.into());
1561                }
1562                Value::Array(vec) => {
1563                    let res: (Vec<u8>, (), isize) = from_owned_redis_value(Value::Array(vec))?;
1564                    if resp2_is_pub_sub_state_cleared(
1565                        &mut received_unsub,
1566                        &mut received_punsub,
1567                        &res.0,
1568                        res.2,
1569                    ) {
1570                        break;
1571                    }
1572                }
1573                _ => {
1574                    return Err((
1575                        ErrorKind::ClientError,
1576                        "Unexpected unsubscribe response",
1577                        format!("{resp:?}"),
1578                    )
1579                        .into())
1580                }
1581            }
1582        }
1583
1584        // Finally, the connection is back in its normal state since all subscriptions were
1585        // cancelled *and* all unsubscribe messages were received.
1586        Ok(())
1587    }
1588
1589    fn send_push(&self, push: PushInfo) {
1590        if let Some(sender) = &self.push_sender {
1591            let _ = sender.send(push);
1592        }
1593    }
1594
1595    fn try_send(&self, value: &RedisResult<Value>) {
1596        if let Ok(Value::Push { kind, data }) = value {
1597            self.send_push(PushInfo {
1598                kind: kind.clone(),
1599                data: data.clone(),
1600            });
1601        }
1602    }
1603
1604    fn send_disconnect(&self) {
1605        self.send_push(PushInfo::disconnect())
1606    }
1607
1608    fn close_connection(&mut self) {
1609        // Notify the PushManager that the connection was lost
1610        self.send_disconnect();
1611        match self.con {
1612            ActualConnection::Tcp(ref mut connection) => {
1613                let _ = connection.reader.shutdown(net::Shutdown::Both);
1614                connection.open = false;
1615            }
1616            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1617            ActualConnection::TcpNativeTls(ref mut connection) => {
1618                let _ = connection.reader.shutdown();
1619                connection.open = false;
1620            }
1621            #[cfg(feature = "tls-rustls")]
1622            ActualConnection::TcpRustls(ref mut connection) => {
1623                let _ = connection.reader.get_mut().shutdown(net::Shutdown::Both);
1624                connection.open = false;
1625            }
1626            #[cfg(unix)]
1627            ActualConnection::Unix(ref mut connection) => {
1628                let _ = connection.sock.shutdown(net::Shutdown::Both);
1629                connection.open = false;
1630            }
1631        }
1632    }
1633
1634    /// Fetches a single message from the connection. If the message is a response,
1635    /// increment `messages_to_skip` if it wasn't received before a timeout.
1636    fn read(&mut self, is_response: bool) -> RedisResult<Value> {
1637        loop {
1638            let result = match self.con {
1639                ActualConnection::Tcp(TcpConnection { ref mut reader, .. }) => {
1640                    self.parser.parse_value(reader)
1641                }
1642                #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1643                ActualConnection::TcpNativeTls(ref mut boxed_tls_connection) => {
1644                    let reader = &mut boxed_tls_connection.reader;
1645                    self.parser.parse_value(reader)
1646                }
1647                #[cfg(feature = "tls-rustls")]
1648                ActualConnection::TcpRustls(ref mut boxed_tls_connection) => {
1649                    let reader = &mut boxed_tls_connection.reader;
1650                    self.parser.parse_value(reader)
1651                }
1652                #[cfg(unix)]
1653                ActualConnection::Unix(UnixConnection { ref mut sock, .. }) => {
1654                    self.parser.parse_value(sock)
1655                }
1656            };
1657            self.try_send(&result);
1658
1659            let Err(err) = &result else {
1660                if self.messages_to_skip > 0 {
1661                    self.messages_to_skip -= 1;
1662                    continue;
1663                }
1664                return result;
1665            };
1666            let Some(io_error) = err.as_io_error() else {
1667                if self.messages_to_skip > 0 {
1668                    self.messages_to_skip -= 1;
1669                    continue;
1670                }
1671                return result;
1672            };
1673            // shutdown connection on protocol error
1674            if io_error.kind() == io::ErrorKind::UnexpectedEof {
1675                self.close_connection();
1676            } else if is_response {
1677                self.messages_to_skip += 1;
1678            }
1679
1680            return result;
1681        }
1682    }
1683
1684    /// Sets sender channel for push values.
1685    pub fn set_push_sender(&mut self, sender: SyncPushSender) {
1686        self.push_sender = Some(sender);
1687    }
1688
1689    fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult<Value> {
1690        let result = self.con.send_bytes(bytes);
1691        if self.protocol != ProtocolVersion::RESP2 {
1692            if let Err(e) = &result {
1693                if e.is_connection_dropped() {
1694                    self.send_disconnect();
1695                }
1696            }
1697        }
1698        result
1699    }
1700
1701    /// Subscribes to a new channel(s).
1702    ///
1703    /// This only works if the connection was configured with [ProtocolVersion::RESP3] and [Self::set_push_sender].
1704    pub fn subscribe_resp3<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1705        check_resp3!(self.protocol);
1706        cmd("SUBSCRIBE")
1707            .arg(channel)
1708            .set_no_response(true)
1709            .exec(self)
1710    }
1711
1712    /// Subscribes to new channel(s) with pattern(s).
1713    ///
1714    /// This only works if the connection was configured with [ProtocolVersion::RESP3] and [Self::set_push_sender].
1715    pub fn psubscribe_resp3<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1716        check_resp3!(self.protocol);
1717        cmd("PSUBSCRIBE")
1718            .arg(pchannel)
1719            .set_no_response(true)
1720            .exec(self)
1721    }
1722
1723    /// Unsubscribes from a channel(s).
1724    ///
1725    /// This only works if the connection was configured with [ProtocolVersion::RESP3] and [Self::set_push_sender].
1726    pub fn unsubscribe_resp3<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1727        check_resp3!(self.protocol);
1728        cmd("UNSUBSCRIBE")
1729            .arg(channel)
1730            .set_no_response(true)
1731            .exec(self)
1732    }
1733
1734    /// Unsubscribes from channel pattern(s).
1735    ///
1736    /// This only works if the connection was configured with [ProtocolVersion::RESP3] and [Self::set_push_sender].
1737    pub fn punsubscribe_resp3<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1738        check_resp3!(self.protocol);
1739        cmd("PUNSUBSCRIBE")
1740            .arg(pchannel)
1741            .set_no_response(true)
1742            .exec(self)
1743    }
1744}
1745
1746impl ConnectionLike for Connection {
1747    /// Sends a [Cmd] into the TCP socket and reads a single response from it.
1748    fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1749        let pcmd = cmd.get_packed_command();
1750        if self.pubsub {
1751            self.exit_pubsub()?;
1752        }
1753
1754        self.send_bytes(&pcmd)?;
1755        if cmd.is_no_response() {
1756            return Ok(Value::Nil);
1757        }
1758        loop {
1759            match self.read(true)? {
1760                Value::Push {
1761                    kind: _kind,
1762                    data: _data,
1763                } => continue,
1764                val => return Ok(val),
1765            }
1766        }
1767    }
1768    fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
1769        if self.pubsub {
1770            self.exit_pubsub()?;
1771        }
1772
1773        self.send_bytes(cmd)?;
1774        loop {
1775            match self.read(true)? {
1776                Value::Push {
1777                    kind: _kind,
1778                    data: _data,
1779                } => continue,
1780                val => return Ok(val),
1781            }
1782        }
1783    }
1784
1785    fn req_packed_commands(
1786        &mut self,
1787        cmd: &[u8],
1788        offset: usize,
1789        count: usize,
1790    ) -> RedisResult<Vec<Value>> {
1791        if self.pubsub {
1792            self.exit_pubsub()?;
1793        }
1794        self.send_bytes(cmd)?;
1795        let mut rv = vec![];
1796        let mut first_err = None;
1797        let mut count = count;
1798        let mut idx = 0;
1799        while idx < (offset + count) {
1800            // When processing a transaction, some responses may be errors.
1801            // We need to keep processing the rest of the responses in that case,
1802            // so bailing early with `?` would not be correct.
1803            // See: https://github.com/redis-rs/redis-rs/issues/436
1804            let response = self.read(true);
1805            match response {
1806                Ok(Value::ServerError(err)) => {
1807                    if idx < offset {
1808                        if first_err.is_none() {
1809                            first_err = Some(err.into());
1810                        }
1811                    } else {
1812                        rv.push(Value::ServerError(err));
1813                    }
1814                }
1815                Ok(item) => {
1816                    // RESP3 can insert push data between command replies
1817                    if let Value::Push {
1818                        kind: _kind,
1819                        data: _data,
1820                    } = item
1821                    {
1822                        // if that is the case we have to extend the loop and handle push data
1823                        count += 1;
1824                    } else if idx >= offset {
1825                        rv.push(item);
1826                    }
1827                }
1828                Err(err) => {
1829                    if first_err.is_none() {
1830                        first_err = Some(err);
1831                    }
1832                }
1833            }
1834            idx += 1;
1835        }
1836
1837        first_err.map_or(Ok(rv), Err)
1838    }
1839
1840    fn get_db(&self) -> i64 {
1841        self.db
1842    }
1843
1844    fn check_connection(&mut self) -> bool {
1845        cmd("PING").query::<String>(self).is_ok()
1846    }
1847
1848    fn is_open(&self) -> bool {
1849        self.con.is_open()
1850    }
1851}
1852
1853impl<C, T> ConnectionLike for T
1854where
1855    C: ConnectionLike,
1856    T: DerefMut<Target = C>,
1857{
1858    fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
1859        self.deref_mut().req_packed_command(cmd)
1860    }
1861
1862    fn req_packed_commands(
1863        &mut self,
1864        cmd: &[u8],
1865        offset: usize,
1866        count: usize,
1867    ) -> RedisResult<Vec<Value>> {
1868        self.deref_mut().req_packed_commands(cmd, offset, count)
1869    }
1870
1871    fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1872        self.deref_mut().req_command(cmd)
1873    }
1874
1875    fn get_db(&self) -> i64 {
1876        self.deref().get_db()
1877    }
1878
1879    fn supports_pipelining(&self) -> bool {
1880        self.deref().supports_pipelining()
1881    }
1882
1883    fn check_connection(&mut self) -> bool {
1884        self.deref_mut().check_connection()
1885    }
1886
1887    fn is_open(&self) -> bool {
1888        self.deref().is_open()
1889    }
1890}
1891
1892/// The pubsub object provides convenient access to the redis pubsub
1893/// system.  Once created you can subscribe and unsubscribe from channels
1894/// and listen in on messages.
1895///
1896/// Example:
1897///
1898/// ```rust,no_run
1899/// # fn do_something() -> redis::RedisResult<()> {
1900/// let client = redis::Client::open("redis://127.0.0.1/")?;
1901/// let mut con = client.get_connection()?;
1902/// let mut pubsub = con.as_pubsub();
1903/// pubsub.subscribe("channel_1")?;
1904/// pubsub.subscribe("channel_2")?;
1905///
1906/// loop {
1907///     let msg = pubsub.get_message()?;
1908///     let payload : String = msg.get_payload()?;
1909///     println!("channel '{}': {}", msg.get_channel_name(), payload);
1910/// }
1911/// # }
1912/// ```
1913impl<'a> PubSub<'a> {
1914    fn new(con: &'a mut Connection) -> Self {
1915        Self {
1916            con,
1917            waiting_messages: VecDeque::new(),
1918        }
1919    }
1920
1921    fn cache_messages_until_received_response(
1922        &mut self,
1923        cmd: &mut Cmd,
1924        is_sub_unsub: bool,
1925    ) -> RedisResult<Value> {
1926        let ignore_response = self.con.protocol != ProtocolVersion::RESP2 && is_sub_unsub;
1927        cmd.set_no_response(ignore_response);
1928
1929        self.con.send_packed_command(&cmd.get_packed_command())?;
1930
1931        loop {
1932            let response = self.con.recv_response()?;
1933            if let Some(msg) = Msg::from_value(&response) {
1934                self.waiting_messages.push_back(msg);
1935            } else {
1936                return Ok(response);
1937            }
1938        }
1939    }
1940
1941    /// Subscribes to a new channel(s).    
1942    pub fn subscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1943        self.cache_messages_until_received_response(cmd("SUBSCRIBE").arg(channel), true)?;
1944        Ok(())
1945    }
1946
1947    /// Subscribes to new channel(s) with pattern(s).
1948    pub fn psubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1949        self.cache_messages_until_received_response(cmd("PSUBSCRIBE").arg(pchannel), true)?;
1950        Ok(())
1951    }
1952
1953    /// Unsubscribes from a channel(s).
1954    pub fn unsubscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1955        self.cache_messages_until_received_response(cmd("UNSUBSCRIBE").arg(channel), true)?;
1956        Ok(())
1957    }
1958
1959    /// Unsubscribes from channel pattern(s).
1960    pub fn punsubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1961        self.cache_messages_until_received_response(cmd("PUNSUBSCRIBE").arg(pchannel), true)?;
1962        Ok(())
1963    }
1964
1965    /// Sends a ping with a message to the server
1966    pub fn ping_message<T: FromRedisValue>(&mut self, message: impl ToRedisArgs) -> RedisResult<T> {
1967        from_owned_redis_value(
1968            self.cache_messages_until_received_response(cmd("PING").arg(message), false)?,
1969        )
1970    }
1971    /// Sends a ping to the server
1972    pub fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
1973        from_owned_redis_value(
1974            self.cache_messages_until_received_response(&mut cmd("PING"), false)?,
1975        )
1976    }
1977
1978    /// Fetches the next message from the pubsub connection.  Blocks until
1979    /// a message becomes available.  This currently does not provide a
1980    /// wait not to block :(
1981    ///
1982    /// The message itself is still generic and can be converted into an
1983    /// appropriate type through the helper methods on it.
1984    pub fn get_message(&mut self) -> RedisResult<Msg> {
1985        if let Some(msg) = self.waiting_messages.pop_front() {
1986            return Ok(msg);
1987        }
1988        loop {
1989            if let Some(msg) = Msg::from_owned_value(self.con.read(false)?) {
1990                return Ok(msg);
1991            } else {
1992                continue;
1993            }
1994        }
1995    }
1996
1997    /// Sets the read timeout for the connection.
1998    ///
1999    /// If the provided value is `None`, then `get_message` call will
2000    /// block indefinitely. It is an error to pass the zero `Duration` to this
2001    /// method.
2002    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
2003        self.con.set_read_timeout(dur)
2004    }
2005}
2006
2007impl Drop for PubSub<'_> {
2008    fn drop(&mut self) {
2009        let _ = self.con.exit_pubsub();
2010    }
2011}
2012
2013/// This holds the data that comes from listening to a pubsub
2014/// connection.  It only contains actual message data.
2015impl Msg {
2016    /// Tries to convert provided [`Value`] into [`Msg`].
2017    pub fn from_value(value: &Value) -> Option<Self> {
2018        Self::from_owned_value(value.clone())
2019    }
2020
2021    /// Tries to convert provided [`Value`] into [`Msg`].
2022    pub fn from_owned_value(value: Value) -> Option<Self> {
2023        let mut pattern = None;
2024        let payload;
2025        let channel;
2026
2027        if let Value::Push { kind, data } = value {
2028            return Self::from_push_info(PushInfo { kind, data });
2029        } else {
2030            let raw_msg: Vec<Value> = from_owned_redis_value(value).ok()?;
2031            let mut iter = raw_msg.into_iter();
2032            let msg_type: String = from_owned_redis_value(iter.next()?).ok()?;
2033            if msg_type == "message" {
2034                channel = iter.next()?;
2035                payload = iter.next()?;
2036            } else if msg_type == "pmessage" {
2037                pattern = Some(iter.next()?);
2038                channel = iter.next()?;
2039                payload = iter.next()?;
2040            } else {
2041                return None;
2042            }
2043        };
2044        Some(Msg {
2045            payload,
2046            channel,
2047            pattern,
2048        })
2049    }
2050
2051    /// Tries to convert provided [`PushInfo`] into [`Msg`].
2052    pub fn from_push_info(push_info: PushInfo) -> Option<Self> {
2053        let mut pattern = None;
2054        let payload;
2055        let channel;
2056
2057        let mut iter = push_info.data.into_iter();
2058        if push_info.kind == PushKind::Message || push_info.kind == PushKind::SMessage {
2059            channel = iter.next()?;
2060            payload = iter.next()?;
2061        } else if push_info.kind == PushKind::PMessage {
2062            pattern = Some(iter.next()?);
2063            channel = iter.next()?;
2064            payload = iter.next()?;
2065        } else {
2066            return None;
2067        }
2068
2069        Some(Msg {
2070            payload,
2071            channel,
2072            pattern,
2073        })
2074    }
2075
2076    /// Returns the channel this message came on.
2077    pub fn get_channel<T: FromRedisValue>(&self) -> RedisResult<T> {
2078        from_redis_value(&self.channel)
2079    }
2080
2081    /// Convenience method to get a string version of the channel.  Unless
2082    /// your channel contains non utf-8 bytes you can always use this
2083    /// method.  If the channel is not a valid string (which really should
2084    /// not happen) then the return value is `"?"`.
2085    pub fn get_channel_name(&self) -> &str {
2086        match self.channel {
2087            Value::BulkString(ref bytes) => from_utf8(bytes).unwrap_or("?"),
2088            _ => "?",
2089        }
2090    }
2091
2092    /// Returns the message's payload in a specific format.
2093    pub fn get_payload<T: FromRedisValue>(&self) -> RedisResult<T> {
2094        from_redis_value(&self.payload)
2095    }
2096
2097    /// Returns the bytes that are the message's payload.  This can be used
2098    /// as an alternative to the `get_payload` function if you are interested
2099    /// in the raw bytes in it.
2100    pub fn get_payload_bytes(&self) -> &[u8] {
2101        match self.payload {
2102            Value::BulkString(ref bytes) => bytes,
2103            _ => b"",
2104        }
2105    }
2106
2107    /// Returns true if the message was constructed from a pattern
2108    /// subscription.
2109    #[allow(clippy::wrong_self_convention)]
2110    pub fn from_pattern(&self) -> bool {
2111        self.pattern.is_some()
2112    }
2113
2114    /// If the message was constructed from a message pattern this can be
2115    /// used to find out which one.  It's recommended to match against
2116    /// an `Option<String>` so that you do not need to use `from_pattern`
2117    /// to figure out if a pattern was set.
2118    pub fn get_pattern<T: FromRedisValue>(&self) -> RedisResult<T> {
2119        match self.pattern {
2120            None => from_redis_value(&Value::Nil),
2121            Some(ref x) => from_redis_value(x),
2122        }
2123    }
2124}
2125
2126/// This function simplifies transaction management slightly.  What it
2127/// does is automatically watching keys and then going into a transaction
2128/// loop util it succeeds.  Once it goes through the results are
2129/// returned.
2130///
2131/// To use the transaction two pieces of information are needed: a list
2132/// of all the keys that need to be watched for modifications and a
2133/// closure with the code that should be execute in the context of the
2134/// transaction.  The closure is invoked with a fresh pipeline in atomic
2135/// mode.  To use the transaction the function needs to return the result
2136/// from querying the pipeline with the connection.
2137///
2138/// The end result of the transaction is then available as the return
2139/// value from the function call.
2140///
2141/// Example:
2142///
2143/// ```rust,no_run
2144/// use redis::Commands;
2145/// # fn do_something() -> redis::RedisResult<()> {
2146/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap();
2147/// # let mut con = client.get_connection().unwrap();
2148/// let key = "the_key";
2149/// let (new_val,) : (isize,) = redis::transaction(&mut con, &[key], |con, pipe| {
2150///     let old_val : isize = con.get(key)?;
2151///     pipe
2152///         .set(key, old_val + 1).ignore()
2153///         .get(key).query(con)
2154/// })?;
2155/// println!("The incremented number is: {}", new_val);
2156/// # Ok(()) }
2157/// ```
2158pub fn transaction<
2159    C: ConnectionLike,
2160    K: ToRedisArgs,
2161    T,
2162    F: FnMut(&mut C, &mut Pipeline) -> RedisResult<Option<T>>,
2163>(
2164    con: &mut C,
2165    keys: &[K],
2166    func: F,
2167) -> RedisResult<T> {
2168    let mut func = func;
2169    loop {
2170        cmd("WATCH").arg(keys).exec(con)?;
2171        let mut p = pipe();
2172        let response: Option<T> = func(con, p.atomic())?;
2173        match response {
2174            None => {
2175                continue;
2176            }
2177            Some(response) => {
2178                // make sure no watch is left in the connection, even if
2179                // someone forgot to use the pipeline.
2180                cmd("UNWATCH").exec(con)?;
2181                return Ok(response);
2182            }
2183        }
2184    }
2185}
2186//TODO: for both clearing logic support sharded channels.
2187
2188/// Common logic for clearing subscriptions in RESP2 async/sync
2189pub fn resp2_is_pub_sub_state_cleared(
2190    received_unsub: &mut bool,
2191    received_punsub: &mut bool,
2192    kind: &[u8],
2193    num: isize,
2194) -> bool {
2195    match kind.first() {
2196        Some(&b'u') => *received_unsub = true,
2197        Some(&b'p') => *received_punsub = true,
2198        _ => (),
2199    };
2200    *received_unsub && *received_punsub && num == 0
2201}
2202
2203/// Common logic for clearing subscriptions in RESP3 async/sync
2204pub fn resp3_is_pub_sub_state_cleared(
2205    received_unsub: &mut bool,
2206    received_punsub: &mut bool,
2207    kind: &PushKind,
2208    num: isize,
2209) -> bool {
2210    match kind {
2211        PushKind::Unsubscribe => *received_unsub = true,
2212        PushKind::PUnsubscribe => *received_punsub = true,
2213        _ => (),
2214    };
2215    *received_unsub && *received_punsub && num == 0
2216}
2217
2218pub fn no_sub_err_is_pub_sub_state_cleared(
2219    received_unsub: &mut bool,
2220    received_punsub: &mut bool,
2221    err: &ServerError,
2222) -> bool {
2223    let details = err.details();
2224    *received_unsub = *received_unsub
2225        || details
2226            .map(|details| details.starts_with("'unsub"))
2227            .unwrap_or_default();
2228    *received_punsub = *received_punsub
2229        || details
2230            .map(|details| details.starts_with("'punsub"))
2231            .unwrap_or_default();
2232    *received_unsub && *received_punsub
2233}
2234
2235/// Common logic for checking real cause of hello3 command error
2236pub fn get_resp3_hello_command_error(err: RedisError) -> RedisError {
2237    if let Some(detail) = err.detail() {
2238        if detail.starts_with("unknown command `HELLO`") {
2239            return (
2240                ErrorKind::RESP3NotSupported,
2241                "Redis Server doesn't support HELLO command therefore resp3 cannot be used",
2242            )
2243                .into();
2244        }
2245    }
2246    err
2247}
2248
2249#[cfg(test)]
2250mod tests {
2251    use super::*;
2252
2253    #[test]
2254    fn test_parse_redis_url() {
2255        let cases = vec![
2256            ("redis://127.0.0.1", true),
2257            ("redis://[::1]", true),
2258            ("rediss://127.0.0.1", true),
2259            ("rediss://[::1]", true),
2260            ("valkey://127.0.0.1", true),
2261            ("valkey://[::1]", true),
2262            ("valkeys://127.0.0.1", true),
2263            ("valkeys://[::1]", true),
2264            ("redis+unix:///run/redis.sock", true),
2265            ("valkey+unix:///run/valkey.sock", true),
2266            ("unix:///run/redis.sock", true),
2267            ("http://127.0.0.1", false),
2268            ("tcp://127.0.0.1", false),
2269        ];
2270        for (url, expected) in cases.into_iter() {
2271            let res = parse_redis_url(url);
2272            assert_eq!(
2273                res.is_some(),
2274                expected,
2275                "Parsed result of `{url}` is not expected",
2276            );
2277        }
2278    }
2279
2280    #[test]
2281    fn test_url_to_tcp_connection_info() {
2282        let cases = vec![
2283            (
2284                url::Url::parse("redis://127.0.0.1").unwrap(),
2285                ConnectionInfo {
2286                    addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2287                    redis: Default::default(),
2288                },
2289            ),
2290            (
2291                url::Url::parse("redis://[::1]").unwrap(),
2292                ConnectionInfo {
2293                    addr: ConnectionAddr::Tcp("::1".to_string(), 6379),
2294                    redis: Default::default(),
2295                },
2296            ),
2297            (
2298                url::Url::parse("redis://%25johndoe%25:%23%40%3C%3E%24@example.com/2").unwrap(),
2299                ConnectionInfo {
2300                    addr: ConnectionAddr::Tcp("example.com".to_string(), 6379),
2301                    redis: RedisConnectionInfo {
2302                        db: 2,
2303                        username: Some("%johndoe%".to_string()),
2304                        password: Some("#@<>$".to_string()),
2305                        ..Default::default()
2306                    },
2307                },
2308            ),
2309            (
2310                url::Url::parse("redis://127.0.0.1/?protocol=2").unwrap(),
2311                ConnectionInfo {
2312                    addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2313                    redis: Default::default(),
2314                },
2315            ),
2316            (
2317                url::Url::parse("redis://127.0.0.1/?protocol=resp3").unwrap(),
2318                ConnectionInfo {
2319                    addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2320                    redis: RedisConnectionInfo {
2321                        protocol: ProtocolVersion::RESP3,
2322                        ..Default::default()
2323                    },
2324                },
2325            ),
2326        ];
2327        for (url, expected) in cases.into_iter() {
2328            let res = url_to_tcp_connection_info(url.clone()).unwrap();
2329            assert_eq!(res.addr, expected.addr, "addr of {url} is not expected");
2330            assert_eq!(
2331                res.redis.db, expected.redis.db,
2332                "db of {url} is not expected",
2333            );
2334            assert_eq!(
2335                res.redis.username, expected.redis.username,
2336                "username of {url} is not expected",
2337            );
2338            assert_eq!(
2339                res.redis.password, expected.redis.password,
2340                "password of {url} is not expected",
2341            );
2342        }
2343    }
2344
2345    #[test]
2346    fn test_url_to_tcp_connection_info_failed() {
2347        let cases = vec![
2348            (
2349                url::Url::parse("redis://").unwrap(),
2350                "Missing hostname",
2351                None,
2352            ),
2353            (
2354                url::Url::parse("redis://127.0.0.1/db").unwrap(),
2355                "Invalid database number",
2356                None,
2357            ),
2358            (
2359                url::Url::parse("redis://C3%B0@127.0.0.1").unwrap(),
2360                "Username is not valid UTF-8 string",
2361                None,
2362            ),
2363            (
2364                url::Url::parse("redis://:C3%B0@127.0.0.1").unwrap(),
2365                "Password is not valid UTF-8 string",
2366                None,
2367            ),
2368            (
2369                url::Url::parse("redis://127.0.0.1/?protocol=4").unwrap(),
2370                "Invalid protocol version",
2371                Some("4"),
2372            ),
2373        ];
2374        for (url, expected, detail) in cases.into_iter() {
2375            let res = url_to_tcp_connection_info(url).unwrap_err();
2376            assert_eq!(
2377                res.kind(),
2378                crate::ErrorKind::InvalidClientConfig,
2379                "{}",
2380                &res,
2381            );
2382            #[allow(deprecated)]
2383            let desc = std::error::Error::description(&res);
2384            assert_eq!(desc, expected, "{}", &res);
2385            assert_eq!(res.detail(), detail, "{}", &res);
2386        }
2387    }
2388
2389    #[test]
2390    #[cfg(unix)]
2391    fn test_url_to_unix_connection_info() {
2392        let cases = vec![
2393            (
2394                url::Url::parse("unix:///var/run/redis.sock").unwrap(),
2395                ConnectionInfo {
2396                    addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2397                    redis: RedisConnectionInfo {
2398                        db: 0,
2399                        username: None,
2400                        password: None,
2401                        protocol: ProtocolVersion::RESP2,
2402                    },
2403                },
2404            ),
2405            (
2406                url::Url::parse("redis+unix:///var/run/redis.sock?db=1").unwrap(),
2407                ConnectionInfo {
2408                    addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2409                    redis: RedisConnectionInfo {
2410                        db: 1,
2411                        ..Default::default()
2412                    },
2413                },
2414            ),
2415            (
2416                url::Url::parse(
2417                    "unix:///example.sock?user=%25johndoe%25&pass=%23%40%3C%3E%24&db=2",
2418                )
2419                .unwrap(),
2420                ConnectionInfo {
2421                    addr: ConnectionAddr::Unix("/example.sock".into()),
2422                    redis: RedisConnectionInfo {
2423                        db: 2,
2424                        username: Some("%johndoe%".to_string()),
2425                        password: Some("#@<>$".to_string()),
2426                        ..Default::default()
2427                    },
2428                },
2429            ),
2430            (
2431                url::Url::parse(
2432                    "redis+unix:///example.sock?pass=%26%3F%3D+%2A%2B&db=2&user=%25johndoe%25",
2433                )
2434                .unwrap(),
2435                ConnectionInfo {
2436                    addr: ConnectionAddr::Unix("/example.sock".into()),
2437                    redis: RedisConnectionInfo {
2438                        db: 2,
2439                        username: Some("%johndoe%".to_string()),
2440                        password: Some("&?= *+".to_string()),
2441                        ..Default::default()
2442                    },
2443                },
2444            ),
2445            (
2446                url::Url::parse("redis+unix:///var/run/redis.sock?protocol=3").unwrap(),
2447                ConnectionInfo {
2448                    addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2449                    redis: RedisConnectionInfo {
2450                        protocol: ProtocolVersion::RESP3,
2451                        ..Default::default()
2452                    },
2453                },
2454            ),
2455        ];
2456        for (url, expected) in cases.into_iter() {
2457            assert_eq!(
2458                ConnectionAddr::Unix(url.to_file_path().unwrap()),
2459                expected.addr,
2460                "addr of {url} is not expected",
2461            );
2462            let res = url_to_unix_connection_info(url.clone()).unwrap();
2463            assert_eq!(res.addr, expected.addr, "addr of {url} is not expected");
2464            assert_eq!(
2465                res.redis.db, expected.redis.db,
2466                "db of {url} is not expected",
2467            );
2468            assert_eq!(
2469                res.redis.username, expected.redis.username,
2470                "username of {url} is not expected",
2471            );
2472            assert_eq!(
2473                res.redis.password, expected.redis.password,
2474                "password of {url} is not expected",
2475            );
2476        }
2477    }
2478}