1use std::borrow::Cow;
2use std::collections::VecDeque;
3use std::fmt;
4use std::io::{self, Write};
5use std::net::{self, SocketAddr, TcpStream, ToSocketAddrs};
6use std::ops::DerefMut;
7use std::path::PathBuf;
8use std::str::{from_utf8, FromStr};
9use std::time::{Duration, Instant};
10
11use crate::cmd::{cmd, pipe, Cmd};
12use crate::io::tcp::{stream_with_settings, TcpSettings};
13use crate::parser::Parser;
14use crate::pipeline::Pipeline;
15use crate::types::{
16 from_redis_value, ErrorKind, FromRedisValue, HashMap, PushKind, RedisError, RedisResult,
17 ServerError, ServerErrorKind, SyncPushSender, ToRedisArgs, Value,
18};
19use crate::{check_resp3, from_owned_redis_value, ProtocolVersion};
20
21#[cfg(unix)]
22use std::os::unix::net::UnixStream;
23
24use crate::commands::resp3_hello;
25#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
26use native_tls::{TlsConnector, TlsStream};
27
28#[cfg(feature = "tls-rustls")]
29use rustls::{RootCertStore, StreamOwned};
30#[cfg(feature = "tls-rustls")]
31use std::sync::Arc;
32
33use crate::PushInfo;
34
35#[cfg(all(
36 feature = "tls-rustls",
37 not(feature = "tls-native-tls"),
38 not(feature = "tls-rustls-webpki-roots")
39))]
40use rustls_native_certs::load_native_certs;
41
42#[cfg(feature = "tls-rustls")]
43use crate::tls::ClientTlsParams;
44
45#[derive(Clone, Debug)]
47#[non_exhaustive]
48pub struct TlsConnParams {
49 #[cfg(feature = "tls-rustls")]
50 pub(crate) client_tls_params: Option<ClientTlsParams>,
51 #[cfg(feature = "tls-rustls")]
52 pub(crate) root_cert_store: Option<RootCertStore>,
53 #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
54 pub(crate) danger_accept_invalid_hostnames: bool,
55}
56
57static DEFAULT_PORT: u16 = 6379;
58
59#[inline(always)]
60fn connect_tcp(addr: (&str, u16)) -> io::Result<TcpStream> {
61 let socket = TcpStream::connect(addr)?;
62 stream_with_settings(socket, &TcpSettings::default())
63}
64
65#[inline(always)]
66fn connect_tcp_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
67 let socket = TcpStream::connect_timeout(addr, timeout)?;
68 stream_with_settings(socket, &TcpSettings::default())
69}
70
71pub fn parse_redis_url(input: &str) -> Option<url::Url> {
76 match url::Url::parse(input) {
77 Ok(result) => match result.scheme() {
78 "redis" | "rediss" | "valkey" | "valkeys" | "redis+unix" | "valkey+unix" | "unix" => {
79 Some(result)
80 }
81 _ => None,
82 },
83 Err(_) => None,
84 }
85}
86
87#[derive(Clone, Copy, PartialEq)]
91pub enum TlsMode {
92 Secure,
94 Insecure,
96}
97
98#[derive(Clone, Debug)]
104pub enum ConnectionAddr {
105 Tcp(String, u16),
107 TcpTls {
109 host: String,
111 port: u16,
113 insecure: bool,
122
123 tls_params: Option<TlsConnParams>,
125 },
126 Unix(PathBuf),
128}
129
130impl PartialEq for ConnectionAddr {
131 fn eq(&self, other: &Self) -> bool {
132 match (self, other) {
133 (ConnectionAddr::Tcp(host1, port1), ConnectionAddr::Tcp(host2, port2)) => {
134 host1 == host2 && port1 == port2
135 }
136 (
137 ConnectionAddr::TcpTls {
138 host: host1,
139 port: port1,
140 insecure: insecure1,
141 tls_params: _,
142 },
143 ConnectionAddr::TcpTls {
144 host: host2,
145 port: port2,
146 insecure: insecure2,
147 tls_params: _,
148 },
149 ) => port1 == port2 && host1 == host2 && insecure1 == insecure2,
150 (ConnectionAddr::Unix(path1), ConnectionAddr::Unix(path2)) => path1 == path2,
151 _ => false,
152 }
153 }
154}
155
156impl Eq for ConnectionAddr {}
157
158impl ConnectionAddr {
159 pub fn is_supported(&self) -> bool {
170 match *self {
171 ConnectionAddr::Tcp(_, _) => true,
172 ConnectionAddr::TcpTls { .. } => {
173 cfg!(any(feature = "tls-native-tls", feature = "tls-rustls"))
174 }
175 ConnectionAddr::Unix(_) => cfg!(unix),
176 }
177 }
178
179 #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
188 pub fn set_danger_accept_invalid_hostnames(&mut self, insecure: bool) {
189 if let ConnectionAddr::TcpTls { tls_params, .. } = self {
190 if let Some(ref mut params) = tls_params {
191 params.danger_accept_invalid_hostnames = insecure;
192 } else if insecure {
193 *tls_params = Some(TlsConnParams {
194 #[cfg(feature = "tls-rustls")]
195 client_tls_params: None,
196 #[cfg(feature = "tls-rustls")]
197 root_cert_store: None,
198 danger_accept_invalid_hostnames: insecure,
199 });
200 }
201 }
202 }
203
204 #[cfg(feature = "cluster")]
205 pub(crate) fn tls_mode(&self) -> Option<TlsMode> {
206 match self {
207 ConnectionAddr::TcpTls { insecure, .. } => {
208 if *insecure {
209 Some(TlsMode::Insecure)
210 } else {
211 Some(TlsMode::Secure)
212 }
213 }
214 _ => None,
215 }
216 }
217}
218
219impl fmt::Display for ConnectionAddr {
220 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
221 match *self {
223 ConnectionAddr::Tcp(ref host, port) => write!(f, "{host}:{port}"),
224 ConnectionAddr::TcpTls { ref host, port, .. } => write!(f, "{host}:{port}"),
225 ConnectionAddr::Unix(ref path) => write!(f, "{}", path.display()),
226 }
227 }
228}
229
230#[derive(Clone, Debug)]
232pub struct ConnectionInfo {
233 pub addr: ConnectionAddr,
235
236 pub redis: RedisConnectionInfo,
238}
239
240#[derive(Clone, Debug, Default)]
242pub struct RedisConnectionInfo {
243 pub db: i64,
245 pub username: Option<String>,
247 pub password: Option<String>,
249 pub protocol: ProtocolVersion,
251}
252
253impl FromStr for ConnectionInfo {
254 type Err = RedisError;
255
256 fn from_str(s: &str) -> Result<Self, Self::Err> {
257 s.into_connection_info()
258 }
259}
260
261pub trait IntoConnectionInfo {
265 fn into_connection_info(self) -> RedisResult<ConnectionInfo>;
267}
268
269impl IntoConnectionInfo for ConnectionInfo {
270 fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
271 Ok(self)
272 }
273}
274
275impl IntoConnectionInfo for &str {
285 fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
286 match parse_redis_url(self) {
287 Some(u) => u.into_connection_info(),
288 None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
289 }
290 }
291}
292
293impl<T> IntoConnectionInfo for (T, u16)
294where
295 T: Into<String>,
296{
297 fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
298 Ok(ConnectionInfo {
299 addr: ConnectionAddr::Tcp(self.0.into(), self.1),
300 redis: RedisConnectionInfo::default(),
301 })
302 }
303}
304
305impl IntoConnectionInfo for String {
315 fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
316 match parse_redis_url(&self) {
317 Some(u) => u.into_connection_info(),
318 None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
319 }
320 }
321}
322
323fn parse_protocol(query: &HashMap<Cow<str>, Cow<str>>) -> RedisResult<ProtocolVersion> {
324 Ok(match query.get("protocol") {
325 Some(protocol) => {
326 if protocol == "2" || protocol == "resp2" {
327 ProtocolVersion::RESP2
328 } else if protocol == "3" || protocol == "resp3" {
329 ProtocolVersion::RESP3
330 } else {
331 fail!((
332 ErrorKind::InvalidClientConfig,
333 "Invalid protocol version",
334 protocol.to_string()
335 ))
336 }
337 }
338 None => ProtocolVersion::RESP2,
339 })
340}
341
342fn url_to_tcp_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
343 let host = match url.host() {
344 Some(host) => {
345 match host {
357 url::Host::Domain(path) => path.to_string(),
358 url::Host::Ipv4(v4) => v4.to_string(),
359 url::Host::Ipv6(v6) => v6.to_string(),
360 }
361 }
362 None => fail!((ErrorKind::InvalidClientConfig, "Missing hostname")),
363 };
364 let port = url.port().unwrap_or(DEFAULT_PORT);
365 let addr = if url.scheme() == "rediss" || url.scheme() == "valkeys" {
366 #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
367 {
368 match url.fragment() {
369 Some("insecure") => ConnectionAddr::TcpTls {
370 host,
371 port,
372 insecure: true,
373 tls_params: None,
374 },
375 Some(_) => fail!((
376 ErrorKind::InvalidClientConfig,
377 "only #insecure is supported as URL fragment"
378 )),
379 _ => ConnectionAddr::TcpTls {
380 host,
381 port,
382 insecure: false,
383 tls_params: None,
384 },
385 }
386 }
387
388 #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
389 fail!((
390 ErrorKind::InvalidClientConfig,
391 "can't connect with TLS, the feature is not enabled"
392 ));
393 } else {
394 ConnectionAddr::Tcp(host, port)
395 };
396 let query: HashMap<_, _> = url.query_pairs().collect();
397 Ok(ConnectionInfo {
398 addr,
399 redis: RedisConnectionInfo {
400 db: match url.path().trim_matches('/') {
401 "" => 0,
402 path => path.parse::<i64>().map_err(|_| -> RedisError {
403 (ErrorKind::InvalidClientConfig, "Invalid database number").into()
404 })?,
405 },
406 username: if url.username().is_empty() {
407 None
408 } else {
409 match percent_encoding::percent_decode(url.username().as_bytes()).decode_utf8() {
410 Ok(decoded) => Some(decoded.into_owned()),
411 Err(_) => fail!((
412 ErrorKind::InvalidClientConfig,
413 "Username is not valid UTF-8 string"
414 )),
415 }
416 },
417 password: match url.password() {
418 Some(pw) => match percent_encoding::percent_decode(pw.as_bytes()).decode_utf8() {
419 Ok(decoded) => Some(decoded.into_owned()),
420 Err(_) => fail!((
421 ErrorKind::InvalidClientConfig,
422 "Password is not valid UTF-8 string"
423 )),
424 },
425 None => None,
426 },
427 protocol: parse_protocol(&query)?,
428 },
429 })
430}
431
432#[cfg(unix)]
433fn url_to_unix_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
434 let query: HashMap<_, _> = url.query_pairs().collect();
435 Ok(ConnectionInfo {
436 addr: ConnectionAddr::Unix(url.to_file_path().map_err(|_| -> RedisError {
437 (ErrorKind::InvalidClientConfig, "Missing path").into()
438 })?),
439 redis: RedisConnectionInfo {
440 db: match query.get("db") {
441 Some(db) => db.parse::<i64>().map_err(|_| -> RedisError {
442 (ErrorKind::InvalidClientConfig, "Invalid database number").into()
443 })?,
444
445 None => 0,
446 },
447 username: query.get("user").map(|username| username.to_string()),
448 password: query.get("pass").map(|password| password.to_string()),
449 protocol: parse_protocol(&query)?,
450 },
451 })
452}
453
454#[cfg(not(unix))]
455fn url_to_unix_connection_info(_: url::Url) -> RedisResult<ConnectionInfo> {
456 fail!((
457 ErrorKind::InvalidClientConfig,
458 "Unix sockets are not available on this platform."
459 ));
460}
461
462impl IntoConnectionInfo for url::Url {
463 fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
464 match self.scheme() {
465 "redis" | "rediss" | "valkey" | "valkeys" => url_to_tcp_connection_info(self),
466 "unix" | "redis+unix" | "valkey+unix" => url_to_unix_connection_info(self),
467 _ => fail!((
468 ErrorKind::InvalidClientConfig,
469 "URL provided is not a redis URL"
470 )),
471 }
472 }
473}
474
475struct TcpConnection {
476 reader: TcpStream,
477 open: bool,
478}
479
480#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
481struct TcpNativeTlsConnection {
482 reader: TlsStream<TcpStream>,
483 open: bool,
484}
485
486#[cfg(feature = "tls-rustls")]
487struct TcpRustlsConnection {
488 reader: StreamOwned<rustls::ClientConnection, TcpStream>,
489 open: bool,
490}
491
492#[cfg(unix)]
493struct UnixConnection {
494 sock: UnixStream,
495 open: bool,
496}
497
498enum ActualConnection {
499 Tcp(TcpConnection),
500 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
501 TcpNativeTls(Box<TcpNativeTlsConnection>),
502 #[cfg(feature = "tls-rustls")]
503 TcpRustls(Box<TcpRustlsConnection>),
504 #[cfg(unix)]
505 Unix(UnixConnection),
506}
507
508#[cfg(feature = "tls-rustls-insecure")]
509struct NoCertificateVerification {
510 supported: rustls::crypto::WebPkiSupportedAlgorithms,
511}
512
513#[cfg(feature = "tls-rustls-insecure")]
514impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
515 fn verify_server_cert(
516 &self,
517 _end_entity: &rustls::pki_types::CertificateDer<'_>,
518 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
519 _server_name: &rustls::pki_types::ServerName<'_>,
520 _ocsp_response: &[u8],
521 _now: rustls::pki_types::UnixTime,
522 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
523 Ok(rustls::client::danger::ServerCertVerified::assertion())
524 }
525
526 fn verify_tls12_signature(
527 &self,
528 _message: &[u8],
529 _cert: &rustls::pki_types::CertificateDer<'_>,
530 _dss: &rustls::DigitallySignedStruct,
531 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
532 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
533 }
534
535 fn verify_tls13_signature(
536 &self,
537 _message: &[u8],
538 _cert: &rustls::pki_types::CertificateDer<'_>,
539 _dss: &rustls::DigitallySignedStruct,
540 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
541 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
542 }
543
544 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
545 self.supported.supported_schemes()
546 }
547}
548
549#[cfg(feature = "tls-rustls-insecure")]
550impl fmt::Debug for NoCertificateVerification {
551 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
552 f.debug_struct("NoCertificateVerification").finish()
553 }
554}
555
556#[cfg(feature = "tls-rustls-insecure")]
558#[derive(Debug)]
559struct AcceptInvalidHostnamesCertVerifier {
560 inner: Arc<rustls::client::WebPkiServerVerifier>,
561}
562
563#[cfg(feature = "tls-rustls-insecure")]
564fn is_hostname_error(err: &rustls::Error) -> bool {
565 matches!(
566 err,
567 rustls::Error::InvalidCertificate(
568 rustls::CertificateError::NotValidForName
569 | rustls::CertificateError::NotValidForNameContext { .. }
570 )
571 )
572}
573
574#[cfg(feature = "tls-rustls-insecure")]
575impl rustls::client::danger::ServerCertVerifier for AcceptInvalidHostnamesCertVerifier {
576 fn verify_server_cert(
577 &self,
578 end_entity: &rustls::pki_types::CertificateDer<'_>,
579 intermediates: &[rustls::pki_types::CertificateDer<'_>],
580 server_name: &rustls::pki_types::ServerName<'_>,
581 ocsp_response: &[u8],
582 now: rustls::pki_types::UnixTime,
583 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
584 self.inner
585 .verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)
586 .or_else(|err| {
587 if is_hostname_error(&err) {
588 Ok(rustls::client::danger::ServerCertVerified::assertion())
589 } else {
590 Err(err)
591 }
592 })
593 }
594
595 fn verify_tls12_signature(
596 &self,
597 message: &[u8],
598 cert: &rustls::pki_types::CertificateDer<'_>,
599 dss: &rustls::DigitallySignedStruct,
600 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
601 self.inner
602 .verify_tls12_signature(message, cert, dss)
603 .or_else(|err| {
604 if is_hostname_error(&err) {
605 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
606 } else {
607 Err(err)
608 }
609 })
610 }
611
612 fn verify_tls13_signature(
613 &self,
614 message: &[u8],
615 cert: &rustls::pki_types::CertificateDer<'_>,
616 dss: &rustls::DigitallySignedStruct,
617 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
618 self.inner
619 .verify_tls13_signature(message, cert, dss)
620 .or_else(|err| {
621 if is_hostname_error(&err) {
622 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
623 } else {
624 Err(err)
625 }
626 })
627 }
628
629 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
630 self.inner.supported_verify_schemes()
631 }
632}
633
634pub struct Connection {
636 con: ActualConnection,
637 parser: Parser,
638 db: i64,
639
640 pubsub: bool,
645
646 protocol: ProtocolVersion,
648
649 push_sender: Option<SyncPushSender>,
651
652 messages_to_skip: usize,
655}
656
657pub struct PubSub<'a> {
661 con: &'a mut Connection,
662 waiting_messages: VecDeque<Msg>,
663}
664
665#[derive(Debug, Clone)]
667pub struct Msg {
668 payload: Value,
669 channel: Value,
670 pattern: Option<Value>,
671}
672
673impl ActualConnection {
674 pub fn new(addr: &ConnectionAddr, timeout: Option<Duration>) -> RedisResult<ActualConnection> {
675 Ok(match *addr {
676 ConnectionAddr::Tcp(ref host, ref port) => {
677 let addr = (host.as_str(), *port);
678 let tcp = match timeout {
679 None => connect_tcp(addr)?,
680 Some(timeout) => {
681 let mut tcp = None;
682 let mut last_error = None;
683 for addr in addr.to_socket_addrs()? {
684 match connect_tcp_timeout(&addr, timeout) {
685 Ok(l) => {
686 tcp = Some(l);
687 break;
688 }
689 Err(e) => {
690 last_error = Some(e);
691 }
692 };
693 }
694 match (tcp, last_error) {
695 (Some(tcp), _) => tcp,
696 (None, Some(e)) => {
697 fail!(e);
698 }
699 (None, None) => {
700 fail!((
701 ErrorKind::InvalidClientConfig,
702 "could not resolve to any addresses"
703 ));
704 }
705 }
706 }
707 };
708 ActualConnection::Tcp(TcpConnection {
709 reader: tcp,
710 open: true,
711 })
712 }
713 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
714 ConnectionAddr::TcpTls {
715 ref host,
716 port,
717 insecure,
718 ref tls_params,
719 } => {
720 let tls_connector = if insecure {
721 TlsConnector::builder()
722 .danger_accept_invalid_certs(true)
723 .danger_accept_invalid_hostnames(true)
724 .use_sni(false)
725 .build()?
726 } else if let Some(params) = tls_params {
727 TlsConnector::builder()
728 .danger_accept_invalid_hostnames(params.danger_accept_invalid_hostnames)
729 .build()?
730 } else {
731 TlsConnector::new()?
732 };
733 let addr = (host.as_str(), port);
734 let tls = match timeout {
735 None => {
736 let tcp = connect_tcp(addr)?;
737 match tls_connector.connect(host, tcp) {
738 Ok(res) => res,
739 Err(e) => {
740 fail!((ErrorKind::IoError, "SSL Handshake error", e.to_string()));
741 }
742 }
743 }
744 Some(timeout) => {
745 let mut tcp = None;
746 let mut last_error = None;
747 for addr in (host.as_str(), port).to_socket_addrs()? {
748 match connect_tcp_timeout(&addr, timeout) {
749 Ok(l) => {
750 tcp = Some(l);
751 break;
752 }
753 Err(e) => {
754 last_error = Some(e);
755 }
756 };
757 }
758 match (tcp, last_error) {
759 (Some(tcp), _) => tls_connector.connect(host, tcp).unwrap(),
760 (None, Some(e)) => {
761 fail!(e);
762 }
763 (None, None) => {
764 fail!((
765 ErrorKind::InvalidClientConfig,
766 "could not resolve to any addresses"
767 ));
768 }
769 }
770 }
771 };
772 ActualConnection::TcpNativeTls(Box::new(TcpNativeTlsConnection {
773 reader: tls,
774 open: true,
775 }))
776 }
777 #[cfg(feature = "tls-rustls")]
778 ConnectionAddr::TcpTls {
779 ref host,
780 port,
781 insecure,
782 ref tls_params,
783 } => {
784 let host: &str = host;
785 let config = create_rustls_config(insecure, tls_params.clone())?;
786 let conn = rustls::ClientConnection::new(
787 Arc::new(config),
788 rustls::pki_types::ServerName::try_from(host)?.to_owned(),
789 )?;
790 let reader = match timeout {
791 None => {
792 let tcp = connect_tcp((host, port))?;
793 StreamOwned::new(conn, tcp)
794 }
795 Some(timeout) => {
796 let mut tcp = None;
797 let mut last_error = None;
798 for addr in (host, port).to_socket_addrs()? {
799 match connect_tcp_timeout(&addr, timeout) {
800 Ok(l) => {
801 tcp = Some(l);
802 break;
803 }
804 Err(e) => {
805 last_error = Some(e);
806 }
807 };
808 }
809 match (tcp, last_error) {
810 (Some(tcp), _) => StreamOwned::new(conn, tcp),
811 (None, Some(e)) => {
812 fail!(e);
813 }
814 (None, None) => {
815 fail!((
816 ErrorKind::InvalidClientConfig,
817 "could not resolve to any addresses"
818 ));
819 }
820 }
821 }
822 };
823
824 ActualConnection::TcpRustls(Box::new(TcpRustlsConnection { reader, open: true }))
825 }
826 #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
827 ConnectionAddr::TcpTls { .. } => {
828 fail!((
829 ErrorKind::InvalidClientConfig,
830 "Cannot connect to TCP with TLS without the tls feature"
831 ));
832 }
833 #[cfg(unix)]
834 ConnectionAddr::Unix(ref path) => ActualConnection::Unix(UnixConnection {
835 sock: UnixStream::connect(path)?,
836 open: true,
837 }),
838 #[cfg(not(unix))]
839 ConnectionAddr::Unix(ref _path) => {
840 fail!((
841 ErrorKind::InvalidClientConfig,
842 "Cannot connect to unix sockets \
843 on this platform"
844 ));
845 }
846 })
847 }
848
849 pub fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult<Value> {
850 match *self {
851 ActualConnection::Tcp(ref mut connection) => {
852 let res = connection.reader.write_all(bytes).map_err(RedisError::from);
853 match res {
854 Err(e) => {
855 if e.is_unrecoverable_error() {
856 connection.open = false;
857 }
858 Err(e)
859 }
860 Ok(_) => Ok(Value::Okay),
861 }
862 }
863 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
864 ActualConnection::TcpNativeTls(ref mut connection) => {
865 let res = connection.reader.write_all(bytes).map_err(RedisError::from);
866 match res {
867 Err(e) => {
868 if e.is_unrecoverable_error() {
869 connection.open = false;
870 }
871 Err(e)
872 }
873 Ok(_) => Ok(Value::Okay),
874 }
875 }
876 #[cfg(feature = "tls-rustls")]
877 ActualConnection::TcpRustls(ref mut connection) => {
878 let res = connection.reader.write_all(bytes).map_err(RedisError::from);
879 match res {
880 Err(e) => {
881 if e.is_unrecoverable_error() {
882 connection.open = false;
883 }
884 Err(e)
885 }
886 Ok(_) => Ok(Value::Okay),
887 }
888 }
889 #[cfg(unix)]
890 ActualConnection::Unix(ref mut connection) => {
891 let result = connection.sock.write_all(bytes).map_err(RedisError::from);
892 match result {
893 Err(e) => {
894 if e.is_unrecoverable_error() {
895 connection.open = false;
896 }
897 Err(e)
898 }
899 Ok(_) => Ok(Value::Okay),
900 }
901 }
902 }
903 }
904
905 pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
906 match *self {
907 ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
908 reader.set_write_timeout(dur)?;
909 }
910 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
911 ActualConnection::TcpNativeTls(ref boxed_tls_connection) => {
912 let reader = &(boxed_tls_connection.reader);
913 reader.get_ref().set_write_timeout(dur)?;
914 }
915 #[cfg(feature = "tls-rustls")]
916 ActualConnection::TcpRustls(ref boxed_tls_connection) => {
917 let reader = &(boxed_tls_connection.reader);
918 reader.get_ref().set_write_timeout(dur)?;
919 }
920 #[cfg(unix)]
921 ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
922 sock.set_write_timeout(dur)?;
923 }
924 }
925 Ok(())
926 }
927
928 pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
929 match *self {
930 ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
931 reader.set_read_timeout(dur)?;
932 }
933 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
934 ActualConnection::TcpNativeTls(ref boxed_tls_connection) => {
935 let reader = &(boxed_tls_connection.reader);
936 reader.get_ref().set_read_timeout(dur)?;
937 }
938 #[cfg(feature = "tls-rustls")]
939 ActualConnection::TcpRustls(ref boxed_tls_connection) => {
940 let reader = &(boxed_tls_connection.reader);
941 reader.get_ref().set_read_timeout(dur)?;
942 }
943 #[cfg(unix)]
944 ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
945 sock.set_read_timeout(dur)?;
946 }
947 }
948 Ok(())
949 }
950
951 pub fn is_open(&self) -> bool {
952 match *self {
953 ActualConnection::Tcp(TcpConnection { open, .. }) => open,
954 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
955 ActualConnection::TcpNativeTls(ref boxed_tls_connection) => boxed_tls_connection.open,
956 #[cfg(feature = "tls-rustls")]
957 ActualConnection::TcpRustls(ref boxed_tls_connection) => boxed_tls_connection.open,
958 #[cfg(unix)]
959 ActualConnection::Unix(UnixConnection { open, .. }) => open,
960 }
961 }
962}
963
964#[cfg(feature = "tls-rustls")]
965pub(crate) fn create_rustls_config(
966 insecure: bool,
967 tls_params: Option<TlsConnParams>,
968) -> RedisResult<rustls::ClientConfig> {
969 #[allow(unused_mut)]
970 let mut root_store = RootCertStore::empty();
971 #[cfg(feature = "tls-rustls-webpki-roots")]
972 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
973 #[cfg(all(
974 feature = "tls-rustls",
975 not(feature = "tls-native-tls"),
976 not(feature = "tls-rustls-webpki-roots")
977 ))]
978 {
979 let mut certificate_result = load_native_certs();
980 if let Some(error) = certificate_result.errors.pop() {
981 return Err(error.into());
982 }
983 for cert in certificate_result.certs {
984 root_store.add(cert)?;
985 }
986 }
987
988 let config = rustls::ClientConfig::builder();
989 let config = if let Some(tls_params) = tls_params {
990 let root_cert_store = tls_params.root_cert_store.unwrap_or(root_store);
991 let config_builder = config.with_root_certificates(root_cert_store.clone());
992
993 let config_builder = if let Some(ClientTlsParams {
994 client_cert_chain: client_cert,
995 client_key,
996 }) = tls_params.client_tls_params
997 {
998 config_builder
999 .with_client_auth_cert(client_cert, client_key)
1000 .map_err(|err| {
1001 RedisError::from((
1002 ErrorKind::InvalidClientConfig,
1003 "Unable to build client with TLS parameters provided.",
1004 err.to_string(),
1005 ))
1006 })?
1007 } else {
1008 config_builder.with_no_client_auth()
1009 };
1010
1011 #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
1017 let config_builder = if !insecure && tls_params.danger_accept_invalid_hostnames {
1018 #[cfg(not(feature = "tls-rustls-insecure"))]
1019 {
1020 fail!((
1023 ErrorKind::InvalidClientConfig,
1024 "Cannot create insecure client via danger_accept_invalid_hostnames without tls-rustls-insecure feature"
1025 ));
1026 }
1027
1028 #[cfg(feature = "tls-rustls-insecure")]
1029 {
1030 let mut config = config_builder;
1031 config.dangerous().set_certificate_verifier(Arc::new(
1032 AcceptInvalidHostnamesCertVerifier {
1033 inner: rustls::client::WebPkiServerVerifier::builder(Arc::new(
1034 root_cert_store,
1035 ))
1036 .build()
1037 .map_err(|err| rustls::Error::from(rustls::OtherError(Arc::new(err))))?,
1038 },
1039 ));
1040 config
1041 }
1042 } else {
1043 config_builder
1044 };
1045
1046 config_builder
1047 } else {
1048 config
1049 .with_root_certificates(root_store)
1050 .with_no_client_auth()
1051 };
1052
1053 match (insecure, cfg!(feature = "tls-rustls-insecure")) {
1054 #[cfg(feature = "tls-rustls-insecure")]
1055 (true, true) => {
1056 let mut config = config;
1057 config.enable_sni = false;
1058 let Some(crypto_provider) = rustls::crypto::CryptoProvider::get_default() else {
1059 return Err(RedisError::from((
1060 ErrorKind::InvalidClientConfig,
1061 "No crypto provider available for rustls",
1062 )));
1063 };
1064 config
1065 .dangerous()
1066 .set_certificate_verifier(Arc::new(NoCertificateVerification {
1067 supported: crypto_provider.signature_verification_algorithms,
1068 }));
1069
1070 Ok(config)
1071 }
1072 (true, false) => {
1073 fail!((
1074 ErrorKind::InvalidClientConfig,
1075 "Cannot create insecure client without tls-rustls-insecure feature"
1076 ));
1077 }
1078 _ => Ok(config),
1079 }
1080}
1081
1082fn authenticate_cmd(
1083 connection_info: &RedisConnectionInfo,
1084 check_username: bool,
1085 password: &str,
1086) -> Cmd {
1087 let mut command = cmd("AUTH");
1088 if check_username {
1089 if let Some(username) = &connection_info.username {
1090 command.arg(username);
1091 }
1092 }
1093 command.arg(password);
1094 command
1095}
1096
1097pub fn connect(
1098 connection_info: &ConnectionInfo,
1099 timeout: Option<Duration>,
1100) -> RedisResult<Connection> {
1101 let start = Instant::now();
1102 let con: ActualConnection = ActualConnection::new(&connection_info.addr, timeout)?;
1103
1104 let remaining_timeout = timeout.and_then(|timeout| timeout.checked_sub(start.elapsed()));
1106 if timeout.is_some() && remaining_timeout.is_none() {
1108 return Err(RedisError::from(std::io::Error::new(
1109 std::io::ErrorKind::TimedOut,
1110 "Connection timed out",
1111 )));
1112 }
1113 con.set_read_timeout(remaining_timeout)?;
1114 con.set_write_timeout(remaining_timeout)?;
1115
1116 let con = setup_connection(
1117 con,
1118 &connection_info.redis,
1119 #[cfg(feature = "cache-aio")]
1120 None,
1121 )?;
1122
1123 con.set_read_timeout(None)?;
1125 con.set_write_timeout(None)?;
1126
1127 Ok(con)
1128}
1129
1130pub(crate) struct ConnectionSetupComponents {
1131 resp3_auth_cmd_idx: Option<usize>,
1132 resp2_auth_cmd_idx: Option<usize>,
1133 select_cmd_idx: Option<usize>,
1134 #[cfg(feature = "cache-aio")]
1135 cache_cmd_idx: Option<usize>,
1136}
1137
1138pub(crate) fn connection_setup_pipeline(
1139 connection_info: &RedisConnectionInfo,
1140 check_username: bool,
1141 #[cfg(feature = "cache-aio")] cache_config: Option<crate::caching::CacheConfig>,
1142) -> (crate::Pipeline, ConnectionSetupComponents) {
1143 let mut pipeline = pipe();
1144 let (authenticate_with_resp3_cmd_index, authenticate_with_resp2_cmd_index) =
1145 if connection_info.protocol != ProtocolVersion::RESP2 {
1146 pipeline.add_command(resp3_hello(connection_info));
1147 (Some(0), None)
1148 } else if connection_info.password.is_some() {
1149 pipeline.add_command(authenticate_cmd(
1150 connection_info,
1151 check_username,
1152 connection_info.password.as_ref().unwrap(),
1153 ));
1154 (None, Some(0))
1155 } else {
1156 (None, None)
1157 };
1158
1159 let select_db_cmd_index = (connection_info.db != 0)
1160 .then(|| pipeline.len())
1161 .inspect(|_| {
1162 pipeline.cmd("SELECT").arg(connection_info.db);
1163 });
1164
1165 #[cfg(feature = "cache-aio")]
1166 let cache_cmd_index = cache_config.map(|cache_config| {
1167 pipeline.cmd("CLIENT").arg("TRACKING").arg("ON");
1168 match cache_config.mode {
1169 crate::caching::CacheMode::All => {}
1170 crate::caching::CacheMode::OptIn => {
1171 pipeline.arg("OPTIN");
1172 }
1173 }
1174 pipeline.len() - 1
1175 });
1176
1177 #[cfg(not(feature = "disable-client-setinfo"))]
1180 pipeline
1181 .cmd("CLIENT")
1182 .arg("SETINFO")
1183 .arg("LIB-NAME")
1184 .arg("redis-rs")
1185 .ignore();
1186 #[cfg(not(feature = "disable-client-setinfo"))]
1187 pipeline
1188 .cmd("CLIENT")
1189 .arg("SETINFO")
1190 .arg("LIB-VER")
1191 .arg(env!("CARGO_PKG_VERSION"))
1192 .ignore();
1193
1194 (
1195 pipeline,
1196 ConnectionSetupComponents {
1197 resp3_auth_cmd_idx: authenticate_with_resp3_cmd_index,
1198 resp2_auth_cmd_idx: authenticate_with_resp2_cmd_index,
1199 select_cmd_idx: select_db_cmd_index,
1200 #[cfg(feature = "cache-aio")]
1201 cache_cmd_idx: cache_cmd_index,
1202 },
1203 )
1204}
1205
1206fn check_resp3_auth(result: &Value) -> RedisResult<()> {
1207 if let Value::ServerError(err) = result {
1208 return Err(get_resp3_hello_command_error(err.clone().into()));
1209 }
1210 Ok(())
1211}
1212
1213#[derive(PartialEq)]
1214pub(crate) enum AuthResult {
1215 Succeeded,
1216 ShouldRetryWithoutUsername,
1217}
1218
1219fn check_resp2_auth(result: &Value) -> RedisResult<AuthResult> {
1220 let err = match result {
1221 Value::Okay => {
1222 return Ok(AuthResult::Succeeded);
1223 }
1224 Value::ServerError(err) => err,
1225 _ => {
1226 return Err((
1227 ErrorKind::ResponseError,
1228 "Redis server refused to authenticate, returns Ok() != Value::Okay",
1229 )
1230 .into());
1231 }
1232 };
1233
1234 let err_msg = err.details().ok_or((
1235 ErrorKind::AuthenticationFailed,
1236 "Password authentication failed",
1237 ))?;
1238 if !err_msg.contains("wrong number of arguments for 'auth' command") {
1239 return Err((
1240 ErrorKind::AuthenticationFailed,
1241 "Password authentication failed",
1242 )
1243 .into());
1244 }
1245 Ok(AuthResult::ShouldRetryWithoutUsername)
1246}
1247
1248fn check_db_select(value: &Value) -> RedisResult<()> {
1249 let Value::ServerError(err) = value else {
1250 return Ok(());
1251 };
1252
1253 match err.details() {
1254 Some(err_msg) => Err((
1255 ErrorKind::ResponseError,
1256 "Redis server refused to switch database",
1257 err_msg.to_string(),
1258 )
1259 .into()),
1260 None => Err((
1261 ErrorKind::ResponseError,
1262 "Redis server refused to switch database",
1263 )
1264 .into()),
1265 }
1266}
1267
1268#[cfg(feature = "cache-aio")]
1269fn check_caching(result: &Value) -> RedisResult<()> {
1270 match result {
1271 Value::Okay => Ok(()),
1272 _ => Err((
1273 ErrorKind::ResponseError,
1274 "Client-side caching returned unknown response",
1275 )
1276 .into()),
1277 }
1278}
1279
1280pub(crate) fn check_connection_setup(
1281 results: Vec<Value>,
1282 ConnectionSetupComponents {
1283 resp3_auth_cmd_idx,
1284 resp2_auth_cmd_idx,
1285 select_cmd_idx,
1286 #[cfg(feature = "cache-aio")]
1287 cache_cmd_idx,
1288 }: ConnectionSetupComponents,
1289) -> RedisResult<AuthResult> {
1290 assert!(!(resp2_auth_cmd_idx.is_some() && resp3_auth_cmd_idx.is_some()));
1292
1293 if let Some(index) = resp3_auth_cmd_idx {
1294 let Some(value) = results.get(index) else {
1295 return Err((ErrorKind::ClientError, "Missing RESP3 auth response").into());
1296 };
1297 check_resp3_auth(value)?;
1298 } else if let Some(index) = resp2_auth_cmd_idx {
1299 let Some(value) = results.get(index) else {
1300 return Err((ErrorKind::ClientError, "Missing RESP2 auth response").into());
1301 };
1302 if check_resp2_auth(value)? == AuthResult::ShouldRetryWithoutUsername {
1303 return Ok(AuthResult::ShouldRetryWithoutUsername);
1304 }
1305 }
1306
1307 if let Some(index) = select_cmd_idx {
1308 let Some(value) = results.get(index) else {
1309 return Err((ErrorKind::ClientError, "Missing SELECT DB response").into());
1310 };
1311 check_db_select(value)?;
1312 }
1313
1314 #[cfg(feature = "cache-aio")]
1315 if let Some(index) = cache_cmd_idx {
1316 let Some(value) = results.get(index) else {
1317 return Err((ErrorKind::ClientError, "Missing Caching response").into());
1318 };
1319 check_caching(value)?;
1320 }
1321
1322 Ok(AuthResult::Succeeded)
1323}
1324
1325fn execute_connection_pipeline(
1326 rv: &mut Connection,
1327 (pipeline, instructions): (crate::Pipeline, ConnectionSetupComponents),
1328) -> RedisResult<AuthResult> {
1329 if pipeline.is_empty() {
1330 return Ok(AuthResult::Succeeded);
1331 }
1332 let results = rv.req_packed_commands(&pipeline.get_packed_pipeline(), 0, pipeline.len())?;
1333
1334 check_connection_setup(results, instructions)
1335}
1336
1337fn setup_connection(
1338 con: ActualConnection,
1339 connection_info: &RedisConnectionInfo,
1340 #[cfg(feature = "cache-aio")] cache_config: Option<crate::caching::CacheConfig>,
1341) -> RedisResult<Connection> {
1342 let mut rv = Connection {
1343 con,
1344 parser: Parser::new(),
1345 db: connection_info.db,
1346 pubsub: false,
1347 protocol: connection_info.protocol,
1348 push_sender: None,
1349 messages_to_skip: 0,
1350 };
1351
1352 if execute_connection_pipeline(
1353 &mut rv,
1354 connection_setup_pipeline(
1355 connection_info,
1356 true,
1357 #[cfg(feature = "cache-aio")]
1358 cache_config,
1359 ),
1360 )? == AuthResult::ShouldRetryWithoutUsername
1361 {
1362 execute_connection_pipeline(
1363 &mut rv,
1364 connection_setup_pipeline(
1365 connection_info,
1366 false,
1367 #[cfg(feature = "cache-aio")]
1368 cache_config,
1369 ),
1370 )?;
1371 }
1372
1373 Ok(rv)
1374}
1375
1376pub trait ConnectionLike {
1388 fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value>;
1391
1392 #[doc(hidden)]
1400 fn req_packed_commands(
1401 &mut self,
1402 cmd: &[u8],
1403 offset: usize,
1404 count: usize,
1405 ) -> RedisResult<Vec<Value>>;
1406
1407 fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1409 let pcmd = cmd.get_packed_command();
1410 self.req_packed_command(&pcmd)
1411 }
1412
1413 fn get_db(&self) -> i64;
1418
1419 #[doc(hidden)]
1421 fn supports_pipelining(&self) -> bool {
1422 true
1423 }
1424
1425 fn check_connection(&mut self) -> bool;
1427
1428 fn is_open(&self) -> bool;
1436}
1437
1438impl Connection {
1446 pub fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()> {
1451 self.send_bytes(cmd)?;
1452 Ok(())
1453 }
1454
1455 pub fn recv_response(&mut self) -> RedisResult<Value> {
1458 self.read(true)
1459 }
1460
1461 pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1467 self.con.set_write_timeout(dur)
1468 }
1469
1470 pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1476 self.con.set_read_timeout(dur)
1477 }
1478
1479 pub fn as_pubsub(&mut self) -> PubSub<'_> {
1481 PubSub::new(self)
1485 }
1486
1487 fn exit_pubsub(&mut self) -> RedisResult<()> {
1488 let res = self.clear_active_subscriptions();
1489 if res.is_ok() {
1490 self.pubsub = false;
1491 } else {
1492 self.pubsub = true;
1494 }
1495
1496 res
1497 }
1498
1499 fn clear_active_subscriptions(&mut self) -> RedisResult<()> {
1504 {
1510 let unsubscribe = cmd("UNSUBSCRIBE").get_packed_command();
1512 let punsubscribe = cmd("PUNSUBSCRIBE").get_packed_command();
1513
1514 self.send_bytes(&unsubscribe)?;
1516 self.send_bytes(&punsubscribe)?;
1517 }
1518
1519 let mut received_unsub = false;
1525 let mut received_punsub = false;
1526
1527 loop {
1528 let resp = self.recv_response()?;
1529
1530 match resp {
1531 Value::Push { kind, data } => {
1532 if data.len() >= 2 {
1533 if let Value::Int(num) = data[1] {
1534 if resp3_is_pub_sub_state_cleared(
1535 &mut received_unsub,
1536 &mut received_punsub,
1537 &kind,
1538 num as isize,
1539 ) {
1540 break;
1541 }
1542 }
1543 }
1544 }
1545 Value::ServerError(err) => {
1546 if err.kind() == Some(ServerErrorKind::NoSub) {
1549 if no_sub_err_is_pub_sub_state_cleared(
1550 &mut received_unsub,
1551 &mut received_punsub,
1552 &err,
1553 ) {
1554 break;
1555 } else {
1556 continue;
1557 }
1558 }
1559
1560 return Err(err.into());
1561 }
1562 Value::Array(vec) => {
1563 let res: (Vec<u8>, (), isize) = from_owned_redis_value(Value::Array(vec))?;
1564 if resp2_is_pub_sub_state_cleared(
1565 &mut received_unsub,
1566 &mut received_punsub,
1567 &res.0,
1568 res.2,
1569 ) {
1570 break;
1571 }
1572 }
1573 _ => {
1574 return Err((
1575 ErrorKind::ClientError,
1576 "Unexpected unsubscribe response",
1577 format!("{resp:?}"),
1578 )
1579 .into())
1580 }
1581 }
1582 }
1583
1584 Ok(())
1587 }
1588
1589 fn send_push(&self, push: PushInfo) {
1590 if let Some(sender) = &self.push_sender {
1591 let _ = sender.send(push);
1592 }
1593 }
1594
1595 fn try_send(&self, value: &RedisResult<Value>) {
1596 if let Ok(Value::Push { kind, data }) = value {
1597 self.send_push(PushInfo {
1598 kind: kind.clone(),
1599 data: data.clone(),
1600 });
1601 }
1602 }
1603
1604 fn send_disconnect(&self) {
1605 self.send_push(PushInfo::disconnect())
1606 }
1607
1608 fn close_connection(&mut self) {
1609 self.send_disconnect();
1611 match self.con {
1612 ActualConnection::Tcp(ref mut connection) => {
1613 let _ = connection.reader.shutdown(net::Shutdown::Both);
1614 connection.open = false;
1615 }
1616 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1617 ActualConnection::TcpNativeTls(ref mut connection) => {
1618 let _ = connection.reader.shutdown();
1619 connection.open = false;
1620 }
1621 #[cfg(feature = "tls-rustls")]
1622 ActualConnection::TcpRustls(ref mut connection) => {
1623 let _ = connection.reader.get_mut().shutdown(net::Shutdown::Both);
1624 connection.open = false;
1625 }
1626 #[cfg(unix)]
1627 ActualConnection::Unix(ref mut connection) => {
1628 let _ = connection.sock.shutdown(net::Shutdown::Both);
1629 connection.open = false;
1630 }
1631 }
1632 }
1633
1634 fn read(&mut self, is_response: bool) -> RedisResult<Value> {
1637 loop {
1638 let result = match self.con {
1639 ActualConnection::Tcp(TcpConnection { ref mut reader, .. }) => {
1640 self.parser.parse_value(reader)
1641 }
1642 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1643 ActualConnection::TcpNativeTls(ref mut boxed_tls_connection) => {
1644 let reader = &mut boxed_tls_connection.reader;
1645 self.parser.parse_value(reader)
1646 }
1647 #[cfg(feature = "tls-rustls")]
1648 ActualConnection::TcpRustls(ref mut boxed_tls_connection) => {
1649 let reader = &mut boxed_tls_connection.reader;
1650 self.parser.parse_value(reader)
1651 }
1652 #[cfg(unix)]
1653 ActualConnection::Unix(UnixConnection { ref mut sock, .. }) => {
1654 self.parser.parse_value(sock)
1655 }
1656 };
1657 self.try_send(&result);
1658
1659 let Err(err) = &result else {
1660 if self.messages_to_skip > 0 {
1661 self.messages_to_skip -= 1;
1662 continue;
1663 }
1664 return result;
1665 };
1666 let Some(io_error) = err.as_io_error() else {
1667 if self.messages_to_skip > 0 {
1668 self.messages_to_skip -= 1;
1669 continue;
1670 }
1671 return result;
1672 };
1673 if io_error.kind() == io::ErrorKind::UnexpectedEof {
1675 self.close_connection();
1676 } else if is_response {
1677 self.messages_to_skip += 1;
1678 }
1679
1680 return result;
1681 }
1682 }
1683
1684 pub fn set_push_sender(&mut self, sender: SyncPushSender) {
1686 self.push_sender = Some(sender);
1687 }
1688
1689 fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult<Value> {
1690 let result = self.con.send_bytes(bytes);
1691 if self.protocol != ProtocolVersion::RESP2 {
1692 if let Err(e) = &result {
1693 if e.is_connection_dropped() {
1694 self.send_disconnect();
1695 }
1696 }
1697 }
1698 result
1699 }
1700
1701 pub fn subscribe_resp3<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1705 check_resp3!(self.protocol);
1706 cmd("SUBSCRIBE")
1707 .arg(channel)
1708 .set_no_response(true)
1709 .exec(self)
1710 }
1711
1712 pub fn psubscribe_resp3<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1716 check_resp3!(self.protocol);
1717 cmd("PSUBSCRIBE")
1718 .arg(pchannel)
1719 .set_no_response(true)
1720 .exec(self)
1721 }
1722
1723 pub fn unsubscribe_resp3<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1727 check_resp3!(self.protocol);
1728 cmd("UNSUBSCRIBE")
1729 .arg(channel)
1730 .set_no_response(true)
1731 .exec(self)
1732 }
1733
1734 pub fn punsubscribe_resp3<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1738 check_resp3!(self.protocol);
1739 cmd("PUNSUBSCRIBE")
1740 .arg(pchannel)
1741 .set_no_response(true)
1742 .exec(self)
1743 }
1744}
1745
1746impl ConnectionLike for Connection {
1747 fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1749 let pcmd = cmd.get_packed_command();
1750 if self.pubsub {
1751 self.exit_pubsub()?;
1752 }
1753
1754 self.send_bytes(&pcmd)?;
1755 if cmd.is_no_response() {
1756 return Ok(Value::Nil);
1757 }
1758 loop {
1759 match self.read(true)? {
1760 Value::Push {
1761 kind: _kind,
1762 data: _data,
1763 } => continue,
1764 val => return Ok(val),
1765 }
1766 }
1767 }
1768 fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
1769 if self.pubsub {
1770 self.exit_pubsub()?;
1771 }
1772
1773 self.send_bytes(cmd)?;
1774 loop {
1775 match self.read(true)? {
1776 Value::Push {
1777 kind: _kind,
1778 data: _data,
1779 } => continue,
1780 val => return Ok(val),
1781 }
1782 }
1783 }
1784
1785 fn req_packed_commands(
1786 &mut self,
1787 cmd: &[u8],
1788 offset: usize,
1789 count: usize,
1790 ) -> RedisResult<Vec<Value>> {
1791 if self.pubsub {
1792 self.exit_pubsub()?;
1793 }
1794 self.send_bytes(cmd)?;
1795 let mut rv = vec![];
1796 let mut first_err = None;
1797 let mut count = count;
1798 let mut idx = 0;
1799 while idx < (offset + count) {
1800 let response = self.read(true);
1805 match response {
1806 Ok(Value::ServerError(err)) => {
1807 if idx < offset {
1808 if first_err.is_none() {
1809 first_err = Some(err.into());
1810 }
1811 } else {
1812 rv.push(Value::ServerError(err));
1813 }
1814 }
1815 Ok(item) => {
1816 if let Value::Push {
1818 kind: _kind,
1819 data: _data,
1820 } = item
1821 {
1822 count += 1;
1824 } else if idx >= offset {
1825 rv.push(item);
1826 }
1827 }
1828 Err(err) => {
1829 if first_err.is_none() {
1830 first_err = Some(err);
1831 }
1832 }
1833 }
1834 idx += 1;
1835 }
1836
1837 first_err.map_or(Ok(rv), Err)
1838 }
1839
1840 fn get_db(&self) -> i64 {
1841 self.db
1842 }
1843
1844 fn check_connection(&mut self) -> bool {
1845 cmd("PING").query::<String>(self).is_ok()
1846 }
1847
1848 fn is_open(&self) -> bool {
1849 self.con.is_open()
1850 }
1851}
1852
1853impl<C, T> ConnectionLike for T
1854where
1855 C: ConnectionLike,
1856 T: DerefMut<Target = C>,
1857{
1858 fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
1859 self.deref_mut().req_packed_command(cmd)
1860 }
1861
1862 fn req_packed_commands(
1863 &mut self,
1864 cmd: &[u8],
1865 offset: usize,
1866 count: usize,
1867 ) -> RedisResult<Vec<Value>> {
1868 self.deref_mut().req_packed_commands(cmd, offset, count)
1869 }
1870
1871 fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1872 self.deref_mut().req_command(cmd)
1873 }
1874
1875 fn get_db(&self) -> i64 {
1876 self.deref().get_db()
1877 }
1878
1879 fn supports_pipelining(&self) -> bool {
1880 self.deref().supports_pipelining()
1881 }
1882
1883 fn check_connection(&mut self) -> bool {
1884 self.deref_mut().check_connection()
1885 }
1886
1887 fn is_open(&self) -> bool {
1888 self.deref().is_open()
1889 }
1890}
1891
1892impl<'a> PubSub<'a> {
1914 fn new(con: &'a mut Connection) -> Self {
1915 Self {
1916 con,
1917 waiting_messages: VecDeque::new(),
1918 }
1919 }
1920
1921 fn cache_messages_until_received_response(
1922 &mut self,
1923 cmd: &mut Cmd,
1924 is_sub_unsub: bool,
1925 ) -> RedisResult<Value> {
1926 let ignore_response = self.con.protocol != ProtocolVersion::RESP2 && is_sub_unsub;
1927 cmd.set_no_response(ignore_response);
1928
1929 self.con.send_packed_command(&cmd.get_packed_command())?;
1930
1931 loop {
1932 let response = self.con.recv_response()?;
1933 if let Some(msg) = Msg::from_value(&response) {
1934 self.waiting_messages.push_back(msg);
1935 } else {
1936 return Ok(response);
1937 }
1938 }
1939 }
1940
1941 pub fn subscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1943 self.cache_messages_until_received_response(cmd("SUBSCRIBE").arg(channel), true)?;
1944 Ok(())
1945 }
1946
1947 pub fn psubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1949 self.cache_messages_until_received_response(cmd("PSUBSCRIBE").arg(pchannel), true)?;
1950 Ok(())
1951 }
1952
1953 pub fn unsubscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1955 self.cache_messages_until_received_response(cmd("UNSUBSCRIBE").arg(channel), true)?;
1956 Ok(())
1957 }
1958
1959 pub fn punsubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1961 self.cache_messages_until_received_response(cmd("PUNSUBSCRIBE").arg(pchannel), true)?;
1962 Ok(())
1963 }
1964
1965 pub fn ping_message<T: FromRedisValue>(&mut self, message: impl ToRedisArgs) -> RedisResult<T> {
1967 from_owned_redis_value(
1968 self.cache_messages_until_received_response(cmd("PING").arg(message), false)?,
1969 )
1970 }
1971 pub fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
1973 from_owned_redis_value(
1974 self.cache_messages_until_received_response(&mut cmd("PING"), false)?,
1975 )
1976 }
1977
1978 pub fn get_message(&mut self) -> RedisResult<Msg> {
1985 if let Some(msg) = self.waiting_messages.pop_front() {
1986 return Ok(msg);
1987 }
1988 loop {
1989 if let Some(msg) = Msg::from_owned_value(self.con.read(false)?) {
1990 return Ok(msg);
1991 } else {
1992 continue;
1993 }
1994 }
1995 }
1996
1997 pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
2003 self.con.set_read_timeout(dur)
2004 }
2005}
2006
2007impl Drop for PubSub<'_> {
2008 fn drop(&mut self) {
2009 let _ = self.con.exit_pubsub();
2010 }
2011}
2012
2013impl Msg {
2016 pub fn from_value(value: &Value) -> Option<Self> {
2018 Self::from_owned_value(value.clone())
2019 }
2020
2021 pub fn from_owned_value(value: Value) -> Option<Self> {
2023 let mut pattern = None;
2024 let payload;
2025 let channel;
2026
2027 if let Value::Push { kind, data } = value {
2028 return Self::from_push_info(PushInfo { kind, data });
2029 } else {
2030 let raw_msg: Vec<Value> = from_owned_redis_value(value).ok()?;
2031 let mut iter = raw_msg.into_iter();
2032 let msg_type: String = from_owned_redis_value(iter.next()?).ok()?;
2033 if msg_type == "message" {
2034 channel = iter.next()?;
2035 payload = iter.next()?;
2036 } else if msg_type == "pmessage" {
2037 pattern = Some(iter.next()?);
2038 channel = iter.next()?;
2039 payload = iter.next()?;
2040 } else {
2041 return None;
2042 }
2043 };
2044 Some(Msg {
2045 payload,
2046 channel,
2047 pattern,
2048 })
2049 }
2050
2051 pub fn from_push_info(push_info: PushInfo) -> Option<Self> {
2053 let mut pattern = None;
2054 let payload;
2055 let channel;
2056
2057 let mut iter = push_info.data.into_iter();
2058 if push_info.kind == PushKind::Message || push_info.kind == PushKind::SMessage {
2059 channel = iter.next()?;
2060 payload = iter.next()?;
2061 } else if push_info.kind == PushKind::PMessage {
2062 pattern = Some(iter.next()?);
2063 channel = iter.next()?;
2064 payload = iter.next()?;
2065 } else {
2066 return None;
2067 }
2068
2069 Some(Msg {
2070 payload,
2071 channel,
2072 pattern,
2073 })
2074 }
2075
2076 pub fn get_channel<T: FromRedisValue>(&self) -> RedisResult<T> {
2078 from_redis_value(&self.channel)
2079 }
2080
2081 pub fn get_channel_name(&self) -> &str {
2086 match self.channel {
2087 Value::BulkString(ref bytes) => from_utf8(bytes).unwrap_or("?"),
2088 _ => "?",
2089 }
2090 }
2091
2092 pub fn get_payload<T: FromRedisValue>(&self) -> RedisResult<T> {
2094 from_redis_value(&self.payload)
2095 }
2096
2097 pub fn get_payload_bytes(&self) -> &[u8] {
2101 match self.payload {
2102 Value::BulkString(ref bytes) => bytes,
2103 _ => b"",
2104 }
2105 }
2106
2107 #[allow(clippy::wrong_self_convention)]
2110 pub fn from_pattern(&self) -> bool {
2111 self.pattern.is_some()
2112 }
2113
2114 pub fn get_pattern<T: FromRedisValue>(&self) -> RedisResult<T> {
2119 match self.pattern {
2120 None => from_redis_value(&Value::Nil),
2121 Some(ref x) => from_redis_value(x),
2122 }
2123 }
2124}
2125
2126pub fn transaction<
2159 C: ConnectionLike,
2160 K: ToRedisArgs,
2161 T,
2162 F: FnMut(&mut C, &mut Pipeline) -> RedisResult<Option<T>>,
2163>(
2164 con: &mut C,
2165 keys: &[K],
2166 func: F,
2167) -> RedisResult<T> {
2168 let mut func = func;
2169 loop {
2170 cmd("WATCH").arg(keys).exec(con)?;
2171 let mut p = pipe();
2172 let response: Option<T> = func(con, p.atomic())?;
2173 match response {
2174 None => {
2175 continue;
2176 }
2177 Some(response) => {
2178 cmd("UNWATCH").exec(con)?;
2181 return Ok(response);
2182 }
2183 }
2184 }
2185}
2186pub fn resp2_is_pub_sub_state_cleared(
2190 received_unsub: &mut bool,
2191 received_punsub: &mut bool,
2192 kind: &[u8],
2193 num: isize,
2194) -> bool {
2195 match kind.first() {
2196 Some(&b'u') => *received_unsub = true,
2197 Some(&b'p') => *received_punsub = true,
2198 _ => (),
2199 };
2200 *received_unsub && *received_punsub && num == 0
2201}
2202
2203pub fn resp3_is_pub_sub_state_cleared(
2205 received_unsub: &mut bool,
2206 received_punsub: &mut bool,
2207 kind: &PushKind,
2208 num: isize,
2209) -> bool {
2210 match kind {
2211 PushKind::Unsubscribe => *received_unsub = true,
2212 PushKind::PUnsubscribe => *received_punsub = true,
2213 _ => (),
2214 };
2215 *received_unsub && *received_punsub && num == 0
2216}
2217
2218pub fn no_sub_err_is_pub_sub_state_cleared(
2219 received_unsub: &mut bool,
2220 received_punsub: &mut bool,
2221 err: &ServerError,
2222) -> bool {
2223 let details = err.details();
2224 *received_unsub = *received_unsub
2225 || details
2226 .map(|details| details.starts_with("'unsub"))
2227 .unwrap_or_default();
2228 *received_punsub = *received_punsub
2229 || details
2230 .map(|details| details.starts_with("'punsub"))
2231 .unwrap_or_default();
2232 *received_unsub && *received_punsub
2233}
2234
2235pub fn get_resp3_hello_command_error(err: RedisError) -> RedisError {
2237 if let Some(detail) = err.detail() {
2238 if detail.starts_with("unknown command `HELLO`") {
2239 return (
2240 ErrorKind::RESP3NotSupported,
2241 "Redis Server doesn't support HELLO command therefore resp3 cannot be used",
2242 )
2243 .into();
2244 }
2245 }
2246 err
2247}
2248
2249#[cfg(test)]
2250mod tests {
2251 use super::*;
2252
2253 #[test]
2254 fn test_parse_redis_url() {
2255 let cases = vec![
2256 ("redis://127.0.0.1", true),
2257 ("redis://[::1]", true),
2258 ("rediss://127.0.0.1", true),
2259 ("rediss://[::1]", true),
2260 ("valkey://127.0.0.1", true),
2261 ("valkey://[::1]", true),
2262 ("valkeys://127.0.0.1", true),
2263 ("valkeys://[::1]", true),
2264 ("redis+unix:///run/redis.sock", true),
2265 ("valkey+unix:///run/valkey.sock", true),
2266 ("unix:///run/redis.sock", true),
2267 ("http://127.0.0.1", false),
2268 ("tcp://127.0.0.1", false),
2269 ];
2270 for (url, expected) in cases.into_iter() {
2271 let res = parse_redis_url(url);
2272 assert_eq!(
2273 res.is_some(),
2274 expected,
2275 "Parsed result of `{url}` is not expected",
2276 );
2277 }
2278 }
2279
2280 #[test]
2281 fn test_url_to_tcp_connection_info() {
2282 let cases = vec![
2283 (
2284 url::Url::parse("redis://127.0.0.1").unwrap(),
2285 ConnectionInfo {
2286 addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2287 redis: Default::default(),
2288 },
2289 ),
2290 (
2291 url::Url::parse("redis://[::1]").unwrap(),
2292 ConnectionInfo {
2293 addr: ConnectionAddr::Tcp("::1".to_string(), 6379),
2294 redis: Default::default(),
2295 },
2296 ),
2297 (
2298 url::Url::parse("redis://%25johndoe%25:%23%40%3C%3E%24@example.com/2").unwrap(),
2299 ConnectionInfo {
2300 addr: ConnectionAddr::Tcp("example.com".to_string(), 6379),
2301 redis: RedisConnectionInfo {
2302 db: 2,
2303 username: Some("%johndoe%".to_string()),
2304 password: Some("#@<>$".to_string()),
2305 ..Default::default()
2306 },
2307 },
2308 ),
2309 (
2310 url::Url::parse("redis://127.0.0.1/?protocol=2").unwrap(),
2311 ConnectionInfo {
2312 addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2313 redis: Default::default(),
2314 },
2315 ),
2316 (
2317 url::Url::parse("redis://127.0.0.1/?protocol=resp3").unwrap(),
2318 ConnectionInfo {
2319 addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2320 redis: RedisConnectionInfo {
2321 protocol: ProtocolVersion::RESP3,
2322 ..Default::default()
2323 },
2324 },
2325 ),
2326 ];
2327 for (url, expected) in cases.into_iter() {
2328 let res = url_to_tcp_connection_info(url.clone()).unwrap();
2329 assert_eq!(res.addr, expected.addr, "addr of {url} is not expected");
2330 assert_eq!(
2331 res.redis.db, expected.redis.db,
2332 "db of {url} is not expected",
2333 );
2334 assert_eq!(
2335 res.redis.username, expected.redis.username,
2336 "username of {url} is not expected",
2337 );
2338 assert_eq!(
2339 res.redis.password, expected.redis.password,
2340 "password of {url} is not expected",
2341 );
2342 }
2343 }
2344
2345 #[test]
2346 fn test_url_to_tcp_connection_info_failed() {
2347 let cases = vec![
2348 (
2349 url::Url::parse("redis://").unwrap(),
2350 "Missing hostname",
2351 None,
2352 ),
2353 (
2354 url::Url::parse("redis://127.0.0.1/db").unwrap(),
2355 "Invalid database number",
2356 None,
2357 ),
2358 (
2359 url::Url::parse("redis://C3%B0@127.0.0.1").unwrap(),
2360 "Username is not valid UTF-8 string",
2361 None,
2362 ),
2363 (
2364 url::Url::parse("redis://:C3%B0@127.0.0.1").unwrap(),
2365 "Password is not valid UTF-8 string",
2366 None,
2367 ),
2368 (
2369 url::Url::parse("redis://127.0.0.1/?protocol=4").unwrap(),
2370 "Invalid protocol version",
2371 Some("4"),
2372 ),
2373 ];
2374 for (url, expected, detail) in cases.into_iter() {
2375 let res = url_to_tcp_connection_info(url).unwrap_err();
2376 assert_eq!(
2377 res.kind(),
2378 crate::ErrorKind::InvalidClientConfig,
2379 "{}",
2380 &res,
2381 );
2382 #[allow(deprecated)]
2383 let desc = std::error::Error::description(&res);
2384 assert_eq!(desc, expected, "{}", &res);
2385 assert_eq!(res.detail(), detail, "{}", &res);
2386 }
2387 }
2388
2389 #[test]
2390 #[cfg(unix)]
2391 fn test_url_to_unix_connection_info() {
2392 let cases = vec![
2393 (
2394 url::Url::parse("unix:///var/run/redis.sock").unwrap(),
2395 ConnectionInfo {
2396 addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2397 redis: RedisConnectionInfo {
2398 db: 0,
2399 username: None,
2400 password: None,
2401 protocol: ProtocolVersion::RESP2,
2402 },
2403 },
2404 ),
2405 (
2406 url::Url::parse("redis+unix:///var/run/redis.sock?db=1").unwrap(),
2407 ConnectionInfo {
2408 addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2409 redis: RedisConnectionInfo {
2410 db: 1,
2411 ..Default::default()
2412 },
2413 },
2414 ),
2415 (
2416 url::Url::parse(
2417 "unix:///example.sock?user=%25johndoe%25&pass=%23%40%3C%3E%24&db=2",
2418 )
2419 .unwrap(),
2420 ConnectionInfo {
2421 addr: ConnectionAddr::Unix("/example.sock".into()),
2422 redis: RedisConnectionInfo {
2423 db: 2,
2424 username: Some("%johndoe%".to_string()),
2425 password: Some("#@<>$".to_string()),
2426 ..Default::default()
2427 },
2428 },
2429 ),
2430 (
2431 url::Url::parse(
2432 "redis+unix:///example.sock?pass=%26%3F%3D+%2A%2B&db=2&user=%25johndoe%25",
2433 )
2434 .unwrap(),
2435 ConnectionInfo {
2436 addr: ConnectionAddr::Unix("/example.sock".into()),
2437 redis: RedisConnectionInfo {
2438 db: 2,
2439 username: Some("%johndoe%".to_string()),
2440 password: Some("&?= *+".to_string()),
2441 ..Default::default()
2442 },
2443 },
2444 ),
2445 (
2446 url::Url::parse("redis+unix:///var/run/redis.sock?protocol=3").unwrap(),
2447 ConnectionInfo {
2448 addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2449 redis: RedisConnectionInfo {
2450 protocol: ProtocolVersion::RESP3,
2451 ..Default::default()
2452 },
2453 },
2454 ),
2455 ];
2456 for (url, expected) in cases.into_iter() {
2457 assert_eq!(
2458 ConnectionAddr::Unix(url.to_file_path().unwrap()),
2459 expected.addr,
2460 "addr of {url} is not expected",
2461 );
2462 let res = url_to_unix_connection_info(url.clone()).unwrap();
2463 assert_eq!(res.addr, expected.addr, "addr of {url} is not expected");
2464 assert_eq!(
2465 res.redis.db, expected.redis.db,
2466 "db of {url} is not expected",
2467 );
2468 assert_eq!(
2469 res.redis.username, expected.redis.username,
2470 "username of {url} is not expected",
2471 );
2472 assert_eq!(
2473 res.redis.password, expected.redis.password,
2474 "password of {url} is not expected",
2475 );
2476 }
2477 }
2478}