1use super::{AsyncPushSender, ConnectionLike, Runtime, SharedHandleContainer, TaskHandle};
2#[cfg(feature = "cache-aio")]
3use crate::caching::{CacheManager, CacheStatistics, PrepareCacheResult};
4use crate::{
5 AsyncConnectionConfig, ProtocolVersion, PushInfo, RedisConnectionInfo, ServerError,
6 ToRedisArgs,
7 aio::setup_connection,
8 check_resp3, cmd,
9 cmd::Cmd,
10 errors::{RedisError, closed_connection_error},
11 parser::ValueCodec,
12 types::{RedisFuture, RedisResult, Value},
13};
14use ::tokio::{
15 io::{AsyncRead, AsyncWrite},
16 sync::{mpsc, oneshot},
17};
18#[cfg(feature = "token-based-authentication")]
19use {
20 crate::errors::ErrorKind,
21 arcstr::ArcStr,
22 log::{debug, error},
23 std::sync::atomic::{AtomicBool, Ordering},
24};
25
26use futures_util::{
27 future::{Future, FutureExt},
28 ready,
29 sink::Sink,
30 stream::{self, Stream, StreamExt},
31};
32use pin_project_lite::pin_project;
33use std::collections::VecDeque;
34use std::fmt;
35use std::fmt::Debug;
36use std::pin::Pin;
37use std::sync::Arc;
38use std::task::{self, Poll};
39use std::time::Duration;
40use tokio_util::codec::Decoder;
41
42type PipelineOutput = oneshot::Sender<RedisResult<Value>>;
44
45enum ErrorOrErrors {
46 Errors(Vec<(usize, ServerError)>),
47 FirstError(RedisError),
49}
50
51enum ResponseAggregate {
52 SingleCommand,
53 Pipeline {
54 buffer: Vec<Value>,
55 error_or_errors: ErrorOrErrors,
56 expectation: PipelineResponseExpectation,
57 },
58}
59
60struct PipelineResponseExpectation {
62 skipped_response_count: usize,
64 expected_response_count: usize,
66 is_transaction: bool,
68 seen_responses: usize,
69}
70
71impl ResponseAggregate {
72 fn new(expectation: Option<PipelineResponseExpectation>) -> Self {
73 match expectation {
74 Some(expectation) => ResponseAggregate::Pipeline {
75 buffer: Vec::new(),
76 error_or_errors: ErrorOrErrors::Errors(Vec::new()),
77 expectation,
78 },
79 None => ResponseAggregate::SingleCommand,
80 }
81 }
82}
83
84struct InFlight {
85 output: Option<PipelineOutput>,
86 response_aggregate: ResponseAggregate,
87}
88
89struct PipelineMessage {
91 input: Vec<u8>,
92 output: Option<PipelineOutput>,
94 expectation: Option<PipelineResponseExpectation>,
98}
99
100#[derive(Clone)]
105struct Pipeline {
106 sender: mpsc::Sender<PipelineMessage>,
107}
108
109impl Debug for Pipeline {
110 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111 f.debug_tuple("Pipeline").field(&self.sender).finish()
112 }
113}
114
115#[cfg(feature = "cache-aio")]
116pin_project! {
117 struct PipelineSink<T> {
118 #[pin]
119 sink_stream: T,
120 in_flight: VecDeque<InFlight>,
121 error: Option<RedisError>,
122 push_sender: Option<Arc<dyn AsyncPushSender>>,
123 cache_manager: Option<CacheManager>,
124 }
125}
126
127#[cfg(not(feature = "cache-aio"))]
128pin_project! {
129 struct PipelineSink<T> {
130 #[pin]
131 sink_stream: T,
132 in_flight: VecDeque<InFlight>,
133 error: Option<RedisError>,
134 push_sender: Option<Arc<dyn AsyncPushSender>>,
135 }
136}
137
138fn send_push(push_sender: &Option<Arc<dyn AsyncPushSender>>, info: PushInfo) {
139 if let Some(sender) = push_sender {
140 let _ = sender.send(info);
141 };
142}
143
144pub(crate) fn send_disconnect(push_sender: &Option<Arc<dyn AsyncPushSender>>) {
145 send_push(push_sender, PushInfo::disconnect());
146}
147
148impl<T> PipelineSink<T>
149where
150 T: Stream<Item = RedisResult<Value>> + 'static,
151{
152 fn new(
153 sink_stream: T,
154 push_sender: Option<Arc<dyn AsyncPushSender>>,
155 #[cfg(feature = "cache-aio")] cache_manager: Option<CacheManager>,
156 ) -> Self
157 where
158 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
159 {
160 PipelineSink {
161 sink_stream,
162 in_flight: VecDeque::new(),
163 error: None,
164 push_sender,
165 #[cfg(feature = "cache-aio")]
166 cache_manager,
167 }
168 }
169
170 fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
172 loop {
173 let item = ready!(self.as_mut().project().sink_stream.poll_next(cx));
174 let item = match item {
175 Some(result) => result,
176 None => Err(closed_connection_error()),
178 };
179
180 let is_unrecoverable = item.as_ref().is_err_and(|err| err.is_unrecoverable_error());
181 self.as_mut().send_result(item);
182 if is_unrecoverable {
183 let self_ = self.project();
184 send_disconnect(self_.push_sender);
185 return Poll::Ready(Err(()));
186 }
187 }
188 }
189
190 fn send_result(self: Pin<&mut Self>, result: RedisResult<Value>) {
191 let self_ = self.project();
192 let result = match result {
193 Ok(Value::Push { kind, data }) if !kind.has_reply() => {
195 #[cfg(feature = "cache-aio")]
196 if let Some(cache_manager) = &self_.cache_manager {
197 cache_manager.handle_push_value(&kind, &data);
198 }
199 send_push(self_.push_sender, PushInfo { kind, data });
200
201 return;
202 }
203 Ok(Value::Push { kind, data }) if kind.has_reply() => {
205 send_push(
206 self_.push_sender,
207 PushInfo {
208 kind: kind.clone(),
209 data: data.clone(),
210 },
211 );
212 Ok(Value::Push { kind, data })
213 }
214 _ => result,
215 };
216
217 let mut entry = match self_.in_flight.pop_front() {
218 Some(entry) => entry,
219 None => return,
220 };
221
222 match &mut entry.response_aggregate {
223 ResponseAggregate::SingleCommand => {
224 if let Some(output) = entry.output.take() {
225 _ = output.send(result);
226 }
227 }
228 ResponseAggregate::Pipeline {
229 buffer,
230 error_or_errors,
231 expectation:
232 PipelineResponseExpectation {
233 expected_response_count,
234 skipped_response_count,
235 is_transaction,
236 seen_responses,
237 },
238 } => {
239 *seen_responses += 1;
240 if *skipped_response_count > 0 {
241 if *is_transaction {
244 if let ErrorOrErrors::Errors(errs) = error_or_errors {
245 match result {
246 Ok(Value::ServerError(err)) => {
247 errs.push((*seen_responses - 2, err)); }
249 Err(err) => *error_or_errors = ErrorOrErrors::FirstError(err),
250 _ => {}
251 }
252 }
253 }
254
255 *skipped_response_count -= 1;
256 self_.in_flight.push_front(entry);
257 return;
258 }
259
260 match result {
261 Ok(item) => {
262 buffer.push(item);
263 }
264 Err(err) => {
265 if matches!(error_or_errors, ErrorOrErrors::Errors(_)) {
266 *error_or_errors = ErrorOrErrors::FirstError(err)
267 }
268 }
269 }
270
271 if buffer.len() < *expected_response_count {
272 self_.in_flight.push_front(entry);
274 return;
275 }
276
277 let response =
278 match std::mem::replace(error_or_errors, ErrorOrErrors::Errors(Vec::new())) {
279 ErrorOrErrors::Errors(errors) => {
280 if errors.is_empty() {
281 Ok(Value::Array(std::mem::take(buffer)))
282 } else {
283 Err(RedisError::make_aborted_transaction(errors))
284 }
285 }
286 ErrorOrErrors::FirstError(redis_error) => Err(redis_error),
287 };
288
289 if let Some(output) = entry.output.take() {
293 _ = output.send(response);
294 }
295 }
296 }
297 }
298}
299
300impl<T> Sink<PipelineMessage> for PipelineSink<T>
301where
302 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
303{
304 type Error = ();
305
306 fn poll_ready(
308 mut self: Pin<&mut Self>,
309 cx: &mut task::Context,
310 ) -> Poll<Result<(), Self::Error>> {
311 match ready!(self.as_mut().project().sink_stream.poll_ready(cx)) {
312 Ok(()) => Ok(()).into(),
313 Err(err) => {
314 *self.project().error = Some(err);
315 Ok(()).into()
316 }
317 }
318 }
319
320 fn start_send(
321 mut self: Pin<&mut Self>,
322 PipelineMessage {
323 input,
324 mut output,
325 expectation,
326 }: PipelineMessage,
327 ) -> Result<(), Self::Error> {
328 if output.as_ref().is_some_and(|output| output.is_closed()) {
332 return Ok(());
333 }
334
335 let self_ = self.as_mut().project();
336
337 if let Some(err) = self_.error.take() {
338 if let Some(output) = output.take() {
339 _ = output.send(Err(err));
340 }
341 return Err(());
342 }
343
344 match self_.sink_stream.start_send(input) {
345 Ok(()) => {
346 let response_aggregate = ResponseAggregate::new(expectation);
347 let entry = InFlight {
348 output,
349 response_aggregate,
350 };
351
352 self_.in_flight.push_back(entry);
353 Ok(())
354 }
355 Err(err) => {
356 if let Some(output) = output.take() {
357 _ = output.send(Err(err));
358 }
359 Err(())
360 }
361 }
362 }
363
364 fn poll_flush(
365 mut self: Pin<&mut Self>,
366 cx: &mut task::Context,
367 ) -> Poll<Result<(), Self::Error>> {
368 ready!(
369 self.as_mut()
370 .project()
371 .sink_stream
372 .poll_flush(cx)
373 .map_err(|err| {
374 self.as_mut().send_result(Err(err));
375 })
376 )?;
377 self.poll_read(cx)
378 }
379
380 fn poll_close(
381 mut self: Pin<&mut Self>,
382 cx: &mut task::Context,
383 ) -> Poll<Result<(), Self::Error>> {
384 if !self.in_flight.is_empty() {
387 ready!(self.as_mut().poll_flush(cx))?;
388 }
389 let this = self.as_mut().project();
390 this.sink_stream.poll_close(cx).map_err(|err| {
391 self.send_result(Err(err));
392 })
393 }
394}
395
396impl Pipeline {
397 const DEFAULT_BUFFER_SIZE: usize = 50;
398
399 fn resolve_buffer_size(size: Option<usize>) -> usize {
400 size.unwrap_or(Self::DEFAULT_BUFFER_SIZE)
401 }
402
403 fn new<T>(
404 sink_stream: T,
405 push_sender: Option<Arc<dyn AsyncPushSender>>,
406 #[cfg(feature = "cache-aio")] cache_manager: Option<CacheManager>,
407 buffer_size: usize,
408 ) -> (Self, impl Future<Output = ()>)
409 where
410 T: Sink<Vec<u8>, Error = RedisError>,
411 T: Stream<Item = RedisResult<Value>>,
412 T: Unpin + Send + 'static,
413 {
414 let (sender, mut receiver) = mpsc::channel(buffer_size);
415
416 let sink = PipelineSink::new(
417 sink_stream,
418 push_sender,
419 #[cfg(feature = "cache-aio")]
420 cache_manager,
421 );
422 let f = stream::poll_fn(move |cx| receiver.poll_recv(cx))
423 .map(Ok)
424 .forward(sink)
425 .map(|_| ());
426 (Pipeline { sender }, f)
427 }
428
429 async fn send_recv(
430 &mut self,
431 input: Vec<u8>,
432 expectation: Option<PipelineResponseExpectation>,
435 timeout: Option<Duration>,
436 skip_response: bool,
437 ) -> Result<Value, RedisError> {
438 if input.is_empty() {
439 return Err(RedisError::make_empty_command());
440 }
441
442 let request = async {
443 if skip_response {
444 self.sender
445 .send(PipelineMessage {
446 input,
447 expectation,
448 output: None,
449 })
450 .await
451 .map_err(|_| None)?;
452
453 return Ok(Value::Nil);
454 }
455
456 let (sender, receiver) = oneshot::channel();
457
458 self.sender
459 .send(PipelineMessage {
460 input,
461 expectation,
462 output: Some(sender),
463 })
464 .await
465 .map_err(|_| None)?;
466
467 receiver.await
468 .map_err(|_| None)
471 .and_then(|res| res.map_err(Some))
472 };
473
474 match timeout {
475 Some(timeout) => match Runtime::locate().timeout(timeout, request).await {
476 Ok(res) => res,
477 Err(elapsed) => Err(Some(elapsed.into())),
478 },
479 None => request.await,
480 }
481 .map_err(|err| err.unwrap_or_else(closed_connection_error))
482 }
483}
484
485#[derive(Clone)]
498pub struct MultiplexedConnection {
499 pipeline: Pipeline,
500 db: i64,
501 response_timeout: Option<Duration>,
502 protocol: ProtocolVersion,
503 concurrency_limiter: Option<Arc<async_lock::Semaphore>>,
504 _task_handle: Option<SharedHandleContainer>,
508 #[cfg(feature = "cache-aio")]
509 pub(crate) cache_manager: Option<CacheManager>,
510 #[cfg(feature = "token-based-authentication")]
511 _credentials_subscription_task_handle: Option<SharedHandleContainer>,
514 #[cfg(feature = "token-based-authentication")]
517 re_authentication_failed: Arc<AtomicBool>,
518}
519
520impl Debug for MultiplexedConnection {
521 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
522 let MultiplexedConnection {
523 pipeline,
524 db,
525 response_timeout,
526 protocol,
527 concurrency_limiter: _,
528 _task_handle,
529 #[cfg(feature = "cache-aio")]
530 cache_manager: _,
531 #[cfg(feature = "token-based-authentication")]
532 _credentials_subscription_task_handle: _,
533 #[cfg(feature = "token-based-authentication")]
534 re_authentication_failed: _,
535 } = self;
536
537 f.debug_struct("MultiplexedConnection")
538 .field("pipeline", &pipeline)
539 .field("db", &db)
540 .field("response_timeout", &response_timeout)
541 .field("protocol", &protocol)
542 .finish()
543 }
544}
545
546impl MultiplexedConnection {
547 pub async fn new<C>(
550 connection_info: &RedisConnectionInfo,
551 stream: C,
552 ) -> RedisResult<(Self, impl Future<Output = ()>)>
553 where
554 C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
555 {
556 Self::new_with_config(connection_info, stream, AsyncConnectionConfig::default()).await
557 }
558
559 pub async fn new_with_config<C>(
562 connection_info: &RedisConnectionInfo,
563 stream: C,
564 config: AsyncConnectionConfig,
565 ) -> RedisResult<(Self, impl Future<Output = ()> + 'static)>
566 where
567 C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
568 {
569 let mut codec = ValueCodec::default().framed(stream);
570 if config.push_sender.is_some() {
571 check_resp3!(
572 connection_info.protocol,
573 "Can only pass push sender to a connection using RESP3"
574 );
575 }
576
577 #[cfg(feature = "cache-aio")]
578 let cache_config = config.cache.as_ref().map(|cache| match cache {
579 crate::client::Cache::Config(cache_config) => *cache_config,
580 #[cfg(any(feature = "connection-manager", feature = "cluster-async"))]
581 crate::client::Cache::Manager(cache_manager) => cache_manager.cache_config,
582 });
583 #[cfg(feature = "cache-aio")]
584 let cache_manager_opt = config
585 .cache
586 .map(|cache| {
587 check_resp3!(
588 connection_info.protocol,
589 "Can only enable client side caching in a connection using RESP3"
590 );
591 match cache {
592 crate::client::Cache::Config(cache_config) => {
593 Ok(CacheManager::new(cache_config))
594 }
595 #[cfg(any(feature = "connection-manager", feature = "cluster-async"))]
596 crate::client::Cache::Manager(cache_manager) => Ok(cache_manager),
597 }
598 })
599 .transpose()?;
600
601 #[cfg(feature = "token-based-authentication")]
602 let mut connection_info = connection_info.clone();
603 #[cfg(not(feature = "token-based-authentication"))]
604 let connection_info = connection_info.clone();
605
606 #[cfg(feature = "token-based-authentication")]
607 if let Some(ref credentials_provider) = config.credentials_provider {
608 match credentials_provider.subscribe().next().await {
610 Some(Ok(credentials)) => {
611 connection_info.username = Some(ArcStr::from(credentials.username));
612 connection_info.password = Some(ArcStr::from(credentials.password));
613 }
614 Some(Err(err)) => {
615 error!("Error while receiving credentials from stream: {err}");
616 return Err(err);
617 }
618 None => {
619 let err = RedisError::from((
620 ErrorKind::AuthenticationFailed,
621 "Credentials stream closed unexpectedly before yielding credentials!",
622 ));
623 error!("{err}");
624 return Err(err);
625 }
626 }
627 }
628
629 setup_connection(
630 &mut codec,
631 &connection_info,
632 #[cfg(feature = "cache-aio")]
633 cache_config,
634 )
635 .await?;
636 if config.push_sender.is_some() {
637 check_resp3!(
638 connection_info.protocol,
639 "Can only pass push sender to a connection using RESP3"
640 );
641 }
642
643 let (pipeline, driver) = Pipeline::new(
644 codec,
645 config.push_sender,
646 #[cfg(feature = "cache-aio")]
647 cache_manager_opt.clone(),
648 Pipeline::resolve_buffer_size(config.pipeline_buffer_size),
649 );
650
651 let concurrency_limiter = config
652 .concurrency_limit
653 .map(|n| Arc::new(async_lock::Semaphore::new(n)));
654
655 let con = MultiplexedConnection {
656 pipeline,
657 db: connection_info.db,
658 response_timeout: config.response_timeout,
659 protocol: connection_info.protocol,
660 concurrency_limiter,
661 _task_handle: None,
662 #[cfg(feature = "cache-aio")]
663 cache_manager: cache_manager_opt,
664 #[cfg(feature = "token-based-authentication")]
665 _credentials_subscription_task_handle: None,
666 #[cfg(feature = "token-based-authentication")]
667 re_authentication_failed: Arc::new(AtomicBool::new(false)),
668 };
669
670 #[cfg(feature = "token-based-authentication")]
672 if let Some(streaming_provider) = config.credentials_provider {
673 let mut inner_connection = con.clone();
674 let re_authentication_failed_arc = Arc::clone(&con.re_authentication_failed);
675 let mut stream = streaming_provider.subscribe();
676
677 let subscription_task_handle = Runtime::locate().spawn(async move {
678 let mut error_cause_logged = false;
679 while let Some(result) = stream.next().await {
680 match result {
681 Ok(credentials) => {
682 if let Err(err) = inner_connection
683 .re_authenticate_with_credentials(&credentials)
684 .await
685 {
686 error!("Failed to re-authenticate async connection: {err}.");
687 error_cause_logged = true;
688 re_authentication_failed_arc.store(true, Ordering::Relaxed);
689 break;
690 } else {
691 debug!("Re-authenticated async connection");
692 }
693 }
694 Err(err) => {
695 error!("Credentials stream error for async connection: {err}.");
696 error_cause_logged = true;
697 }
698 }
699 }
700 if !re_authentication_failed_arc.load(Ordering::Relaxed) {
701 if !error_cause_logged {
702 error!("Re-authentication stream ended unexpectedly.");
703 }
704 re_authentication_failed_arc.store(true, Ordering::Relaxed);
705 }
706 });
707 return Ok((
708 Self {
709 _credentials_subscription_task_handle: Some(SharedHandleContainer::new(
710 subscription_task_handle,
711 )),
712 ..con
713 },
714 driver,
715 ));
716 }
717
718 Ok((con, driver))
719 }
720
721 pub(crate) fn set_task_handle(&mut self, handle: TaskHandle) {
724 self._task_handle = Some(SharedHandleContainer::new(handle));
725 }
726
727 pub fn set_response_timeout(&mut self, timeout: std::time::Duration) {
729 self.response_timeout = Some(timeout);
730 }
731
732 pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
735 let _permit = if cmd.skip_concurrency_limit {
736 None
737 } else if let Some(limiter) = &self.concurrency_limiter {
738 Some(limiter.acquire().await)
739 } else {
740 None
741 };
742 #[cfg(feature = "token-based-authentication")]
743 if self.re_authentication_failed.load(Ordering::Relaxed) {
744 return Err(RedisError::from((
745 ErrorKind::AuthenticationFailed,
746 "Connection is no longer usable due to re-authentication failure",
747 )));
748 }
749 #[cfg(feature = "cache-aio")]
750 if let Some(cache_manager) = &self.cache_manager {
751 match cache_manager.get_cached_cmd(cmd) {
752 PrepareCacheResult::Cached(value) => return Ok(value),
753 PrepareCacheResult::NotCached(cacheable_command) => {
754 let mut pipeline = crate::Pipeline::new();
755 cacheable_command.pack_command(cache_manager, &mut pipeline);
756
757 let result = self
758 .pipeline
759 .send_recv(
760 pipeline.get_packed_pipeline(),
761 Some(PipelineResponseExpectation {
762 skipped_response_count: 0,
763 expected_response_count: pipeline.commands.len(),
764 is_transaction: false,
765 seen_responses: 0,
766 }),
767 self.response_timeout,
768 cmd.is_no_response(),
769 )
770 .await?;
771 let replies: Vec<Value> = crate::types::from_redis_value(result)?;
772 return cacheable_command.resolve(cache_manager, replies.into_iter());
773 }
774 _ => (),
775 }
776 }
777 self.pipeline
778 .send_recv(
779 cmd.get_packed_command(),
780 None,
781 self.response_timeout,
782 cmd.is_no_response(),
783 )
784 .await
785 }
786
787 pub async fn send_packed_commands(
791 &mut self,
792 cmd: &crate::Pipeline,
793 offset: usize,
794 count: usize,
795 ) -> RedisResult<Vec<Value>> {
796 let _permits = if let Some(limiter) = &self.concurrency_limiter {
801 let mut permits = Vec::with_capacity(count.max(1));
802 permits.push(limiter.acquire().await);
803 for _ in 1..count {
804 match limiter.try_acquire() {
805 Some(permit) => permits.push(permit),
806 None => break,
807 }
808 }
809 permits
810 } else {
811 Vec::new()
812 };
813 #[cfg(feature = "token-based-authentication")]
814 if self.re_authentication_failed.load(Ordering::Relaxed) {
815 return Err(RedisError::from((
816 ErrorKind::AuthenticationFailed,
817 "Connection is no longer usable due to re-authentication failure",
818 )));
819 }
820 #[cfg(feature = "cache-aio")]
821 if let Some(cache_manager) = &self.cache_manager {
822 let (cacheable_pipeline, pipeline, (skipped_response_count, expected_response_count)) =
823 cache_manager.get_cached_pipeline(cmd);
824 if pipeline.is_empty() {
825 return cacheable_pipeline.resolve(cache_manager, Value::Array(Vec::new()));
826 }
827 let result = self
828 .pipeline
829 .send_recv(
830 pipeline.get_packed_pipeline(),
831 Some(PipelineResponseExpectation {
832 skipped_response_count,
833 expected_response_count,
834 is_transaction: cacheable_pipeline.transaction_mode,
835 seen_responses: 0,
836 }),
837 self.response_timeout,
838 false,
839 )
840 .await?;
841
842 return cacheable_pipeline.resolve(cache_manager, result);
843 }
844 let value = self
845 .pipeline
846 .send_recv(
847 cmd.get_packed_pipeline(),
848 Some(PipelineResponseExpectation {
849 skipped_response_count: offset,
850 expected_response_count: count,
851 is_transaction: cmd.is_transaction(),
852 seen_responses: 0,
853 }),
854 self.response_timeout,
855 false,
856 )
857 .await?;
858 match value {
859 Value::Array(values) => Ok(values),
860 _ => Ok(vec![value]),
861 }
862 }
863
864 #[cfg(feature = "cache-aio")]
866 #[cfg_attr(docsrs, doc(cfg(feature = "cache-aio")))]
867 pub fn get_cache_statistics(&self) -> Option<CacheStatistics> {
868 self.cache_manager.as_ref().map(|cm| cm.statistics())
869 }
870}
871
872impl ConnectionLike for MultiplexedConnection {
873 fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
874 (async move { self.send_packed_command(cmd).await }).boxed()
875 }
876
877 fn req_packed_commands<'a>(
878 &'a mut self,
879 cmd: &'a crate::Pipeline,
880 offset: usize,
881 count: usize,
882 ) -> RedisFuture<'a, Vec<Value>> {
883 (async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
884 }
885
886 fn get_db(&self) -> i64 {
887 self.db
888 }
889}
890
891impl MultiplexedConnection {
892 pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
909 check_resp3!(self.protocol);
910 let mut cmd = cmd("SUBSCRIBE");
911 cmd.arg(channel_name);
912 cmd.exec_async(self).await?;
913 Ok(())
914 }
915
916 pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
931 check_resp3!(self.protocol);
932 let mut cmd = cmd("UNSUBSCRIBE");
933 cmd.arg(channel_name);
934 cmd.exec_async(self).await?;
935 Ok(())
936 }
937
938 pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
957 check_resp3!(self.protocol);
958 let mut cmd = cmd("PSUBSCRIBE");
959 cmd.arg(channel_pattern);
960 cmd.exec_async(self).await?;
961 Ok(())
962 }
963
964 pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
968 check_resp3!(self.protocol);
969 let mut cmd = cmd("PUNSUBSCRIBE");
970 cmd.arg(channel_pattern);
971 cmd.exec_async(self).await?;
972 Ok(())
973 }
974}
975
976#[cfg(feature = "token-based-authentication")]
977impl MultiplexedConnection {
978 async fn re_authenticate_with_credentials(
983 &mut self,
984 credentials: &crate::auth::BasicAuth,
985 ) -> RedisResult<()> {
986 let mut auth_cmd =
987 crate::connection::authenticate_cmd(Some(&credentials.username), &credentials.password);
988 auth_cmd.skip_concurrency_limit = true;
989 self.send_packed_command(&auth_cmd)
990 .await?
991 .extract_error()
992 .map(|_| ())
993 }
994}
995
996#[cfg(test)]
997mod tests {
998 use super::*;
999
1000 #[test]
1001 fn test_pipeline_resolve_buffer_size_default() {
1002 assert_eq!(Pipeline::resolve_buffer_size(None), 50);
1003 }
1004
1005 #[test]
1006 fn test_pipeline_resolve_buffer_size_custom() {
1007 assert_eq!(Pipeline::resolve_buffer_size(Some(100)), 100);
1008 }
1009
1010 fn mock_conn_info() -> RedisConnectionInfo {
1011 RedisConnectionInfo {
1012 skip_set_lib_name: true,
1013 ..Default::default()
1014 }
1015 }
1016
1017 async fn create_mock_connection(
1018 concurrency_limit: usize,
1019 ) -> (
1020 MultiplexedConnection,
1021 tokio::sync::mpsc::Receiver<()>,
1022 tokio::sync::mpsc::Sender<()>,
1023 ) {
1024 use futures_util::StreamExt;
1025 use tokio::io::AsyncWriteExt;
1026 use tokio_util::codec::FramedRead;
1027
1028 let (client_half, server_half) = tokio::io::duplex(4096);
1029 let (cmd_received_tx, cmd_received_rx) = tokio::sync::mpsc::channel::<()>(10);
1030 let (send_response_tx, mut send_response_rx) = tokio::sync::mpsc::channel::<()>(10);
1031
1032 let (server_read, mut server_write) = tokio::io::split(server_half);
1033
1034 tokio::spawn(async move {
1035 let mut reader = FramedRead::new(server_read, ValueCodec::default());
1036 while let Some(Ok(_)) = reader.next().await {
1037 let _ = cmd_received_tx.send(()).await;
1038 }
1039 });
1040
1041 tokio::spawn(async move {
1042 while send_response_rx.recv().await.is_some() {
1043 let _ = server_write.write_all(b"+OK\r\n").await;
1044 let _ = server_write.flush().await;
1045 }
1046 });
1047
1048 let config = AsyncConnectionConfig::new()
1049 .set_concurrency_limit(concurrency_limit)
1050 .set_response_timeout(None)
1051 .set_connection_timeout(None);
1052
1053 let (conn, driver) =
1054 MultiplexedConnection::new_with_config(&mock_conn_info(), client_half, config)
1055 .await
1056 .unwrap();
1057 tokio::spawn(driver);
1058
1059 (conn, cmd_received_rx, send_response_tx)
1060 }
1061
1062 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1063 async fn test_concurrency_limit_enforced() {
1064 let (conn, mut cmd_received_rx, send_response_tx) = create_mock_connection(2).await;
1065
1066 let h1 = tokio::spawn({
1067 let mut c = conn.clone();
1068 async move { c.send_packed_command(&cmd("PING")).await }
1069 });
1070 let h2 = tokio::spawn({
1071 let mut c = conn.clone();
1072 async move { c.send_packed_command(&cmd("PING")).await }
1073 });
1074 let h3 = tokio::spawn({
1075 let mut c = conn.clone();
1076 async move { c.send_packed_command(&cmd("PING")).await }
1077 });
1078
1079 cmd_received_rx.recv().await.unwrap();
1080 cmd_received_rx.recv().await.unwrap();
1081
1082 let third = tokio::time::timeout(Duration::from_millis(100), cmd_received_rx.recv()).await;
1083 assert!(
1084 third.is_err(),
1085 "3rd request should be blocked by concurrency limit"
1086 );
1087
1088 send_response_tx.send(()).await.unwrap();
1089
1090 cmd_received_rx.recv().await.unwrap();
1091
1092 send_response_tx.send(()).await.unwrap();
1093 send_response_tx.send(()).await.unwrap();
1094
1095 h1.await.unwrap().unwrap();
1096 h2.await.unwrap().unwrap();
1097 h3.await.unwrap().unwrap();
1098 }
1099
1100 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1101 async fn test_no_limit_bypasses_concurrency_limit() {
1102 let (conn, mut cmd_received_rx, send_response_tx) = create_mock_connection(1).await;
1103
1104 let h1 = tokio::spawn({
1105 let mut c = conn.clone();
1106 async move { c.send_packed_command(&cmd("PING")).await }
1107 });
1108
1109 cmd_received_rx.recv().await.unwrap();
1110
1111 let h2 = tokio::spawn({
1112 let mut c = conn.clone();
1113 async move {
1114 let mut ping = cmd("PING");
1115 ping.skip_concurrency_limit = true;
1116 c.send_packed_command(&ping).await
1117 }
1118 });
1119
1120 let received =
1121 tokio::time::timeout(Duration::from_millis(100), cmd_received_rx.recv()).await;
1122 assert!(
1123 received.is_ok(),
1124 "no_limit request should bypass concurrency limit"
1125 );
1126
1127 send_response_tx.send(()).await.unwrap();
1128 send_response_tx.send(()).await.unwrap();
1129
1130 h1.await.unwrap().unwrap();
1131 h2.await.unwrap().unwrap();
1132 }
1133
1134 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1135 async fn test_pipeline_acquires_multiple_permits() {
1136 let (conn, mut cmd_received_rx, send_response_tx) = create_mock_connection(3).await;
1137
1138 let pipeline_handle = tokio::spawn({
1139 let mut c = conn.clone();
1140 async move {
1141 let mut pipe = crate::Pipeline::new();
1142 pipe.cmd("SET").arg("a").arg("1");
1143 pipe.cmd("SET").arg("b").arg("2");
1144 pipe.cmd("SET").arg("c").arg("3");
1145 c.send_packed_commands(&pipe, 0, 3).await
1146 }
1147 });
1148
1149 for _ in 0..3 {
1150 cmd_received_rx.recv().await.unwrap();
1151 }
1152
1153 let single_handle = tokio::spawn({
1154 let mut c = conn.clone();
1155 async move { c.send_packed_command(&cmd("PING")).await }
1156 });
1157
1158 let blocked =
1159 tokio::time::timeout(Duration::from_millis(100), cmd_received_rx.recv()).await;
1160 assert!(
1161 blocked.is_err(),
1162 "single command should be blocked while pipeline holds all permits"
1163 );
1164
1165 for _ in 0..3 {
1166 send_response_tx.send(()).await.unwrap();
1167 }
1168
1169 cmd_received_rx.recv().await.unwrap();
1170 send_response_tx.send(()).await.unwrap();
1171
1172 pipeline_handle.await.unwrap().unwrap();
1173 single_handle.await.unwrap().unwrap();
1174 }
1175
1176 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1177 async fn test_pipeline_proceeds_with_partial_permits() {
1178 let (conn, mut cmd_received_rx, send_response_tx) = create_mock_connection(2).await;
1179
1180 let single_handle = tokio::spawn({
1181 let mut c = conn.clone();
1182 async move { c.send_packed_command(&cmd("PING")).await }
1183 });
1184 cmd_received_rx.recv().await.unwrap();
1185
1186 let pipeline_handle = tokio::spawn({
1187 let mut c = conn.clone();
1188 async move {
1189 let mut pipe = crate::Pipeline::new();
1190 for i in 0..5 {
1191 pipe.cmd("SET").arg(format!("k{i}")).arg(i);
1192 }
1193 c.send_packed_commands(&pipe, 0, 5).await
1194 }
1195 });
1196
1197 let received =
1198 tokio::time::timeout(Duration::from_millis(100), cmd_received_rx.recv()).await;
1199 assert!(
1200 received.is_ok(),
1201 "pipeline should proceed even with only partial permits"
1202 );
1203
1204 for _ in 1..5 {
1205 cmd_received_rx.recv().await.unwrap();
1206 }
1207
1208 for _ in 0..6 {
1209 send_response_tx.send(()).await.unwrap();
1210 }
1211
1212 single_handle.await.unwrap().unwrap();
1213 pipeline_handle.await.unwrap().unwrap();
1214 }
1215
1216 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1217 async fn test_permit_released_on_cancellation() {
1218 let (conn, mut cmd_received_rx, send_response_tx) = create_mock_connection(1).await;
1219
1220 let h1 = tokio::spawn({
1221 let mut c = conn.clone();
1222 async move { c.send_packed_command(&cmd("PING")).await }
1223 });
1224 cmd_received_rx.recv().await.unwrap();
1225
1226 let h2 = tokio::spawn({
1228 let mut c = conn.clone();
1229 async move { c.send_packed_command(&cmd("PING")).await }
1230 });
1231 tokio::time::sleep(Duration::from_millis(50)).await;
1232 h2.abort();
1233 let _ = h2.await;
1234
1235 send_response_tx.send(()).await.unwrap();
1237 h1.await.unwrap().unwrap();
1238
1239 let h3 = tokio::spawn({
1242 let mut c = conn.clone();
1243 async move { c.send_packed_command(&cmd("PING")).await }
1244 });
1245
1246 let received =
1247 tokio::time::timeout(Duration::from_millis(100), cmd_received_rx.recv()).await;
1248 assert!(
1249 received.is_ok(),
1250 "request after cancellation should acquire the permit"
1251 );
1252
1253 send_response_tx.send(()).await.unwrap();
1254 h3.await.unwrap().unwrap();
1255 }
1256
1257 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1258 async fn test_permit_released_on_response_timeout() {
1259 use futures_util::StreamExt;
1260 use tokio::io::AsyncWriteExt;
1261 use tokio_util::codec::FramedRead;
1262
1263 let (client_half, server_half) = tokio::io::duplex(4096);
1264 let (cmd_received_tx, mut cmd_received_rx) = tokio::sync::mpsc::channel::<()>(10);
1265
1266 let (server_read, mut server_write) = tokio::io::split(server_half);
1267
1268 tokio::spawn(async move {
1269 let mut reader = FramedRead::new(server_read, ValueCodec::default());
1270 while let Some(Ok(_)) = reader.next().await {
1271 let _ = cmd_received_tx.send(()).await;
1272 }
1273 });
1274
1275 tokio::spawn(async move {
1276 futures_util::future::pending::<()>().await;
1277 let _ = server_write.write_all(b"").await;
1278 });
1279
1280 let config = AsyncConnectionConfig::new()
1281 .set_concurrency_limit(1)
1282 .set_response_timeout(Some(Duration::from_millis(100)))
1283 .set_connection_timeout(None);
1284
1285 let (conn, driver) =
1286 MultiplexedConnection::new_with_config(&mock_conn_info(), client_half, config)
1287 .await
1288 .unwrap();
1289 tokio::spawn(driver);
1290
1291 let mut c1 = conn.clone();
1293 let err = c1.send_packed_command(&cmd("PING")).await.unwrap_err();
1294 assert!(err.is_io_error(), "expected IO error from timeout");
1295 cmd_received_rx.recv().await.unwrap();
1296
1297 let mut c2 = conn.clone();
1300 let err = c2.send_packed_command(&cmd("PING")).await.unwrap_err();
1301 assert!(err.is_io_error(), "expected IO error from timeout");
1302 cmd_received_rx.recv().await.unwrap();
1303 }
1304}