Skip to main content

redis/
connection.rs

1use std::borrow::Cow;
2use std::collections::VecDeque;
3use std::fmt;
4use std::io::{self, Write};
5use std::net::{self, SocketAddr, TcpStream, ToSocketAddrs};
6use std::ops::DerefMut;
7use std::path::PathBuf;
8use std::str::{FromStr, from_utf8};
9use std::time::{Duration, Instant};
10
11use crate::cmd::{Cmd, cmd, pipe};
12use crate::errors::{ErrorKind, RedisError, ServerError, ServerErrorKind};
13use crate::io::tcp::{TcpSettings, stream_with_settings};
14use crate::parser::Parser;
15use crate::pipeline::Pipeline;
16use crate::types::{
17    FromRedisValue, HashMap, PushKind, RedisResult, SyncPushSender, ToRedisArgs, Value,
18    from_redis_value_ref,
19};
20use crate::{ProtocolVersion, check_resp3, from_redis_value};
21
22#[cfg(unix)]
23use std::os::unix::net::UnixStream;
24
25use crate::commands::resp3_hello;
26use arcstr::ArcStr;
27#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
28use native_tls::{TlsConnector, TlsStream};
29
30#[cfg(feature = "tls-rustls")]
31use rustls::{RootCertStore, StreamOwned};
32#[cfg(feature = "tls-rustls")]
33use std::sync::Arc;
34
35use crate::PushInfo;
36
37#[cfg(all(
38    feature = "tls-rustls",
39    not(feature = "tls-native-tls"),
40    not(feature = "tls-rustls-webpki-roots")
41))]
42use rustls_native_certs::load_native_certs;
43
44#[cfg(feature = "tls-rustls")]
45use crate::tls::ClientTlsParams;
46
47// Non-exhaustive to prevent construction outside this crate
48#[derive(Clone, Debug)]
49pub struct TlsConnParams {
50    #[cfg(feature = "tls-rustls")]
51    pub(crate) client_tls_params: Option<ClientTlsParams>,
52    #[cfg(feature = "tls-rustls")]
53    pub(crate) root_cert_store: Option<RootCertStore>,
54    #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
55    pub(crate) danger_accept_invalid_hostnames: bool,
56}
57
58static DEFAULT_PORT: u16 = 6379;
59
60#[inline(always)]
61fn connect_tcp(addr: (&str, u16), tcp_settings: &TcpSettings) -> io::Result<TcpStream> {
62    let socket = TcpStream::connect(addr)?;
63    stream_with_settings(socket, tcp_settings)
64}
65
66#[inline(always)]
67fn connect_tcp_timeout(
68    addr: &SocketAddr,
69    timeout: Duration,
70    tcp_settings: &TcpSettings,
71) -> io::Result<TcpStream> {
72    let socket = TcpStream::connect_timeout(addr, timeout)?;
73    stream_with_settings(socket, tcp_settings)
74}
75
76/// This function takes a redis URL string and parses it into a URL
77/// as used by rust-url.
78///
79/// This is necessary as the default parser does not understand how redis URLs function.
80pub fn parse_redis_url(input: &str) -> Option<url::Url> {
81    match url::Url::parse(input) {
82        Ok(result) => match result.scheme() {
83            "redis" | "rediss" | "valkey" | "valkeys" | "redis+unix" | "valkey+unix" | "unix" => {
84                Some(result)
85            }
86            _ => None,
87        },
88        Err(_) => None,
89    }
90}
91
92/// TlsMode indicates use or do not use verification of certification.
93///
94/// Check [ConnectionAddr](ConnectionAddr::TcpTls::insecure) for more.
95#[derive(Clone, Copy, PartialEq)]
96#[non_exhaustive]
97pub enum TlsMode {
98    /// Secure verify certification.
99    Secure,
100    /// Insecure do not verify certification.
101    Insecure,
102}
103
104/// Defines the connection address.
105///
106/// Not all connection addresses are supported on all platforms.  For instance
107/// to connect to a unix socket you need to run this on an operating system
108/// that supports them.
109#[derive(Clone, Debug)]
110#[non_exhaustive]
111pub enum ConnectionAddr {
112    /// Format for this is `(host, port)`.
113    Tcp(String, u16),
114    /// Format for this is `(host, port)`.
115    TcpTls {
116        /// Hostname
117        host: String,
118        /// Port
119        port: u16,
120        /// Disable hostname verification when connecting.
121        ///
122        /// # Warning
123        ///
124        /// You should think very carefully before you use this method. If hostname
125        /// verification is not used, any valid certificate for any site will be
126        /// trusted for use from any other. This introduces a significant
127        /// vulnerability to man-in-the-middle attacks.
128        insecure: bool,
129
130        /// TLS certificates and client key.
131        tls_params: Option<TlsConnParams>,
132    },
133    /// Format for this is the path to the unix socket.
134    Unix(PathBuf),
135}
136
137impl PartialEq for ConnectionAddr {
138    fn eq(&self, other: &Self) -> bool {
139        match (self, other) {
140            (ConnectionAddr::Tcp(host1, port1), ConnectionAddr::Tcp(host2, port2)) => {
141                host1 == host2 && port1 == port2
142            }
143            (
144                ConnectionAddr::TcpTls {
145                    host: host1,
146                    port: port1,
147                    insecure: insecure1,
148                    tls_params: _,
149                },
150                ConnectionAddr::TcpTls {
151                    host: host2,
152                    port: port2,
153                    insecure: insecure2,
154                    tls_params: _,
155                },
156            ) => port1 == port2 && host1 == host2 && insecure1 == insecure2,
157            (ConnectionAddr::Unix(path1), ConnectionAddr::Unix(path2)) => path1 == path2,
158            _ => false,
159        }
160    }
161}
162
163impl Eq for ConnectionAddr {}
164
165impl ConnectionAddr {
166    /// Checks if this address is supported.
167    ///
168    /// Because not all platforms support all connection addresses this is a
169    /// quick way to figure out if a connection method is supported. Currently
170    /// this affects:
171    ///
172    /// - Unix socket addresses, which are supported only on Unix
173    ///
174    /// - TLS addresses, which are supported only if a TLS feature is enabled
175    ///   (either `tls-native-tls` or `tls-rustls`).
176    pub fn is_supported(&self) -> bool {
177        match *self {
178            ConnectionAddr::Tcp(_, _) => true,
179            ConnectionAddr::TcpTls { .. } => {
180                cfg!(any(feature = "tls-native-tls", feature = "tls-rustls"))
181            }
182            ConnectionAddr::Unix(_) => cfg!(unix),
183        }
184    }
185
186    /// Configure this address to connect without checking certificate hostnames.
187    ///
188    /// # Warning
189    ///
190    /// You should think very carefully before you use this method. If hostname
191    /// verification is not used, any valid certificate for any site will be
192    /// trusted for use from any other. This introduces a significant
193    /// vulnerability to man-in-the-middle attacks.
194    #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
195    pub fn set_danger_accept_invalid_hostnames(&mut self, insecure: bool) {
196        if let ConnectionAddr::TcpTls { tls_params, .. } = self {
197            if let Some(params) = tls_params {
198                params.danger_accept_invalid_hostnames = insecure;
199            } else if insecure {
200                *tls_params = Some(TlsConnParams {
201                    #[cfg(feature = "tls-rustls")]
202                    client_tls_params: None,
203                    #[cfg(feature = "tls-rustls")]
204                    root_cert_store: None,
205                    danger_accept_invalid_hostnames: insecure,
206                });
207            }
208        }
209    }
210
211    #[cfg(feature = "cluster")]
212    pub(crate) fn tls_mode(&self) -> Option<TlsMode> {
213        match self {
214            ConnectionAddr::TcpTls { insecure, .. } => {
215                if *insecure {
216                    Some(TlsMode::Insecure)
217                } else {
218                    Some(TlsMode::Secure)
219                }
220            }
221            _ => None,
222        }
223    }
224}
225
226impl fmt::Display for ConnectionAddr {
227    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
228        // Cluster::get_connection_info depends on the return value from this function
229        match *self {
230            ConnectionAddr::Tcp(ref host, port) => write!(f, "{host}:{port}"),
231            ConnectionAddr::TcpTls { ref host, port, .. } => write!(f, "{host}:{port}"),
232            ConnectionAddr::Unix(ref path) => write!(f, "{}", path.display()),
233        }
234    }
235}
236
237/// Holds the connection information that redis should use for connecting.
238#[derive(Clone, Debug)]
239pub struct ConnectionInfo {
240    /// A connection address for where to connect to.
241    pub(crate) addr: ConnectionAddr,
242
243    /// The settings for the TCP connection
244    pub(crate) tcp_settings: TcpSettings,
245    /// A redis connection info for how to handshake with redis.
246    pub(crate) redis: RedisConnectionInfo,
247}
248
249impl ConnectionInfo {
250    /// Returns the connection address.
251    pub fn addr(&self) -> &ConnectionAddr {
252        &self.addr
253    }
254
255    /// Returns the settings for the TCP connection.
256    pub fn tcp_settings(&self) -> &TcpSettings {
257        &self.tcp_settings
258    }
259
260    /// Returns the redis connection info for how to handshake with redis.
261    pub fn redis_settings(&self) -> &RedisConnectionInfo {
262        &self.redis
263    }
264
265    /// Sets the connection address for where to connect to.
266    pub fn set_addr(mut self, addr: ConnectionAddr) -> Self {
267        self.addr = addr;
268        self
269    }
270
271    /// Sets the TCP settings for the connection.
272    pub fn set_tcp_settings(mut self, tcp_settings: TcpSettings) -> Self {
273        self.tcp_settings = tcp_settings;
274        self
275    }
276
277    /// Set all redis connection info fields at once.
278    pub fn set_redis_settings(mut self, redis: RedisConnectionInfo) -> Self {
279        self.redis = redis;
280        self
281    }
282}
283
284/// Redis specific/connection independent information used to establish a connection to redis.
285#[derive(Clone, Default)]
286pub struct RedisConnectionInfo {
287    /// The database number to use.  This is usually `0`.
288    pub(crate) db: i64,
289    /// Optionally a username that should be used for connection.
290    pub(crate) username: Option<ArcStr>,
291    /// Optionally a password that should be used for connection.
292    pub(crate) password: Option<ArcStr>,
293    /// Version of the protocol to use.
294    pub(crate) protocol: ProtocolVersion,
295    /// If set, the connection shouldn't send the library name to the server.
296    pub(crate) skip_set_lib_name: bool,
297}
298
299impl RedisConnectionInfo {
300    /// Returns the username that should be used for connection.
301    pub fn username(&self) -> Option<&str> {
302        self.username.as_deref()
303    }
304
305    /// Returns the password that should be used for connection.
306    pub fn password(&self) -> Option<&str> {
307        self.password.as_deref()
308    }
309
310    /// Returns version of the protocol to use.
311    pub fn protocol(&self) -> ProtocolVersion {
312        self.protocol
313    }
314
315    /// Returns `true` if the `CLIENT SETINFO` command should be skipped.
316    pub fn skip_set_lib_name(&self) -> bool {
317        self.skip_set_lib_name
318    }
319
320    /// Returns the database number to use.
321    pub fn db(&self) -> i64 {
322        self.db
323    }
324
325    /// Sets the username for the connection's ACL.
326    pub fn set_username(mut self, username: impl AsRef<str>) -> Self {
327        self.username = Some(username.as_ref().into());
328        self
329    }
330
331    /// Sets the password for the connection's ACL.
332    pub fn set_password(mut self, password: impl AsRef<str>) -> Self {
333        self.password = Some(password.as_ref().into());
334        self
335    }
336
337    /// Sets the version of the RESP to use.
338    pub fn set_protocol(mut self, protocol: ProtocolVersion) -> Self {
339        self.protocol = protocol;
340        self
341    }
342
343    /// Removes the pipelined `CLIENT SETINFO` call from the connection creation.
344    pub fn set_skip_set_lib_name(mut self) -> Self {
345        self.skip_set_lib_name = true;
346        self
347    }
348
349    /// Sets the database number to use.
350    pub fn set_db(mut self, db: i64) -> Self {
351        self.db = db;
352        self
353    }
354}
355
356impl std::fmt::Debug for RedisConnectionInfo {
357    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358        let RedisConnectionInfo {
359            db,
360            username,
361            password,
362            protocol,
363            skip_set_lib_name,
364        } = self;
365        let mut debug_info = f.debug_struct("RedisConnectionInfo");
366
367        debug_info.field("db", &db);
368        debug_info.field("username", &username);
369        debug_info.field("password", &password.as_ref().map(|_| "<redacted>"));
370        debug_info.field("protocol", &protocol);
371        debug_info.field("skip_set_lib_name", &skip_set_lib_name);
372
373        debug_info.finish()
374    }
375}
376
377impl FromStr for ConnectionInfo {
378    type Err = RedisError;
379
380    fn from_str(s: &str) -> Result<Self, Self::Err> {
381        s.into_connection_info()
382    }
383}
384
385/// Converts an object into a connection info struct.  This allows the
386/// constructor of the client to accept connection information in a
387/// range of different formats.
388pub trait IntoConnectionInfo {
389    /// Converts the object into a connection info object.
390    fn into_connection_info(self) -> RedisResult<ConnectionInfo>;
391}
392
393impl IntoConnectionInfo for ConnectionInfo {
394    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
395        Ok(self)
396    }
397}
398
399impl IntoConnectionInfo for ConnectionAddr {
400    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
401        Ok(ConnectionInfo {
402            addr: self,
403            redis: Default::default(),
404            tcp_settings: Default::default(),
405        })
406    }
407}
408
409/// URL format: `{redis|rediss|valkey|valkeys}://[<username>][:<password>@]<hostname>[:port][/<db>]`
410///
411/// - Basic: `redis://127.0.0.1:6379`
412/// - Username & Password: `redis://user:password@127.0.0.1:6379`
413/// - Password only: `redis://:password@127.0.0.1:6379`
414/// - Specifying DB: `redis://127.0.0.1:6379/0`
415/// - Enabling TLS: `rediss://127.0.0.1:6379`
416/// - Enabling Insecure TLS: `rediss://127.0.0.1:6379/#insecure`
417/// - Enabling RESP3: `redis://127.0.0.1:6379/?protocol=resp3`
418impl IntoConnectionInfo for &str {
419    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
420        match parse_redis_url(self) {
421            Some(u) => u.into_connection_info(),
422            None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
423        }
424    }
425}
426
427impl<T> IntoConnectionInfo for (T, u16)
428where
429    T: Into<String>,
430{
431    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
432        Ok(ConnectionInfo {
433            addr: ConnectionAddr::Tcp(self.0.into(), self.1),
434            redis: RedisConnectionInfo::default(),
435            tcp_settings: TcpSettings::default(),
436        })
437    }
438}
439
440/// URL format: `{redis|rediss|valkey|valkeys}://[<username>][:<password>@]<hostname>[:port][/<db>]`
441///
442/// - Basic: `redis://127.0.0.1:6379`
443/// - Username & Password: `redis://user:password@127.0.0.1:6379`
444/// - Password only: `redis://:password@127.0.0.1:6379`
445/// - Specifying DB: `redis://127.0.0.1:6379/0`
446/// - Enabling TLS: `rediss://127.0.0.1:6379`
447/// - Enabling Insecure TLS: `rediss://127.0.0.1:6379/#insecure`
448/// - Enabling RESP3: `redis://127.0.0.1:6379/?protocol=resp3`
449impl IntoConnectionInfo for String {
450    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
451        match parse_redis_url(&self) {
452            Some(u) => u.into_connection_info(),
453            None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
454        }
455    }
456}
457
458fn parse_protocol(query: &HashMap<Cow<str>, Cow<str>>) -> RedisResult<ProtocolVersion> {
459    Ok(match query.get("protocol") {
460        Some(protocol) => {
461            if protocol == "2" || protocol == "resp2" {
462                ProtocolVersion::RESP2
463            } else if protocol == "3" || protocol == "resp3" {
464                ProtocolVersion::RESP3
465            } else {
466                fail!((
467                    ErrorKind::InvalidClientConfig,
468                    "Invalid protocol version",
469                    protocol.to_string()
470                ))
471            }
472        }
473        None => ProtocolVersion::RESP2,
474    })
475}
476
477#[inline]
478pub(crate) fn is_wildcard_address(address: &str) -> bool {
479    address == "0.0.0.0" || address == "::"
480}
481
482fn url_to_tcp_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
483    let host = match url.host() {
484        Some(host) => {
485            // Here we manually match host's enum arms and call their to_string().
486            // Because url.host().to_string() will add `[` and `]` for ipv6:
487            // https://docs.rs/url/latest/src/url/host.rs.html#170
488            // And these brackets will break host.parse::<Ipv6Addr>() when
489            // `client.open()` - `ActualConnection::new()` - `addr.to_socket_addrs()`:
490            // https://doc.rust-lang.org/src/std/net/addr.rs.html#963
491            // https://doc.rust-lang.org/src/std/net/parser.rs.html#158
492            // IpAddr string with brackets can ONLY parse to SocketAddrV6:
493            // https://doc.rust-lang.org/src/std/net/parser.rs.html#255
494            // But if we call Ipv6Addr.to_string directly, it follows rfc5952 without brackets:
495            // https://doc.rust-lang.org/src/std/net/ip.rs.html#1755
496            let host_str = match host {
497                url::Host::Domain(path) => path.to_string(),
498                url::Host::Ipv4(v4) => v4.to_string(),
499                url::Host::Ipv6(v6) => v6.to_string(),
500            };
501
502            if is_wildcard_address(&host_str) {
503                return Err(RedisError::from((
504                    ErrorKind::InvalidClientConfig,
505                    "Cannot connect to a wildcard address (0.0.0.0 or ::)",
506                )));
507            }
508            host_str
509        }
510        None => fail!((ErrorKind::InvalidClientConfig, "Missing hostname")),
511    };
512    let port = url.port().unwrap_or(DEFAULT_PORT);
513    let addr = if url.scheme() == "rediss" || url.scheme() == "valkeys" {
514        #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
515        {
516            match url.fragment() {
517                Some("insecure") => ConnectionAddr::TcpTls {
518                    host,
519                    port,
520                    insecure: true,
521                    tls_params: None,
522                },
523                Some(_) => fail!((
524                    ErrorKind::InvalidClientConfig,
525                    "only #insecure is supported as URL fragment"
526                )),
527                _ => ConnectionAddr::TcpTls {
528                    host,
529                    port,
530                    insecure: false,
531                    tls_params: None,
532                },
533            }
534        }
535
536        #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
537        fail!((
538            ErrorKind::InvalidClientConfig,
539            "can't connect with TLS, the feature is not enabled"
540        ));
541    } else {
542        ConnectionAddr::Tcp(host, port)
543    };
544    let query: HashMap<_, _> = url.query_pairs().collect();
545    Ok(ConnectionInfo {
546        addr,
547        redis: RedisConnectionInfo {
548            db: match url.path().trim_matches('/') {
549                "" => 0,
550                path => path.parse::<i64>().map_err(|_| -> RedisError {
551                    (ErrorKind::InvalidClientConfig, "Invalid database number").into()
552                })?,
553            },
554            username: if url.username().is_empty() {
555                None
556            } else {
557                match percent_encoding::percent_decode(url.username().as_bytes()).decode_utf8() {
558                    Ok(decoded) => Some(decoded.into()),
559                    Err(_) => fail!((
560                        ErrorKind::InvalidClientConfig,
561                        "Username is not valid UTF-8 string"
562                    )),
563                }
564            },
565            password: match url.password() {
566                Some(pw) => match percent_encoding::percent_decode(pw.as_bytes()).decode_utf8() {
567                    Ok(decoded) => Some(decoded.into()),
568                    Err(_) => fail!((
569                        ErrorKind::InvalidClientConfig,
570                        "Password is not valid UTF-8 string"
571                    )),
572                },
573                None => None,
574            },
575            protocol: parse_protocol(&query)?,
576            skip_set_lib_name: false,
577        },
578        tcp_settings: TcpSettings::default(),
579    })
580}
581
582#[cfg(unix)]
583fn url_to_unix_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
584    let query: HashMap<_, _> = url.query_pairs().collect();
585    Ok(ConnectionInfo {
586        addr: ConnectionAddr::Unix(url.to_file_path().map_err(|_| -> RedisError {
587            (ErrorKind::InvalidClientConfig, "Missing path").into()
588        })?),
589        redis: RedisConnectionInfo {
590            db: match query.get("db") {
591                Some(db) => db.parse::<i64>().map_err(|_| -> RedisError {
592                    (ErrorKind::InvalidClientConfig, "Invalid database number").into()
593                })?,
594
595                None => 0,
596            },
597            username: query.get("user").map(|username| username.as_ref().into()),
598            password: query.get("pass").map(|password| password.as_ref().into()),
599            protocol: parse_protocol(&query)?,
600            ..Default::default()
601        },
602        tcp_settings: TcpSettings::default(),
603    })
604}
605
606#[cfg(not(unix))]
607fn url_to_unix_connection_info(_: url::Url) -> RedisResult<ConnectionInfo> {
608    fail!((
609        ErrorKind::InvalidClientConfig,
610        "Unix sockets are not available on this platform."
611    ));
612}
613
614impl IntoConnectionInfo for url::Url {
615    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
616        match self.scheme() {
617            "redis" | "rediss" | "valkey" | "valkeys" => url_to_tcp_connection_info(self),
618            "unix" | "redis+unix" | "valkey+unix" => url_to_unix_connection_info(self),
619            _ => fail!((
620                ErrorKind::InvalidClientConfig,
621                "URL provided is not a redis URL"
622            )),
623        }
624    }
625}
626
627struct TcpConnection {
628    reader: TcpStream,
629    open: bool,
630}
631
632#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
633struct TcpNativeTlsConnection {
634    reader: TlsStream<TcpStream>,
635    open: bool,
636}
637
638#[cfg(feature = "tls-rustls")]
639struct TcpRustlsConnection {
640    reader: StreamOwned<rustls::ClientConnection, TcpStream>,
641    open: bool,
642}
643
644#[cfg(unix)]
645struct UnixConnection {
646    sock: UnixStream,
647    open: bool,
648}
649
650enum ActualConnection {
651    Tcp(TcpConnection),
652    #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
653    TcpNativeTls(Box<TcpNativeTlsConnection>),
654    #[cfg(feature = "tls-rustls")]
655    TcpRustls(Box<TcpRustlsConnection>),
656    #[cfg(unix)]
657    Unix(UnixConnection),
658}
659
660#[cfg(feature = "tls-rustls-insecure")]
661struct NoCertificateVerification {
662    supported: rustls::crypto::WebPkiSupportedAlgorithms,
663}
664
665#[cfg(feature = "tls-rustls-insecure")]
666impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
667    fn verify_server_cert(
668        &self,
669        _end_entity: &rustls::pki_types::CertificateDer<'_>,
670        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
671        _server_name: &rustls::pki_types::ServerName<'_>,
672        _ocsp_response: &[u8],
673        _now: rustls::pki_types::UnixTime,
674    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
675        Ok(rustls::client::danger::ServerCertVerified::assertion())
676    }
677
678    fn verify_tls12_signature(
679        &self,
680        _message: &[u8],
681        _cert: &rustls::pki_types::CertificateDer<'_>,
682        _dss: &rustls::DigitallySignedStruct,
683    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
684        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
685    }
686
687    fn verify_tls13_signature(
688        &self,
689        _message: &[u8],
690        _cert: &rustls::pki_types::CertificateDer<'_>,
691        _dss: &rustls::DigitallySignedStruct,
692    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
693        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
694    }
695
696    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
697        self.supported.supported_schemes()
698    }
699}
700
701#[cfg(feature = "tls-rustls-insecure")]
702impl fmt::Debug for NoCertificateVerification {
703    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
704        f.debug_struct("NoCertificateVerification").finish()
705    }
706}
707
708/// Insecure `ServerCertVerifier` for rustls that implements `danger_accept_invalid_hostnames`.
709#[cfg(feature = "tls-rustls-insecure")]
710#[derive(Debug)]
711struct AcceptInvalidHostnamesCertVerifier {
712    inner: Arc<rustls::client::WebPkiServerVerifier>,
713}
714
715#[cfg(feature = "tls-rustls-insecure")]
716fn is_hostname_error(err: &rustls::Error) -> bool {
717    matches!(
718        err,
719        rustls::Error::InvalidCertificate(
720            rustls::CertificateError::NotValidForName
721                | rustls::CertificateError::NotValidForNameContext { .. }
722        )
723    )
724}
725
726#[cfg(feature = "tls-rustls-insecure")]
727impl rustls::client::danger::ServerCertVerifier for AcceptInvalidHostnamesCertVerifier {
728    fn verify_server_cert(
729        &self,
730        end_entity: &rustls::pki_types::CertificateDer<'_>,
731        intermediates: &[rustls::pki_types::CertificateDer<'_>],
732        server_name: &rustls::pki_types::ServerName<'_>,
733        ocsp_response: &[u8],
734        now: rustls::pki_types::UnixTime,
735    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
736        self.inner
737            .verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)
738            .or_else(|err| {
739                if is_hostname_error(&err) {
740                    Ok(rustls::client::danger::ServerCertVerified::assertion())
741                } else {
742                    Err(err)
743                }
744            })
745    }
746
747    fn verify_tls12_signature(
748        &self,
749        message: &[u8],
750        cert: &rustls::pki_types::CertificateDer<'_>,
751        dss: &rustls::DigitallySignedStruct,
752    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
753        self.inner
754            .verify_tls12_signature(message, cert, dss)
755            .or_else(|err| {
756                if is_hostname_error(&err) {
757                    Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
758                } else {
759                    Err(err)
760                }
761            })
762    }
763
764    fn verify_tls13_signature(
765        &self,
766        message: &[u8],
767        cert: &rustls::pki_types::CertificateDer<'_>,
768        dss: &rustls::DigitallySignedStruct,
769    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
770        self.inner
771            .verify_tls13_signature(message, cert, dss)
772            .or_else(|err| {
773                if is_hostname_error(&err) {
774                    Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
775                } else {
776                    Err(err)
777                }
778            })
779    }
780
781    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
782        self.inner.supported_verify_schemes()
783    }
784}
785
786/// Represents a stateful redis TCP connection.
787pub struct Connection {
788    con: ActualConnection,
789    parser: Parser,
790    db: i64,
791
792    /// Flag indicating whether the connection was left in the PubSub state after dropping `PubSub`.
793    ///
794    /// This flag is checked when attempting to send a command, and if it's raised, we attempt to
795    /// exit the pubsub state before executing the new request.
796    pubsub: bool,
797
798    // Field indicating which protocol to use for server communications.
799    protocol: ProtocolVersion,
800
801    /// This is used to manage Push messages in RESP3 mode.
802    push_sender: Option<SyncPushSender>,
803
804    /// The number of messages that are expected to be returned from the server,
805    /// but the user no longer waits for - answers for requests that already returned a transient error.
806    messages_to_skip: usize,
807}
808
809/// Represents a RESP2 pubsub connection.
810///
811/// If you're using a DB that supports RESP3, consider using a regular connection and setting a push sender it using [Connection::set_push_sender].
812pub struct PubSub<'a> {
813    con: &'a mut Connection,
814    waiting_messages: VecDeque<Msg>,
815}
816
817/// Represents a pubsub message.
818#[derive(Debug, Clone)]
819pub struct Msg {
820    payload: Value,
821    channel: Value,
822    pattern: Option<Value>,
823}
824
825impl ActualConnection {
826    pub fn new(
827        addr: &ConnectionAddr,
828        timeout: Option<Duration>,
829        tcp_settings: &TcpSettings,
830    ) -> RedisResult<ActualConnection> {
831        Ok(match *addr {
832            ConnectionAddr::Tcp(ref host, ref port) => {
833                if is_wildcard_address(host) {
834                    fail!((
835                        ErrorKind::InvalidClientConfig,
836                        "Cannot connect to a wildcard address (0.0.0.0 or ::)"
837                    ));
838                }
839                let addr = (host.as_str(), *port);
840                let tcp = match timeout {
841                    None => connect_tcp(addr, tcp_settings)?,
842                    Some(timeout) => {
843                        let mut tcp = None;
844                        let mut last_error = None;
845                        for addr in addr.to_socket_addrs()? {
846                            match connect_tcp_timeout(&addr, timeout, tcp_settings) {
847                                Ok(l) => {
848                                    tcp = Some(l);
849                                    break;
850                                }
851                                Err(e) => {
852                                    last_error = Some(e);
853                                }
854                            };
855                        }
856                        match (tcp, last_error) {
857                            (Some(tcp), _) => tcp,
858                            (None, Some(e)) => {
859                                fail!(e);
860                            }
861                            (None, None) => {
862                                fail!((
863                                    ErrorKind::InvalidClientConfig,
864                                    "could not resolve to any addresses"
865                                ));
866                            }
867                        }
868                    }
869                };
870                ActualConnection::Tcp(TcpConnection {
871                    reader: tcp,
872                    open: true,
873                })
874            }
875            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
876            ConnectionAddr::TcpTls {
877                ref host,
878                port,
879                insecure,
880                ref tls_params,
881            } => {
882                let tls_connector = if insecure {
883                    TlsConnector::builder()
884                        .danger_accept_invalid_certs(true)
885                        .danger_accept_invalid_hostnames(true)
886                        .use_sni(false)
887                        .build()?
888                } else if let Some(params) = tls_params {
889                    TlsConnector::builder()
890                        .danger_accept_invalid_hostnames(params.danger_accept_invalid_hostnames)
891                        .build()?
892                } else {
893                    TlsConnector::new()?
894                };
895                let addr = (host.as_str(), port);
896                let tls = match timeout {
897                    None => {
898                        let tcp = connect_tcp(addr, tcp_settings)?;
899                        match tls_connector.connect(host, tcp) {
900                            Ok(res) => res,
901                            Err(e) => {
902                                fail!((ErrorKind::Io, "SSL Handshake error", e.to_string()));
903                            }
904                        }
905                    }
906                    Some(timeout) => {
907                        let mut tcp = None;
908                        let mut last_error = None;
909                        for addr in (host.as_str(), port).to_socket_addrs()? {
910                            match connect_tcp_timeout(&addr, timeout, tcp_settings) {
911                                Ok(l) => {
912                                    tcp = Some(l);
913                                    break;
914                                }
915                                Err(e) => {
916                                    last_error = Some(e);
917                                }
918                            };
919                        }
920                        match (tcp, last_error) {
921                            (Some(tcp), _) => tls_connector.connect(host, tcp).unwrap(),
922                            (None, Some(e)) => {
923                                fail!(e);
924                            }
925                            (None, None) => {
926                                fail!((
927                                    ErrorKind::InvalidClientConfig,
928                                    "could not resolve to any addresses"
929                                ));
930                            }
931                        }
932                    }
933                };
934                ActualConnection::TcpNativeTls(Box::new(TcpNativeTlsConnection {
935                    reader: tls,
936                    open: true,
937                }))
938            }
939            #[cfg(feature = "tls-rustls")]
940            ConnectionAddr::TcpTls {
941                ref host,
942                port,
943                insecure,
944                ref tls_params,
945            } => {
946                let host: &str = host;
947                let config = create_rustls_config(insecure, tls_params.clone())?;
948                let conn = rustls::ClientConnection::new(
949                    Arc::new(config),
950                    rustls::pki_types::ServerName::try_from(host)?.to_owned(),
951                )?;
952                let reader = match timeout {
953                    None => {
954                        let tcp = connect_tcp((host, port), tcp_settings)?;
955                        StreamOwned::new(conn, tcp)
956                    }
957                    Some(timeout) => {
958                        let mut tcp = None;
959                        let mut last_error = None;
960                        for addr in (host, port).to_socket_addrs()? {
961                            match connect_tcp_timeout(&addr, timeout, tcp_settings) {
962                                Ok(l) => {
963                                    tcp = Some(l);
964                                    break;
965                                }
966                                Err(e) => {
967                                    last_error = Some(e);
968                                }
969                            };
970                        }
971                        match (tcp, last_error) {
972                            (Some(tcp), _) => StreamOwned::new(conn, tcp),
973                            (None, Some(e)) => {
974                                fail!(e);
975                            }
976                            (None, None) => {
977                                fail!((
978                                    ErrorKind::InvalidClientConfig,
979                                    "could not resolve to any addresses"
980                                ));
981                            }
982                        }
983                    }
984                };
985
986                ActualConnection::TcpRustls(Box::new(TcpRustlsConnection { reader, open: true }))
987            }
988            #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
989            ConnectionAddr::TcpTls { .. } => {
990                fail!((
991                    ErrorKind::InvalidClientConfig,
992                    "Cannot connect to TCP with TLS without the tls feature"
993                ));
994            }
995            #[cfg(unix)]
996            ConnectionAddr::Unix(ref path) => ActualConnection::Unix(UnixConnection {
997                sock: UnixStream::connect(path)?,
998                open: true,
999            }),
1000            #[cfg(not(unix))]
1001            ConnectionAddr::Unix(ref _path) => {
1002                fail!((
1003                    ErrorKind::InvalidClientConfig,
1004                    "Cannot connect to unix sockets \
1005                     on this platform"
1006                ));
1007            }
1008        })
1009    }
1010
1011    pub fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult<Value> {
1012        match *self {
1013            ActualConnection::Tcp(ref mut connection) => {
1014                let res = connection.reader.write_all(bytes).map_err(RedisError::from);
1015                match res {
1016                    Err(e) => {
1017                        if e.is_unrecoverable_error() {
1018                            connection.open = false;
1019                        }
1020                        Err(e)
1021                    }
1022                    Ok(_) => Ok(Value::Okay),
1023                }
1024            }
1025            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1026            ActualConnection::TcpNativeTls(ref mut connection) => {
1027                let res = connection.reader.write_all(bytes).map_err(RedisError::from);
1028                match res {
1029                    Err(e) => {
1030                        if e.is_unrecoverable_error() {
1031                            connection.open = false;
1032                        }
1033                        Err(e)
1034                    }
1035                    Ok(_) => Ok(Value::Okay),
1036                }
1037            }
1038            #[cfg(feature = "tls-rustls")]
1039            ActualConnection::TcpRustls(ref mut connection) => {
1040                let res = connection.reader.write_all(bytes).map_err(RedisError::from);
1041                match res {
1042                    Err(e) => {
1043                        if e.is_unrecoverable_error() {
1044                            connection.open = false;
1045                        }
1046                        Err(e)
1047                    }
1048                    Ok(_) => Ok(Value::Okay),
1049                }
1050            }
1051            #[cfg(unix)]
1052            ActualConnection::Unix(ref mut connection) => {
1053                let result = connection.sock.write_all(bytes).map_err(RedisError::from);
1054                match result {
1055                    Err(e) => {
1056                        if e.is_unrecoverable_error() {
1057                            connection.open = false;
1058                        }
1059                        Err(e)
1060                    }
1061                    Ok(_) => Ok(Value::Okay),
1062                }
1063            }
1064        }
1065    }
1066
1067    pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1068        match *self {
1069            ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
1070                reader.set_write_timeout(dur)?;
1071            }
1072            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1073            ActualConnection::TcpNativeTls(ref boxed_tls_connection) => {
1074                let reader = &(boxed_tls_connection.reader);
1075                reader.get_ref().set_write_timeout(dur)?;
1076            }
1077            #[cfg(feature = "tls-rustls")]
1078            ActualConnection::TcpRustls(ref boxed_tls_connection) => {
1079                let reader = &(boxed_tls_connection.reader);
1080                reader.get_ref().set_write_timeout(dur)?;
1081            }
1082            #[cfg(unix)]
1083            ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
1084                sock.set_write_timeout(dur)?;
1085            }
1086        }
1087        Ok(())
1088    }
1089
1090    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1091        match *self {
1092            ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
1093                reader.set_read_timeout(dur)?;
1094            }
1095            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1096            ActualConnection::TcpNativeTls(ref boxed_tls_connection) => {
1097                let reader = &(boxed_tls_connection.reader);
1098                reader.get_ref().set_read_timeout(dur)?;
1099            }
1100            #[cfg(feature = "tls-rustls")]
1101            ActualConnection::TcpRustls(ref boxed_tls_connection) => {
1102                let reader = &(boxed_tls_connection.reader);
1103                reader.get_ref().set_read_timeout(dur)?;
1104            }
1105            #[cfg(unix)]
1106            ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
1107                sock.set_read_timeout(dur)?;
1108            }
1109        }
1110        Ok(())
1111    }
1112
1113    pub fn is_open(&self) -> bool {
1114        match *self {
1115            ActualConnection::Tcp(TcpConnection { open, .. }) => open,
1116            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1117            ActualConnection::TcpNativeTls(ref boxed_tls_connection) => boxed_tls_connection.open,
1118            #[cfg(feature = "tls-rustls")]
1119            ActualConnection::TcpRustls(ref boxed_tls_connection) => boxed_tls_connection.open,
1120            #[cfg(unix)]
1121            ActualConnection::Unix(UnixConnection { open, .. }) => open,
1122        }
1123    }
1124}
1125
1126#[cfg(feature = "tls-rustls")]
1127pub(crate) fn create_rustls_config(
1128    insecure: bool,
1129    tls_params: Option<TlsConnParams>,
1130) -> RedisResult<rustls::ClientConfig> {
1131    #[allow(unused_mut)]
1132    let mut root_store = RootCertStore::empty();
1133    #[cfg(feature = "tls-rustls-webpki-roots")]
1134    root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
1135    #[cfg(all(
1136        feature = "tls-rustls",
1137        not(feature = "tls-native-tls"),
1138        not(feature = "tls-rustls-webpki-roots")
1139    ))]
1140    {
1141        let mut certificate_result = load_native_certs();
1142        if let Some(error) = certificate_result.errors.pop() {
1143            return Err(error.into());
1144        }
1145        for cert in certificate_result.certs {
1146            root_store.add(cert)?;
1147        }
1148    }
1149
1150    let config = rustls::ClientConfig::builder();
1151    let config = if let Some(tls_params) = tls_params {
1152        let root_cert_store = tls_params.root_cert_store.unwrap_or(root_store);
1153        let config_builder = config.with_root_certificates(root_cert_store.clone());
1154
1155        let config_builder = if let Some(ClientTlsParams {
1156            client_cert_chain: client_cert,
1157            client_key,
1158        }) = tls_params.client_tls_params
1159        {
1160            config_builder
1161                .with_client_auth_cert(client_cert, client_key)
1162                .map_err(|err| {
1163                    RedisError::from((
1164                        ErrorKind::InvalidClientConfig,
1165                        "Unable to build client with TLS parameters provided.",
1166                        err.to_string(),
1167                    ))
1168                })?
1169        } else {
1170            config_builder.with_no_client_auth()
1171        };
1172
1173        // Implement `danger_accept_invalid_hostnames`.
1174        //
1175        // The strange cfg here is to handle a specific unusual combination of features: if
1176        // `tls-native-tls` and `tls-rustls` are enabled, but `tls-rustls-insecure` is not, and the
1177        // application tries to use the danger flag.
1178        #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
1179        let config_builder = if !insecure && tls_params.danger_accept_invalid_hostnames {
1180            #[cfg(not(feature = "tls-rustls-insecure"))]
1181            {
1182                // This code should not enable an insecure mode if the `insecure` feature is not
1183                // set, but it shouldn't silently ignore the flag either. So return an error.
1184                fail!((
1185                    ErrorKind::InvalidClientConfig,
1186                    "Cannot create insecure client via danger_accept_invalid_hostnames without tls-rustls-insecure feature"
1187                ));
1188            }
1189
1190            #[cfg(feature = "tls-rustls-insecure")]
1191            {
1192                let mut config = config_builder;
1193                config.dangerous().set_certificate_verifier(Arc::new(
1194                    AcceptInvalidHostnamesCertVerifier {
1195                        inner: rustls::client::WebPkiServerVerifier::builder(Arc::new(
1196                            root_cert_store,
1197                        ))
1198                        .build()
1199                        .map_err(|err| rustls::Error::from(rustls::OtherError(Arc::new(err))))?,
1200                    },
1201                ));
1202                config
1203            }
1204        } else {
1205            config_builder
1206        };
1207
1208        config_builder
1209    } else {
1210        config
1211            .with_root_certificates(root_store)
1212            .with_no_client_auth()
1213    };
1214
1215    match (insecure, cfg!(feature = "tls-rustls-insecure")) {
1216        #[cfg(feature = "tls-rustls-insecure")]
1217        (true, true) => {
1218            let mut config = config;
1219            config.enable_sni = false;
1220            let Some(crypto_provider) = rustls::crypto::CryptoProvider::get_default() else {
1221                return Err(RedisError::from((
1222                    ErrorKind::InvalidClientConfig,
1223                    "No crypto provider available for rustls",
1224                )));
1225            };
1226            config
1227                .dangerous()
1228                .set_certificate_verifier(Arc::new(NoCertificateVerification {
1229                    supported: crypto_provider.signature_verification_algorithms,
1230                }));
1231
1232            Ok(config)
1233        }
1234        (true, false) => {
1235            fail!((
1236                ErrorKind::InvalidClientConfig,
1237                "Cannot create insecure client without tls-rustls-insecure feature"
1238            ));
1239        }
1240        _ => Ok(config),
1241    }
1242}
1243
1244pub(crate) fn authenticate_cmd(username: Option<&str>, password: &str) -> Cmd {
1245    let mut command = cmd("AUTH");
1246
1247    if let Some(username) = &username {
1248        command.arg(username);
1249    }
1250
1251    command.arg(password);
1252    command
1253}
1254
1255pub fn connect(
1256    connection_info: &ConnectionInfo,
1257    timeout: Option<Duration>,
1258) -> RedisResult<Connection> {
1259    let start = Instant::now();
1260    let con: ActualConnection = ActualConnection::new(
1261        &connection_info.addr,
1262        timeout,
1263        &connection_info.tcp_settings,
1264    )?;
1265
1266    // we temporarily set the timeout, and will remove it after finishing setup.
1267    let remaining_timeout = timeout.and_then(|timeout| timeout.checked_sub(start.elapsed()));
1268    // TLS could run logic that doesn't contain a timeout, and should fail if it takes too long.
1269    if timeout.is_some() && remaining_timeout.is_none() {
1270        return Err(RedisError::from(std::io::Error::new(
1271            std::io::ErrorKind::TimedOut,
1272            "Connection timed out",
1273        )));
1274    }
1275    con.set_read_timeout(remaining_timeout)?;
1276    con.set_write_timeout(remaining_timeout)?;
1277
1278    let con = setup_connection(
1279        con,
1280        &connection_info.redis,
1281        #[cfg(feature = "cache-aio")]
1282        None,
1283    )?;
1284
1285    // remove the temporary timeout.
1286    con.set_read_timeout(None)?;
1287    con.set_write_timeout(None)?;
1288
1289    Ok(con)
1290}
1291
1292pub(crate) struct ConnectionSetupComponents {
1293    resp3_auth_cmd_idx: Option<usize>,
1294    resp2_auth_cmd_idx: Option<usize>,
1295    select_cmd_idx: Option<usize>,
1296    #[cfg(feature = "cache-aio")]
1297    cache_cmd_idx: Option<usize>,
1298}
1299
1300pub(crate) fn connection_setup_pipeline(
1301    connection_info: &RedisConnectionInfo,
1302    check_username: bool,
1303    #[cfg(feature = "cache-aio")] cache_config: Option<crate::caching::CacheConfig>,
1304) -> (crate::Pipeline, ConnectionSetupComponents) {
1305    let mut pipeline = pipe();
1306    let (authenticate_with_resp3_cmd_index, authenticate_with_resp2_cmd_index) =
1307        if connection_info.protocol.supports_resp3() {
1308            pipeline.add_command(resp3_hello(connection_info));
1309            (Some(0), None)
1310        } else if let Some(password) = connection_info.password.as_ref() {
1311            pipeline.add_command(authenticate_cmd(
1312                check_username.then(|| connection_info.username()).flatten(),
1313                password,
1314            ));
1315            (None, Some(0))
1316        } else {
1317            (None, None)
1318        };
1319
1320    let select_db_cmd_index = (connection_info.db != 0)
1321        .then(|| pipeline.len())
1322        .inspect(|_| {
1323            pipeline.cmd("SELECT").arg(connection_info.db);
1324        });
1325
1326    #[cfg(feature = "cache-aio")]
1327    let cache_cmd_index = cache_config.map(|cache_config| {
1328        pipeline.cmd("CLIENT").arg("TRACKING").arg("ON");
1329        match cache_config.mode {
1330            crate::caching::CacheMode::All => {}
1331            crate::caching::CacheMode::OptIn => {
1332                pipeline.arg("OPTIN");
1333            }
1334        }
1335        pipeline.len() - 1
1336    });
1337
1338    // result is ignored, as per the command's instructions.
1339    // https://redis.io/commands/client-setinfo/
1340    if !connection_info.skip_set_lib_name {
1341        pipeline
1342            .cmd("CLIENT")
1343            .arg("SETINFO")
1344            .arg("LIB-NAME")
1345            .arg("redis-rs")
1346            .ignore();
1347        pipeline
1348            .cmd("CLIENT")
1349            .arg("SETINFO")
1350            .arg("LIB-VER")
1351            .arg(env!("CARGO_PKG_VERSION"))
1352            .ignore();
1353    }
1354
1355    (
1356        pipeline,
1357        ConnectionSetupComponents {
1358            resp3_auth_cmd_idx: authenticate_with_resp3_cmd_index,
1359            resp2_auth_cmd_idx: authenticate_with_resp2_cmd_index,
1360            select_cmd_idx: select_db_cmd_index,
1361            #[cfg(feature = "cache-aio")]
1362            cache_cmd_idx: cache_cmd_index,
1363        },
1364    )
1365}
1366
1367fn check_resp3_auth(result: &Value) -> RedisResult<()> {
1368    if let Value::ServerError(err) = result {
1369        return Err(get_resp3_hello_command_error(err.clone().into()));
1370    }
1371    Ok(())
1372}
1373
1374#[derive(PartialEq)]
1375pub(crate) enum AuthResult {
1376    Succeeded,
1377    ShouldRetryWithoutUsername,
1378}
1379
1380fn check_resp2_auth(result: &Value) -> RedisResult<AuthResult> {
1381    let err = match result {
1382        Value::Okay => {
1383            return Ok(AuthResult::Succeeded);
1384        }
1385        Value::ServerError(err) => err,
1386        _ => {
1387            return Err((
1388                ServerErrorKind::ResponseError.into(),
1389                "Redis server refused to authenticate, returns Ok() != Value::Okay",
1390            )
1391                .into());
1392        }
1393    };
1394
1395    let err_msg = err.details().ok_or((
1396        ErrorKind::AuthenticationFailed,
1397        "Password authentication failed",
1398    ))?;
1399    if !err_msg.contains("wrong number of arguments for 'auth' command") {
1400        return Err((
1401            ErrorKind::AuthenticationFailed,
1402            "Password authentication failed",
1403        )
1404            .into());
1405    }
1406    Ok(AuthResult::ShouldRetryWithoutUsername)
1407}
1408
1409fn check_db_select(value: &Value) -> RedisResult<()> {
1410    let Value::ServerError(err) = value else {
1411        return Ok(());
1412    };
1413
1414    match err.details() {
1415        Some(err_msg) => Err((
1416            ServerErrorKind::ResponseError.into(),
1417            "Redis server refused to switch database",
1418            err_msg.to_string(),
1419        )
1420            .into()),
1421        None => Err((
1422            ServerErrorKind::ResponseError.into(),
1423            "Redis server refused to switch database",
1424        )
1425            .into()),
1426    }
1427}
1428
1429#[cfg(feature = "cache-aio")]
1430fn check_caching(result: &Value) -> RedisResult<()> {
1431    match result {
1432        Value::Okay => Ok(()),
1433        _ => Err((
1434            ServerErrorKind::ResponseError.into(),
1435            "Client-side caching returned unknown response",
1436            format!("{result:?}"),
1437        )
1438            .into()),
1439    }
1440}
1441
1442pub(crate) fn check_connection_setup(
1443    results: Vec<Value>,
1444    ConnectionSetupComponents {
1445        resp3_auth_cmd_idx,
1446        resp2_auth_cmd_idx,
1447        select_cmd_idx,
1448        #[cfg(feature = "cache-aio")]
1449        cache_cmd_idx,
1450    }: ConnectionSetupComponents,
1451) -> RedisResult<AuthResult> {
1452    // can't have both values set
1453    assert!(!(resp2_auth_cmd_idx.is_some() && resp3_auth_cmd_idx.is_some()));
1454
1455    if let Some(index) = resp3_auth_cmd_idx {
1456        let Some(value) = results.get(index) else {
1457            return Err((ErrorKind::Client, "Missing RESP3 auth response").into());
1458        };
1459        check_resp3_auth(value)?;
1460    } else if let Some(index) = resp2_auth_cmd_idx {
1461        let Some(value) = results.get(index) else {
1462            return Err((ErrorKind::Client, "Missing RESP2 auth response").into());
1463        };
1464        if check_resp2_auth(value)? == AuthResult::ShouldRetryWithoutUsername {
1465            return Ok(AuthResult::ShouldRetryWithoutUsername);
1466        }
1467    }
1468
1469    if let Some(index) = select_cmd_idx {
1470        let Some(value) = results.get(index) else {
1471            return Err((ErrorKind::Client, "Missing SELECT DB response").into());
1472        };
1473        check_db_select(value)?;
1474    }
1475
1476    #[cfg(feature = "cache-aio")]
1477    if let Some(index) = cache_cmd_idx {
1478        let Some(value) = results.get(index) else {
1479            return Err((ErrorKind::Client, "Missing Caching response").into());
1480        };
1481        check_caching(value)?;
1482    }
1483
1484    Ok(AuthResult::Succeeded)
1485}
1486
1487fn execute_connection_pipeline(
1488    rv: &mut Connection,
1489    (pipeline, instructions): (crate::Pipeline, ConnectionSetupComponents),
1490) -> RedisResult<AuthResult> {
1491    if pipeline.is_empty() {
1492        return Ok(AuthResult::Succeeded);
1493    }
1494    let results = rv.req_packed_commands(&pipeline.get_packed_pipeline(), 0, pipeline.len())?;
1495
1496    check_connection_setup(results, instructions)
1497}
1498
1499fn setup_connection(
1500    con: ActualConnection,
1501    connection_info: &RedisConnectionInfo,
1502    #[cfg(feature = "cache-aio")] cache_config: Option<crate::caching::CacheConfig>,
1503) -> RedisResult<Connection> {
1504    let mut rv = Connection {
1505        con,
1506        parser: Parser::new(),
1507        db: connection_info.db,
1508        pubsub: false,
1509        protocol: connection_info.protocol,
1510        push_sender: None,
1511        messages_to_skip: 0,
1512    };
1513
1514    if execute_connection_pipeline(
1515        &mut rv,
1516        connection_setup_pipeline(
1517            connection_info,
1518            true,
1519            #[cfg(feature = "cache-aio")]
1520            cache_config,
1521        ),
1522    )? == AuthResult::ShouldRetryWithoutUsername
1523    {
1524        execute_connection_pipeline(
1525            &mut rv,
1526            connection_setup_pipeline(
1527                connection_info,
1528                false,
1529                #[cfg(feature = "cache-aio")]
1530                cache_config,
1531            ),
1532        )?;
1533    }
1534
1535    Ok(rv)
1536}
1537
1538/// Implements the "stateless" part of the connection interface that is used by the
1539/// different objects in redis-rs.
1540///
1541/// Primarily it obviously applies to `Connection` object but also some other objects
1542///  implement the interface (for instance whole clients or certain redis results).
1543///
1544/// Generally clients and connections (as well as redis results of those) implement
1545/// this trait.  Actual connections provide more functionality which can be used
1546/// to implement things like `PubSub` but they also can modify the intrinsic
1547/// state of the TCP connection.  This is not possible with `ConnectionLike`
1548/// implementors because that functionality is not exposed.
1549pub trait ConnectionLike {
1550    /// Sends an already encoded (packed) command into the TCP socket and
1551    /// reads the single response from it.
1552    fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value>;
1553
1554    /// Sends multiple already encoded (packed) command into the TCP socket
1555    /// and reads `count` responses from it.  This is used to implement
1556    /// pipelining.
1557    /// Important - this function is meant for internal usage, since it's
1558    /// easy to pass incorrect `offset` & `count` parameters, which might
1559    /// cause the connection to enter an erroneous state. Users shouldn't
1560    /// call it, instead using the Pipeline::query function.
1561    #[doc(hidden)]
1562    fn req_packed_commands(
1563        &mut self,
1564        cmd: &[u8],
1565        offset: usize,
1566        count: usize,
1567    ) -> RedisResult<Vec<Value>>;
1568
1569    /// Sends a [Cmd] into the TCP socket and reads a single response from it.
1570    fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1571        let pcmd = cmd.get_packed_command();
1572        self.req_packed_command(&pcmd)
1573    }
1574
1575    /// Returns the database this connection is bound to.  Note that this
1576    /// information might be unreliable because it's initially cached and
1577    /// also might be incorrect if the connection like object is not
1578    /// actually connected.
1579    fn get_db(&self) -> i64;
1580
1581    /// Does this connection support pipelining?
1582    #[doc(hidden)]
1583    fn supports_pipelining(&self) -> bool {
1584        true
1585    }
1586
1587    /// Check that all connections it has are available (`PING` internally).
1588    fn check_connection(&mut self) -> bool;
1589
1590    /// Returns the connection status.
1591    ///
1592    /// The connection is open until any `read` call received an
1593    /// invalid response from the server (most likely a closed or dropped
1594    /// connection, otherwise a Redis protocol error). When using unix
1595    /// sockets the connection is open until writing a command failed with a
1596    /// `BrokenPipe` error.
1597    fn is_open(&self) -> bool;
1598}
1599
1600/// A connection is an object that represents a single redis connection.  It
1601/// provides basic support for sending encoded commands into a redis connection
1602/// and to read a response from it.  It's bound to a single database and can
1603/// only be created from the client.
1604///
1605/// You generally do not much with this object other than passing it to
1606/// `Cmd` objects.
1607impl Connection {
1608    /// Sends an already encoded (packed) command into the TCP socket and
1609    /// does not read a response.  This is useful for commands like
1610    /// `MONITOR` which yield multiple items.  This needs to be used with
1611    /// care because it changes the state of the connection.
1612    pub fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()> {
1613        self.send_bytes(cmd)?;
1614        Ok(())
1615    }
1616
1617    /// Fetches a single response from the connection.  This is useful
1618    /// if used in combination with `send_packed_command`.
1619    pub fn recv_response(&mut self) -> RedisResult<Value> {
1620        self.read(true)
1621    }
1622
1623    /// Sets the write timeout for the connection.
1624    ///
1625    /// If the provided value is `None`, then `send_packed_command` call will
1626    /// block indefinitely. It is an error to pass the zero `Duration` to this
1627    /// method.
1628    pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1629        self.con.set_write_timeout(dur)
1630    }
1631
1632    /// Sets the read timeout for the connection.
1633    ///
1634    /// If the provided value is `None`, then `recv_response` call will
1635    /// block indefinitely. It is an error to pass the zero `Duration` to this
1636    /// method.
1637    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1638        self.con.set_read_timeout(dur)
1639    }
1640
1641    /// Creates a [`PubSub`] instance for this connection.
1642    pub fn as_pubsub(&mut self) -> PubSub<'_> {
1643        // NOTE: The pubsub flag is intentionally not raised at this time since
1644        // running commands within the pubsub state should not try and exit from
1645        // the pubsub state.
1646        PubSub::new(self)
1647    }
1648
1649    fn exit_pubsub(&mut self) -> RedisResult<()> {
1650        let res = self.clear_active_subscriptions();
1651        if res.is_ok() {
1652            self.pubsub = false;
1653        } else {
1654            // Raise the pubsub flag to indicate the connection is "stuck" in that state.
1655            self.pubsub = true;
1656        }
1657
1658        res
1659    }
1660
1661    /// Get the inner connection out of a PubSub
1662    ///
1663    /// Any active subscriptions are unsubscribed. In the event of an error, the connection is
1664    /// dropped.
1665    fn clear_active_subscriptions(&mut self) -> RedisResult<()> {
1666        // Responses to unsubscribe commands return in a 3-tuple with values
1667        // ("unsubscribe" or "punsubscribe", name of subscription removed, count of remaining subs).
1668        // The "count of remaining subs" includes both pattern subscriptions and non pattern
1669        // subscriptions. Thus, to accurately drain all unsubscribe messages received from the
1670        // server, both commands need to be executed at once.
1671        {
1672            // Prepare both unsubscribe commands
1673            let unsubscribe = cmd("UNSUBSCRIBE").get_packed_command();
1674            let punsubscribe = cmd("PUNSUBSCRIBE").get_packed_command();
1675
1676            // Execute commands
1677            self.send_bytes(&unsubscribe)?;
1678            self.send_bytes(&punsubscribe)?;
1679        }
1680
1681        // Receive responses
1682        //
1683        // There will be at minimum two responses - 1 for each of punsubscribe and unsubscribe
1684        // commands. There may be more responses if there are active subscriptions. In this case,
1685        // messages are received until the _subscription count_ in the responses reach zero.
1686        let mut received_unsub = false;
1687        let mut received_punsub = false;
1688
1689        loop {
1690            let resp = self.recv_response()?;
1691
1692            match resp {
1693                Value::Push { kind, data } => {
1694                    if data.len() >= 2 {
1695                        if let Value::Int(num) = data[1] {
1696                            if resp3_is_pub_sub_state_cleared(
1697                                &mut received_unsub,
1698                                &mut received_punsub,
1699                                &kind,
1700                                num as isize,
1701                            ) {
1702                                break;
1703                            }
1704                        }
1705                    }
1706                }
1707                Value::ServerError(err) => {
1708                    // a new error behavior, introduced in valkey 8.
1709                    // https://github.com/valkey-io/valkey/pull/759
1710                    if err.kind() == Some(ServerErrorKind::NoSub) {
1711                        if no_sub_err_is_pub_sub_state_cleared(
1712                            &mut received_unsub,
1713                            &mut received_punsub,
1714                            &err,
1715                        ) {
1716                            break;
1717                        } else {
1718                            continue;
1719                        }
1720                    }
1721
1722                    return Err(err.into());
1723                }
1724                Value::Array(vec) => {
1725                    let res: (Vec<u8>, (), isize) = from_redis_value(Value::Array(vec))?;
1726                    if resp2_is_pub_sub_state_cleared(
1727                        &mut received_unsub,
1728                        &mut received_punsub,
1729                        &res.0,
1730                        res.2,
1731                    ) {
1732                        break;
1733                    }
1734                }
1735                _ => {
1736                    return Err((
1737                        ErrorKind::Client,
1738                        "Unexpected unsubscribe response",
1739                        format!("{resp:?}"),
1740                    )
1741                        .into());
1742                }
1743            }
1744        }
1745
1746        // Finally, the connection is back in its normal state since all subscriptions were
1747        // cancelled *and* all unsubscribe messages were received.
1748        Ok(())
1749    }
1750
1751    fn send_push(&self, push: PushInfo) {
1752        if let Some(sender) = &self.push_sender {
1753            let _ = sender.send(push);
1754        }
1755    }
1756
1757    fn try_send(&self, value: &RedisResult<Value>) {
1758        if let Ok(Value::Push { kind, data }) = value {
1759            self.send_push(PushInfo {
1760                kind: kind.clone(),
1761                data: data.clone(),
1762            });
1763        }
1764    }
1765
1766    fn send_disconnect(&self) {
1767        self.send_push(PushInfo::disconnect())
1768    }
1769
1770    fn close_connection(&mut self) {
1771        // Notify the PushManager that the connection was lost
1772        self.send_disconnect();
1773        match self.con {
1774            ActualConnection::Tcp(ref mut connection) => {
1775                let _ = connection.reader.shutdown(net::Shutdown::Both);
1776                connection.open = false;
1777            }
1778            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1779            ActualConnection::TcpNativeTls(ref mut connection) => {
1780                let _ = connection.reader.shutdown();
1781                connection.open = false;
1782            }
1783            #[cfg(feature = "tls-rustls")]
1784            ActualConnection::TcpRustls(ref mut connection) => {
1785                let _ = connection.reader.get_mut().shutdown(net::Shutdown::Both);
1786                connection.open = false;
1787            }
1788            #[cfg(unix)]
1789            ActualConnection::Unix(ref mut connection) => {
1790                let _ = connection.sock.shutdown(net::Shutdown::Both);
1791                connection.open = false;
1792            }
1793        }
1794    }
1795
1796    /// Fetches a single message from the connection. If the message is a response,
1797    /// increment `messages_to_skip` if it wasn't received before a timeout.
1798    fn read(&mut self, is_response: bool) -> RedisResult<Value> {
1799        loop {
1800            let result = match self.con {
1801                ActualConnection::Tcp(TcpConnection { ref mut reader, .. }) => {
1802                    self.parser.parse_value(reader)
1803                }
1804                #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1805                ActualConnection::TcpNativeTls(ref mut boxed_tls_connection) => {
1806                    let reader = &mut boxed_tls_connection.reader;
1807                    self.parser.parse_value(reader)
1808                }
1809                #[cfg(feature = "tls-rustls")]
1810                ActualConnection::TcpRustls(ref mut boxed_tls_connection) => {
1811                    let reader = &mut boxed_tls_connection.reader;
1812                    self.parser.parse_value(reader)
1813                }
1814                #[cfg(unix)]
1815                ActualConnection::Unix(UnixConnection { ref mut sock, .. }) => {
1816                    self.parser.parse_value(sock)
1817                }
1818            };
1819            self.try_send(&result);
1820
1821            let Err(err) = &result else {
1822                if self.messages_to_skip > 0 {
1823                    self.messages_to_skip -= 1;
1824                    continue;
1825                }
1826                return result;
1827            };
1828            let Some(io_error) = err.as_io_error() else {
1829                if self.messages_to_skip > 0 {
1830                    self.messages_to_skip -= 1;
1831                    continue;
1832                }
1833                return result;
1834            };
1835            // shutdown connection on protocol error
1836            if io_error.kind() == io::ErrorKind::UnexpectedEof {
1837                self.close_connection();
1838            } else if is_response {
1839                self.messages_to_skip += 1;
1840            }
1841
1842            return result;
1843        }
1844    }
1845
1846    /// Sets sender channel for push values.
1847    pub fn set_push_sender(&mut self, sender: SyncPushSender) {
1848        self.push_sender = Some(sender);
1849    }
1850
1851    fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult<Value> {
1852        if bytes.is_empty() {
1853            return Err(RedisError::make_empty_command());
1854        }
1855        let result = self.con.send_bytes(bytes);
1856        if self.protocol.supports_resp3() {
1857            if let Err(e) = &result {
1858                if e.is_connection_dropped() {
1859                    self.send_disconnect();
1860                }
1861            }
1862        }
1863        result
1864    }
1865
1866    /// Subscribes to a new channel(s).
1867    ///
1868    /// This only works if the connection was configured with [ProtocolVersion::RESP3] and [Self::set_push_sender].
1869    pub fn subscribe_resp3<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1870        check_resp3!(self.protocol);
1871        cmd("SUBSCRIBE")
1872            .arg(channel)
1873            .set_no_response(true)
1874            .exec(self)
1875    }
1876
1877    /// Subscribes to new channel(s) with pattern(s).
1878    ///
1879    /// This only works if the connection was configured with [ProtocolVersion::RESP3] and [Self::set_push_sender].
1880    pub fn psubscribe_resp3<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1881        check_resp3!(self.protocol);
1882        cmd("PSUBSCRIBE")
1883            .arg(pchannel)
1884            .set_no_response(true)
1885            .exec(self)
1886    }
1887
1888    /// Unsubscribes from a channel(s).
1889    ///
1890    /// This only works if the connection was configured with [ProtocolVersion::RESP3] and [Self::set_push_sender].
1891    pub fn unsubscribe_resp3<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1892        check_resp3!(self.protocol);
1893        cmd("UNSUBSCRIBE")
1894            .arg(channel)
1895            .set_no_response(true)
1896            .exec(self)
1897    }
1898
1899    /// Unsubscribes from channel pattern(s).
1900    ///
1901    /// This only works if the connection was configured with [ProtocolVersion::RESP3] and [Self::set_push_sender].
1902    pub fn punsubscribe_resp3<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1903        check_resp3!(self.protocol);
1904        cmd("PUNSUBSCRIBE")
1905            .arg(pchannel)
1906            .set_no_response(true)
1907            .exec(self)
1908    }
1909}
1910
1911impl ConnectionLike for Connection {
1912    /// Sends a [Cmd] into the TCP socket and reads a single response from it.
1913    fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1914        let pcmd = cmd.get_packed_command();
1915        if self.pubsub {
1916            self.exit_pubsub()?;
1917        }
1918
1919        self.send_bytes(&pcmd)?;
1920        if cmd.is_no_response() {
1921            return Ok(Value::Nil);
1922        }
1923        loop {
1924            match self.read(true)? {
1925                Value::Push {
1926                    kind: _kind,
1927                    data: _data,
1928                } => continue,
1929                val => return Ok(val),
1930            }
1931        }
1932    }
1933    fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
1934        if self.pubsub {
1935            self.exit_pubsub()?;
1936        }
1937
1938        self.send_bytes(cmd)?;
1939        loop {
1940            match self.read(true)? {
1941                Value::Push {
1942                    kind: _kind,
1943                    data: _data,
1944                } => continue,
1945                val => return Ok(val),
1946            }
1947        }
1948    }
1949
1950    fn req_packed_commands(
1951        &mut self,
1952        cmd: &[u8],
1953        offset: usize,
1954        count: usize,
1955    ) -> RedisResult<Vec<Value>> {
1956        if self.pubsub {
1957            self.exit_pubsub()?;
1958        }
1959        self.send_bytes(cmd)?;
1960        let mut rv = vec![];
1961        let mut first_err = None;
1962        let mut server_errors = vec![];
1963        let mut count = count;
1964        let mut idx = 0;
1965        while idx < (offset + count) {
1966            // When processing a transaction, some responses may be errors.
1967            // We need to keep processing the rest of the responses in that case,
1968            // so bailing early with `?` would not be correct.
1969            // See: https://github.com/redis-rs/redis-rs/issues/436
1970            let response = self.read(true);
1971            match response {
1972                Ok(Value::ServerError(err)) => {
1973                    if idx < offset {
1974                        server_errors.push((idx - 1, err)); // -1, to offset the added MULTI call.
1975                    } else {
1976                        rv.push(Value::ServerError(err));
1977                    }
1978                }
1979                Ok(item) => {
1980                    // RESP3 can insert push data between command replies
1981                    if let Value::Push {
1982                        kind: _kind,
1983                        data: _data,
1984                    } = item
1985                    {
1986                        // if that is the case we have to extend the loop and handle push data
1987                        count += 1;
1988                    } else if idx >= offset {
1989                        rv.push(item);
1990                    }
1991                }
1992                Err(err) => {
1993                    if first_err.is_none() {
1994                        first_err = Some(err);
1995                    }
1996                }
1997            }
1998            idx += 1;
1999        }
2000
2001        if !server_errors.is_empty() {
2002            return Err(RedisError::make_aborted_transaction(server_errors));
2003        }
2004
2005        first_err.map_or(Ok(rv), Err)
2006    }
2007
2008    fn get_db(&self) -> i64 {
2009        self.db
2010    }
2011
2012    fn check_connection(&mut self) -> bool {
2013        cmd("PING").query::<String>(self).is_ok()
2014    }
2015
2016    fn is_open(&self) -> bool {
2017        self.con.is_open()
2018    }
2019}
2020
2021impl<C, T> ConnectionLike for T
2022where
2023    C: ConnectionLike,
2024    T: DerefMut<Target = C>,
2025{
2026    fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
2027        self.deref_mut().req_packed_command(cmd)
2028    }
2029
2030    fn req_packed_commands(
2031        &mut self,
2032        cmd: &[u8],
2033        offset: usize,
2034        count: usize,
2035    ) -> RedisResult<Vec<Value>> {
2036        self.deref_mut().req_packed_commands(cmd, offset, count)
2037    }
2038
2039    fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
2040        self.deref_mut().req_command(cmd)
2041    }
2042
2043    fn get_db(&self) -> i64 {
2044        self.deref().get_db()
2045    }
2046
2047    fn supports_pipelining(&self) -> bool {
2048        self.deref().supports_pipelining()
2049    }
2050
2051    fn check_connection(&mut self) -> bool {
2052        self.deref_mut().check_connection()
2053    }
2054
2055    fn is_open(&self) -> bool {
2056        self.deref().is_open()
2057    }
2058}
2059
2060/// The pubsub object provides convenient access to the redis pubsub
2061/// system.  Once created you can subscribe and unsubscribe from channels
2062/// and listen in on messages.
2063///
2064/// Example:
2065///
2066/// ```rust,no_run
2067/// # fn do_something() -> redis::RedisResult<()> {
2068/// let client = redis::Client::open("redis://127.0.0.1/")?;
2069/// let mut con = client.get_connection()?;
2070/// let mut pubsub = con.as_pubsub();
2071/// pubsub.subscribe("channel_1")?;
2072/// pubsub.subscribe("channel_2")?;
2073///
2074/// loop {
2075///     let msg = pubsub.get_message()?;
2076///     let payload : String = msg.get_payload()?;
2077///     println!("channel '{}': {}", msg.get_channel_name(), payload);
2078/// }
2079/// # }
2080/// ```
2081impl<'a> PubSub<'a> {
2082    fn new(con: &'a mut Connection) -> Self {
2083        Self {
2084            con,
2085            waiting_messages: VecDeque::new(),
2086        }
2087    }
2088
2089    fn cache_messages_until_received_response(
2090        &mut self,
2091        cmd: &mut Cmd,
2092        is_sub_unsub: bool,
2093    ) -> RedisResult<Value> {
2094        let ignore_response = self.con.protocol.supports_resp3() && is_sub_unsub;
2095        cmd.set_no_response(ignore_response);
2096
2097        self.con.send_packed_command(&cmd.get_packed_command())?;
2098
2099        loop {
2100            let response = self.con.recv_response()?;
2101            if let Some(msg) = Msg::from_value(&response) {
2102                self.waiting_messages.push_back(msg);
2103            } else {
2104                return Ok(response);
2105            }
2106        }
2107    }
2108
2109    /// Subscribes to a new channel(s).
2110    pub fn subscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
2111        self.cache_messages_until_received_response(cmd("SUBSCRIBE").arg(channel), true)?;
2112        Ok(())
2113    }
2114
2115    /// Subscribes to new channel(s) with pattern(s).
2116    pub fn psubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
2117        self.cache_messages_until_received_response(cmd("PSUBSCRIBE").arg(pchannel), true)?;
2118        Ok(())
2119    }
2120
2121    /// Unsubscribes from a channel(s).
2122    pub fn unsubscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
2123        self.cache_messages_until_received_response(cmd("UNSUBSCRIBE").arg(channel), true)?;
2124        Ok(())
2125    }
2126
2127    /// Unsubscribes from channel pattern(s).
2128    pub fn punsubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
2129        self.cache_messages_until_received_response(cmd("PUNSUBSCRIBE").arg(pchannel), true)?;
2130        Ok(())
2131    }
2132
2133    /// Sends a ping with a message to the server
2134    pub fn ping_message<T: FromRedisValue>(&mut self, message: impl ToRedisArgs) -> RedisResult<T> {
2135        Ok(from_redis_value(
2136            self.cache_messages_until_received_response(cmd("PING").arg(message), false)?,
2137        )?)
2138    }
2139    /// Sends a ping to the server
2140    pub fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
2141        Ok(from_redis_value(
2142            self.cache_messages_until_received_response(&mut cmd("PING"), false)?,
2143        )?)
2144    }
2145
2146    /// Fetches the next message from the pubsub connection.  Blocks until
2147    /// a message becomes available.  This currently does not provide a
2148    /// wait not to block :(
2149    ///
2150    /// The message itself is still generic and can be converted into an
2151    /// appropriate type through the helper methods on it.
2152    pub fn get_message(&mut self) -> RedisResult<Msg> {
2153        if let Some(msg) = self.waiting_messages.pop_front() {
2154            return Ok(msg);
2155        }
2156        loop {
2157            if let Some(msg) = Msg::from_owned_value(self.con.read(false)?) {
2158                return Ok(msg);
2159            } else {
2160                continue;
2161            }
2162        }
2163    }
2164
2165    /// Sets the read timeout for the connection.
2166    ///
2167    /// If the provided value is `None`, then `get_message` call will
2168    /// block indefinitely. It is an error to pass the zero `Duration` to this
2169    /// method.
2170    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
2171        self.con.set_read_timeout(dur)
2172    }
2173}
2174
2175impl Drop for PubSub<'_> {
2176    fn drop(&mut self) {
2177        let _ = self.con.exit_pubsub();
2178    }
2179}
2180
2181/// This holds the data that comes from listening to a pubsub
2182/// connection.  It only contains actual message data.
2183impl Msg {
2184    /// Tries to convert provided [`Value`] into [`Msg`].
2185    pub fn from_value(value: &Value) -> Option<Self> {
2186        Self::from_owned_value(value.clone())
2187    }
2188
2189    /// Tries to convert provided [`Value`] into [`Msg`].
2190    pub fn from_owned_value(value: Value) -> Option<Self> {
2191        let mut pattern = None;
2192        let payload;
2193        let channel;
2194
2195        if let Value::Push { kind, data } = value {
2196            return Self::from_push_info(PushInfo { kind, data });
2197        } else {
2198            let raw_msg: Vec<Value> = from_redis_value(value).ok()?;
2199            let mut iter = raw_msg.into_iter();
2200            let msg_type: String = from_redis_value(iter.next()?).ok()?;
2201            if msg_type == "message" {
2202                channel = iter.next()?;
2203                payload = iter.next()?;
2204            } else if msg_type == "pmessage" {
2205                pattern = Some(iter.next()?);
2206                channel = iter.next()?;
2207                payload = iter.next()?;
2208            } else {
2209                return None;
2210            }
2211        };
2212        Some(Msg {
2213            payload,
2214            channel,
2215            pattern,
2216        })
2217    }
2218
2219    /// Tries to convert provided [`PushInfo`] into [`Msg`].
2220    pub fn from_push_info(push_info: PushInfo) -> Option<Self> {
2221        let mut pattern = None;
2222        let payload;
2223        let channel;
2224
2225        let mut iter = push_info.data.into_iter();
2226        if push_info.kind == PushKind::Message || push_info.kind == PushKind::SMessage {
2227            channel = iter.next()?;
2228            payload = iter.next()?;
2229        } else if push_info.kind == PushKind::PMessage {
2230            pattern = Some(iter.next()?);
2231            channel = iter.next()?;
2232            payload = iter.next()?;
2233        } else {
2234            return None;
2235        }
2236
2237        Some(Msg {
2238            payload,
2239            channel,
2240            pattern,
2241        })
2242    }
2243
2244    /// Returns the channel this message came on.
2245    pub fn get_channel<T: FromRedisValue>(&self) -> RedisResult<T> {
2246        Ok(from_redis_value_ref(&self.channel)?)
2247    }
2248
2249    /// Convenience method to get a string version of the channel.  Unless
2250    /// your channel contains non utf-8 bytes you can always use this
2251    /// method.  If the channel is not a valid string (which really should
2252    /// not happen) then the return value is `"?"`.
2253    pub fn get_channel_name(&self) -> &str {
2254        match self.channel {
2255            Value::BulkString(ref bytes) => from_utf8(bytes).unwrap_or("?"),
2256            _ => "?",
2257        }
2258    }
2259
2260    /// Returns the message's payload in a specific format.
2261    pub fn get_payload<T: FromRedisValue>(&self) -> RedisResult<T> {
2262        Ok(from_redis_value_ref(&self.payload)?)
2263    }
2264
2265    /// Returns the bytes that are the message's payload.  This can be used
2266    /// as an alternative to the `get_payload` function if you are interested
2267    /// in the raw bytes in it.
2268    pub fn get_payload_bytes(&self) -> &[u8] {
2269        match self.payload {
2270            Value::BulkString(ref bytes) => bytes,
2271            _ => b"",
2272        }
2273    }
2274
2275    /// Returns true if the message was constructed from a pattern
2276    /// subscription.
2277    #[allow(clippy::wrong_self_convention)]
2278    pub fn from_pattern(&self) -> bool {
2279        self.pattern.is_some()
2280    }
2281
2282    /// If the message was constructed from a message pattern this can be
2283    /// used to find out which one.  It's recommended to match against
2284    /// an `Option<String>` so that you do not need to use `from_pattern`
2285    /// to figure out if a pattern was set.
2286    pub fn get_pattern<T: FromRedisValue>(&self) -> RedisResult<T> {
2287        Ok(match self.pattern {
2288            None => from_redis_value_ref(&Value::Nil),
2289            Some(ref x) => from_redis_value_ref(x),
2290        }?)
2291    }
2292}
2293
2294/// This function simplifies transaction management slightly.  What it
2295/// does is automatically watching keys and then going into a transaction
2296/// loop util it succeeds.  Once it goes through the results are
2297/// returned.
2298///
2299/// To use the transaction two pieces of information are needed: a list
2300/// of all the keys that need to be watched for modifications and a
2301/// closure with the code that should be execute in the context of the
2302/// transaction.  The closure is invoked with a fresh pipeline in atomic
2303/// mode.  To use the transaction the function needs to return the result
2304/// from querying the pipeline with the connection.
2305///
2306/// The end result of the transaction is then available as the return
2307/// value from the function call.
2308///
2309/// Example:
2310///
2311/// ```rust,no_run
2312/// use redis::Commands;
2313/// # fn do_something() -> redis::RedisResult<()> {
2314/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap();
2315/// # let mut con = client.get_connection().unwrap();
2316/// let key = "the_key";
2317/// let (new_val,) : (isize,) = redis::transaction(&mut con, &[key], |con, pipe| {
2318///     let old_val : isize = con.get(key)?;
2319///     pipe
2320///         .set(key, old_val + 1).ignore()
2321///         .get(key).query(con)
2322/// })?;
2323/// println!("The incremented number is: {}", new_val);
2324/// # Ok(()) }
2325/// ```
2326pub fn transaction<
2327    C: ConnectionLike,
2328    K: ToRedisArgs,
2329    T,
2330    F: FnMut(&mut C, &mut Pipeline) -> RedisResult<Option<T>>,
2331>(
2332    con: &mut C,
2333    keys: &[K],
2334    func: F,
2335) -> RedisResult<T> {
2336    let mut func = func;
2337    loop {
2338        cmd("WATCH").arg(keys).exec(con)?;
2339        let mut p = pipe();
2340        let response: Option<T> = func(con, p.atomic())?;
2341        match response {
2342            None => {
2343                continue;
2344            }
2345            Some(response) => {
2346                // make sure no watch is left in the connection, even if
2347                // someone forgot to use the pipeline.
2348                cmd("UNWATCH").exec(con)?;
2349                return Ok(response);
2350            }
2351        }
2352    }
2353}
2354//TODO: for both clearing logic support sharded channels.
2355
2356/// Common logic for clearing subscriptions in RESP2 async/sync
2357pub fn resp2_is_pub_sub_state_cleared(
2358    received_unsub: &mut bool,
2359    received_punsub: &mut bool,
2360    kind: &[u8],
2361    num: isize,
2362) -> bool {
2363    match kind.first() {
2364        Some(&b'u') => *received_unsub = true,
2365        Some(&b'p') => *received_punsub = true,
2366        _ => (),
2367    };
2368    *received_unsub && *received_punsub && num == 0
2369}
2370
2371/// Common logic for clearing subscriptions in RESP3 async/sync
2372pub fn resp3_is_pub_sub_state_cleared(
2373    received_unsub: &mut bool,
2374    received_punsub: &mut bool,
2375    kind: &PushKind,
2376    num: isize,
2377) -> bool {
2378    match kind {
2379        PushKind::Unsubscribe => *received_unsub = true,
2380        PushKind::PUnsubscribe => *received_punsub = true,
2381        _ => (),
2382    };
2383    *received_unsub && *received_punsub && num == 0
2384}
2385
2386pub fn no_sub_err_is_pub_sub_state_cleared(
2387    received_unsub: &mut bool,
2388    received_punsub: &mut bool,
2389    err: &ServerError,
2390) -> bool {
2391    let details = err.details();
2392    *received_unsub = *received_unsub
2393        || details
2394            .map(|details| details.starts_with("'unsub"))
2395            .unwrap_or_default();
2396    *received_punsub = *received_punsub
2397        || details
2398            .map(|details| details.starts_with("'punsub"))
2399            .unwrap_or_default();
2400    *received_unsub && *received_punsub
2401}
2402
2403/// Common logic for checking real cause of hello3 command error
2404pub fn get_resp3_hello_command_error(err: RedisError) -> RedisError {
2405    if let Some(detail) = err.detail() {
2406        if detail.starts_with("unknown command `HELLO`") {
2407            return (
2408                ErrorKind::RESP3NotSupported,
2409                "Redis Server doesn't support HELLO command therefore resp3 cannot be used",
2410            )
2411                .into();
2412        }
2413    }
2414    err
2415}
2416
2417#[cfg(test)]
2418mod tests {
2419    use super::*;
2420
2421    #[test]
2422    fn test_parse_redis_url() {
2423        let cases = vec![
2424            ("redis://127.0.0.1", true),
2425            ("redis://[::1]", true),
2426            ("rediss://127.0.0.1", true),
2427            ("rediss://[::1]", true),
2428            ("valkey://127.0.0.1", true),
2429            ("valkey://[::1]", true),
2430            ("valkeys://127.0.0.1", true),
2431            ("valkeys://[::1]", true),
2432            ("redis+unix:///run/redis.sock", true),
2433            ("valkey+unix:///run/valkey.sock", true),
2434            ("unix:///run/redis.sock", true),
2435            ("http://127.0.0.1", false),
2436            ("tcp://127.0.0.1", false),
2437        ];
2438        for (url, expected) in cases.into_iter() {
2439            let res = parse_redis_url(url);
2440            assert_eq!(
2441                res.is_some(),
2442                expected,
2443                "Parsed result of `{url}` is not expected",
2444            );
2445        }
2446    }
2447
2448    #[test]
2449    fn test_url_to_tcp_connection_info() {
2450        let cases = vec![
2451            (
2452                url::Url::parse("redis://127.0.0.1").unwrap(),
2453                ConnectionInfo {
2454                    addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2455                    redis: Default::default(),
2456                    tcp_settings: TcpSettings::default(),
2457                },
2458            ),
2459            (
2460                url::Url::parse("redis://[::1]").unwrap(),
2461                ConnectionInfo {
2462                    addr: ConnectionAddr::Tcp("::1".to_string(), 6379),
2463                    redis: Default::default(),
2464                    tcp_settings: TcpSettings::default(),
2465                },
2466            ),
2467            (
2468                url::Url::parse("redis://%25johndoe%25:%23%40%3C%3E%24@example.com/2").unwrap(),
2469                ConnectionInfo {
2470                    addr: ConnectionAddr::Tcp("example.com".to_string(), 6379),
2471                    redis: RedisConnectionInfo {
2472                        db: 2,
2473                        username: Some("%johndoe%".into()),
2474                        password: Some("#@<>$".into()),
2475                        ..Default::default()
2476                    },
2477                    tcp_settings: TcpSettings::default(),
2478                },
2479            ),
2480            (
2481                url::Url::parse("redis://127.0.0.1/?protocol=2").unwrap(),
2482                ConnectionInfo {
2483                    addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2484                    redis: Default::default(),
2485                    tcp_settings: TcpSettings::default(),
2486                },
2487            ),
2488            (
2489                url::Url::parse("redis://127.0.0.1/?protocol=resp3").unwrap(),
2490                ConnectionInfo {
2491                    addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2492                    redis: RedisConnectionInfo {
2493                        protocol: ProtocolVersion::RESP3,
2494                        ..Default::default()
2495                    },
2496                    tcp_settings: TcpSettings::default(),
2497                },
2498            ),
2499        ];
2500        for (url, expected) in cases.into_iter() {
2501            let res = url_to_tcp_connection_info(url.clone()).unwrap();
2502            assert_eq!(res.addr, expected.addr, "addr of {url} is not expected");
2503            assert_eq!(
2504                res.redis.db, expected.redis.db,
2505                "db of {url} is not expected",
2506            );
2507            assert_eq!(
2508                res.redis.username, expected.redis.username,
2509                "username of {url} is not expected",
2510            );
2511            assert_eq!(
2512                res.redis.password, expected.redis.password,
2513                "password of {url} is not expected",
2514            );
2515        }
2516    }
2517
2518    #[test]
2519    fn test_url_to_tcp_connection_info_failed() {
2520        let cases = vec![
2521            (
2522                url::Url::parse("redis://").unwrap(),
2523                "Missing hostname",
2524                None,
2525            ),
2526            (
2527                url::Url::parse("redis://127.0.0.1/db").unwrap(),
2528                "Invalid database number",
2529                None,
2530            ),
2531            (
2532                url::Url::parse("redis://C3%B0@127.0.0.1").unwrap(),
2533                "Username is not valid UTF-8 string",
2534                None,
2535            ),
2536            (
2537                url::Url::parse("redis://:C3%B0@127.0.0.1").unwrap(),
2538                "Password is not valid UTF-8 string",
2539                None,
2540            ),
2541            (
2542                url::Url::parse("redis://127.0.0.1/?protocol=4").unwrap(),
2543                "Invalid protocol version",
2544                Some("4"),
2545            ),
2546        ];
2547        for (url, expected, detail) in cases.into_iter() {
2548            let res = url_to_tcp_connection_info(url).unwrap_err();
2549            assert_eq!(res.kind(), crate::ErrorKind::InvalidClientConfig,);
2550            let desc = res.to_string();
2551            assert!(desc.contains(expected), "{desc}");
2552            assert_eq!(res.detail(), detail);
2553        }
2554    }
2555
2556    #[test]
2557    #[cfg(unix)]
2558    fn test_url_to_unix_connection_info() {
2559        let cases = vec![
2560            (
2561                url::Url::parse("unix:///var/run/redis.sock").unwrap(),
2562                ConnectionInfo {
2563                    addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2564                    redis: RedisConnectionInfo {
2565                        db: 0,
2566                        username: None,
2567                        password: None,
2568                        protocol: ProtocolVersion::RESP2,
2569                        skip_set_lib_name: false,
2570                    },
2571                    tcp_settings: Default::default(),
2572                },
2573            ),
2574            (
2575                url::Url::parse("redis+unix:///var/run/redis.sock?db=1").unwrap(),
2576                ConnectionInfo {
2577                    addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2578                    redis: RedisConnectionInfo {
2579                        db: 1,
2580                        ..Default::default()
2581                    },
2582                    tcp_settings: TcpSettings::default(),
2583                },
2584            ),
2585            (
2586                url::Url::parse(
2587                    "unix:///example.sock?user=%25johndoe%25&pass=%23%40%3C%3E%24&db=2",
2588                )
2589                .unwrap(),
2590                ConnectionInfo {
2591                    addr: ConnectionAddr::Unix("/example.sock".into()),
2592                    redis: RedisConnectionInfo {
2593                        db: 2,
2594                        username: Some("%johndoe%".into()),
2595                        password: Some("#@<>$".into()),
2596                        ..Default::default()
2597                    },
2598                    tcp_settings: TcpSettings::default(),
2599                },
2600            ),
2601            (
2602                url::Url::parse(
2603                    "redis+unix:///example.sock?pass=%26%3F%3D+%2A%2B&db=2&user=%25johndoe%25",
2604                )
2605                .unwrap(),
2606                ConnectionInfo {
2607                    addr: ConnectionAddr::Unix("/example.sock".into()),
2608                    redis: RedisConnectionInfo {
2609                        db: 2,
2610                        username: Some("%johndoe%".into()),
2611                        password: Some("&?= *+".into()),
2612                        ..Default::default()
2613                    },
2614                    tcp_settings: TcpSettings::default(),
2615                },
2616            ),
2617            (
2618                url::Url::parse("redis+unix:///var/run/redis.sock?protocol=3").unwrap(),
2619                ConnectionInfo {
2620                    addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2621                    redis: RedisConnectionInfo {
2622                        protocol: ProtocolVersion::RESP3,
2623                        ..Default::default()
2624                    },
2625                    tcp_settings: TcpSettings::default(),
2626                },
2627            ),
2628        ];
2629        for (url, expected) in cases.into_iter() {
2630            assert_eq!(
2631                ConnectionAddr::Unix(url.to_file_path().unwrap()),
2632                expected.addr,
2633                "addr of {url} is not expected",
2634            );
2635            let res = url_to_unix_connection_info(url.clone()).unwrap();
2636            assert_eq!(res.addr, expected.addr, "addr of {url} is not expected");
2637            assert_eq!(
2638                res.redis.db, expected.redis.db,
2639                "db of {url} is not expected",
2640            );
2641            assert_eq!(
2642                res.redis.username, expected.redis.username,
2643                "username of {url} is not expected",
2644            );
2645            assert_eq!(
2646                res.redis.password, expected.redis.password,
2647                "password of {url} is not expected",
2648            );
2649        }
2650    }
2651}