Skip to main content

redis/aio/
multiplexed_connection.rs

1use super::{AsyncPushSender, ConnectionLike, Runtime, SharedHandleContainer, TaskHandle};
2#[cfg(feature = "cache-aio")]
3use crate::caching::{CacheManager, CacheStatistics, PrepareCacheResult};
4use crate::{
5    AsyncConnectionConfig, ProtocolVersion, PushInfo, RedisConnectionInfo, ServerError,
6    ToRedisArgs,
7    aio::setup_connection,
8    check_resp3, cmd,
9    cmd::Cmd,
10    errors::{RedisError, closed_connection_error},
11    parser::ValueCodec,
12    types::{RedisFuture, RedisResult, Value},
13};
14use ::tokio::{
15    io::{AsyncRead, AsyncWrite},
16    sync::{mpsc, oneshot},
17};
18#[cfg(feature = "token-based-authentication")]
19use {
20    crate::errors::ErrorKind,
21    arcstr::ArcStr,
22    log::{debug, error},
23    std::sync::atomic::{AtomicBool, Ordering},
24};
25
26use futures_util::{
27    future::{Future, FutureExt},
28    ready,
29    sink::Sink,
30    stream::{self, Stream, StreamExt},
31};
32use pin_project_lite::pin_project;
33use std::collections::VecDeque;
34use std::fmt;
35use std::fmt::Debug;
36use std::pin::Pin;
37use std::sync::Arc;
38use std::task::{self, Poll};
39use std::time::Duration;
40use tokio_util::codec::Decoder;
41
42// Senders which the result of a single request are sent through
43type PipelineOutput = oneshot::Sender<RedisResult<Value>>;
44
45enum ErrorOrErrors {
46    Errors(Vec<(usize, ServerError)>),
47    // only set if we receive a transmission error
48    FirstError(RedisError),
49}
50
51enum ResponseAggregate {
52    SingleCommand,
53    Pipeline {
54        buffer: Vec<Value>,
55        error_or_errors: ErrorOrErrors,
56        expectation: PipelineResponseExpectation,
57    },
58}
59
60// TODO - this is a really bad name.
61struct PipelineResponseExpectation {
62    // The number of responses to skip before starting to save responses in the buffer.
63    skipped_response_count: usize,
64    // The number of responses to keep in the buffer
65    expected_response_count: usize,
66    // whether the pipelined request is a transaction
67    is_transaction: bool,
68    seen_responses: usize,
69}
70
71impl ResponseAggregate {
72    fn new(expectation: Option<PipelineResponseExpectation>) -> Self {
73        match expectation {
74            Some(expectation) => ResponseAggregate::Pipeline {
75                buffer: Vec::new(),
76                error_or_errors: ErrorOrErrors::Errors(Vec::new()),
77                expectation,
78            },
79            None => ResponseAggregate::SingleCommand,
80        }
81    }
82}
83
84struct InFlight {
85    output: Option<PipelineOutput>,
86    response_aggregate: ResponseAggregate,
87}
88
89// A single message sent through the pipeline
90struct PipelineMessage {
91    input: Vec<u8>,
92    // If `output` is None, then the caller doesn't expect to receive an answer.
93    output: Option<PipelineOutput>,
94    // If `None`, this is a single request, not a pipeline of multiple requests.
95    // If `Some`, the first value is the number of responses to skip,
96    // the second is the number of responses to keep, and the third is whether the pipeline is a transaction.
97    expectation: Option<PipelineResponseExpectation>,
98}
99
100/// Wrapper around a `Stream + Sink` where each item sent through the `Sink` results in one or more
101/// items being output by the `Stream` (the number is specified at time of sending). With the
102/// interface provided by `Pipeline` an easy interface of request to response, hiding the `Stream`
103/// and `Sink`.
104#[derive(Clone)]
105struct Pipeline {
106    sender: mpsc::Sender<PipelineMessage>,
107}
108
109impl Debug for Pipeline {
110    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111        f.debug_tuple("Pipeline").field(&self.sender).finish()
112    }
113}
114
115#[cfg(feature = "cache-aio")]
116pin_project! {
117    struct PipelineSink<T> {
118        #[pin]
119        sink_stream: T,
120        in_flight: VecDeque<InFlight>,
121        error: Option<RedisError>,
122        push_sender: Option<Arc<dyn AsyncPushSender>>,
123        cache_manager: Option<CacheManager>,
124    }
125}
126
127#[cfg(not(feature = "cache-aio"))]
128pin_project! {
129    struct PipelineSink<T> {
130        #[pin]
131        sink_stream: T,
132        in_flight: VecDeque<InFlight>,
133        error: Option<RedisError>,
134        push_sender: Option<Arc<dyn AsyncPushSender>>,
135    }
136}
137
138fn send_push(push_sender: &Option<Arc<dyn AsyncPushSender>>, info: PushInfo) {
139    if let Some(sender) = push_sender {
140        let _ = sender.send(info);
141    };
142}
143
144pub(crate) fn send_disconnect(push_sender: &Option<Arc<dyn AsyncPushSender>>) {
145    send_push(push_sender, PushInfo::disconnect());
146}
147
148impl<T> PipelineSink<T>
149where
150    T: Stream<Item = RedisResult<Value>> + 'static,
151{
152    fn new(
153        sink_stream: T,
154        push_sender: Option<Arc<dyn AsyncPushSender>>,
155        #[cfg(feature = "cache-aio")] cache_manager: Option<CacheManager>,
156    ) -> Self
157    where
158        T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
159    {
160        PipelineSink {
161            sink_stream,
162            in_flight: VecDeque::new(),
163            error: None,
164            push_sender,
165            #[cfg(feature = "cache-aio")]
166            cache_manager,
167        }
168    }
169
170    // Read messages from the stream and send them back to the caller
171    fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
172        loop {
173            let item = ready!(self.as_mut().project().sink_stream.poll_next(cx));
174            let item = match item {
175                Some(result) => result,
176                // The redis response stream is not going to produce any more items so we simulate a disconnection error to break out of the loop.
177                None => Err(closed_connection_error()),
178            };
179
180            let is_unrecoverable = item.as_ref().is_err_and(|err| err.is_unrecoverable_error());
181            self.as_mut().send_result(item);
182            if is_unrecoverable {
183                let self_ = self.project();
184                send_disconnect(self_.push_sender);
185                return Poll::Ready(Err(()));
186            }
187        }
188    }
189
190    fn send_result(self: Pin<&mut Self>, result: RedisResult<Value>) {
191        let self_ = self.project();
192        let result = match result {
193            // If this push message isn't a reply, we'll pass it as-is to the push manager and stop iterating
194            Ok(Value::Push { kind, data }) if !kind.has_reply() => {
195                #[cfg(feature = "cache-aio")]
196                if let Some(cache_manager) = &self_.cache_manager {
197                    cache_manager.handle_push_value(&kind, &data);
198                }
199                send_push(self_.push_sender, PushInfo { kind, data });
200
201                return;
202            }
203            // If this push message is a reply to a query, we'll clone it to the push manager and continue with sending the reply
204            Ok(Value::Push { kind, data }) if kind.has_reply() => {
205                send_push(
206                    self_.push_sender,
207                    PushInfo {
208                        kind: kind.clone(),
209                        data: data.clone(),
210                    },
211                );
212                Ok(Value::Push { kind, data })
213            }
214            _ => result,
215        };
216
217        let mut entry = match self_.in_flight.pop_front() {
218            Some(entry) => entry,
219            None => return,
220        };
221
222        match &mut entry.response_aggregate {
223            ResponseAggregate::SingleCommand => {
224                if let Some(output) = entry.output.take() {
225                    _ = output.send(result);
226                }
227            }
228            ResponseAggregate::Pipeline {
229                buffer,
230                error_or_errors,
231                expectation:
232                    PipelineResponseExpectation {
233                        expected_response_count,
234                        skipped_response_count,
235                        is_transaction,
236                        seen_responses,
237                    },
238            } => {
239                *seen_responses += 1;
240                if *skipped_response_count > 0 {
241                    // server errors in skipped values are still counted for errors in transactions, since they're errors that will cause the transaction to fail,
242                    // and we only skip values in transaction.
243                    if *is_transaction {
244                        if let ErrorOrErrors::Errors(errs) = error_or_errors {
245                            match result {
246                                Ok(Value::ServerError(err)) => {
247                                    errs.push((*seen_responses - 2, err)); // - 1 to offset the early increment, and -1 to offset the added MULTI call.
248                                }
249                                Err(err) => *error_or_errors = ErrorOrErrors::FirstError(err),
250                                _ => {}
251                            }
252                        }
253                    }
254
255                    *skipped_response_count -= 1;
256                    self_.in_flight.push_front(entry);
257                    return;
258                }
259
260                match result {
261                    Ok(item) => {
262                        buffer.push(item);
263                    }
264                    Err(err) => {
265                        if matches!(error_or_errors, ErrorOrErrors::Errors(_)) {
266                            *error_or_errors = ErrorOrErrors::FirstError(err)
267                        }
268                    }
269                }
270
271                if buffer.len() < *expected_response_count {
272                    // Need to gather more response values
273                    self_.in_flight.push_front(entry);
274                    return;
275                }
276
277                let response =
278                    match std::mem::replace(error_or_errors, ErrorOrErrors::Errors(Vec::new())) {
279                        ErrorOrErrors::Errors(errors) => {
280                            if errors.is_empty() {
281                                Ok(Value::Array(std::mem::take(buffer)))
282                            } else {
283                                Err(RedisError::make_aborted_transaction(errors))
284                            }
285                        }
286                        ErrorOrErrors::FirstError(redis_error) => Err(redis_error),
287                    };
288
289                // `Err` means that the receiver was dropped in which case it does not
290                // care about the output and we can continue by just dropping the value
291                // and sender
292                if let Some(output) = entry.output.take() {
293                    _ = output.send(response);
294                }
295            }
296        }
297    }
298}
299
300impl<T> Sink<PipelineMessage> for PipelineSink<T>
301where
302    T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
303{
304    type Error = ();
305
306    // Retrieve incoming messages and write them to the sink
307    fn poll_ready(
308        mut self: Pin<&mut Self>,
309        cx: &mut task::Context,
310    ) -> Poll<Result<(), Self::Error>> {
311        match ready!(self.as_mut().project().sink_stream.poll_ready(cx)) {
312            Ok(()) => Ok(()).into(),
313            Err(err) => {
314                *self.project().error = Some(err);
315                Ok(()).into()
316            }
317        }
318    }
319
320    fn start_send(
321        mut self: Pin<&mut Self>,
322        PipelineMessage {
323            input,
324            mut output,
325            expectation,
326        }: PipelineMessage,
327    ) -> Result<(), Self::Error> {
328        // If initially a receiver was created, but then dropped, there is nothing to receive our output we do not need to send the message as it is
329        // ambiguous whether the message will be sent anyway. Helps shed some load on the
330        // connection.
331        if output.as_ref().is_some_and(|output| output.is_closed()) {
332            return Ok(());
333        }
334
335        let self_ = self.as_mut().project();
336
337        if let Some(err) = self_.error.take() {
338            if let Some(output) = output.take() {
339                _ = output.send(Err(err));
340            }
341            return Err(());
342        }
343
344        match self_.sink_stream.start_send(input) {
345            Ok(()) => {
346                let response_aggregate = ResponseAggregate::new(expectation);
347                let entry = InFlight {
348                    output,
349                    response_aggregate,
350                };
351
352                self_.in_flight.push_back(entry);
353                Ok(())
354            }
355            Err(err) => {
356                if let Some(output) = output.take() {
357                    _ = output.send(Err(err));
358                }
359                Err(())
360            }
361        }
362    }
363
364    fn poll_flush(
365        mut self: Pin<&mut Self>,
366        cx: &mut task::Context,
367    ) -> Poll<Result<(), Self::Error>> {
368        ready!(
369            self.as_mut()
370                .project()
371                .sink_stream
372                .poll_flush(cx)
373                .map_err(|err| {
374                    self.as_mut().send_result(Err(err));
375                })
376        )?;
377        self.poll_read(cx)
378    }
379
380    fn poll_close(
381        mut self: Pin<&mut Self>,
382        cx: &mut task::Context,
383    ) -> Poll<Result<(), Self::Error>> {
384        // No new requests will come in after the first call to `close` but we need to complete any
385        // in progress requests before closing
386        if !self.in_flight.is_empty() {
387            ready!(self.as_mut().poll_flush(cx))?;
388        }
389        let this = self.as_mut().project();
390        this.sink_stream.poll_close(cx).map_err(|err| {
391            self.send_result(Err(err));
392        })
393    }
394}
395
396impl Pipeline {
397    const DEFAULT_BUFFER_SIZE: usize = 50;
398
399    fn resolve_buffer_size(size: Option<usize>) -> usize {
400        size.unwrap_or(Self::DEFAULT_BUFFER_SIZE)
401    }
402
403    fn new<T>(
404        sink_stream: T,
405        push_sender: Option<Arc<dyn AsyncPushSender>>,
406        #[cfg(feature = "cache-aio")] cache_manager: Option<CacheManager>,
407        buffer_size: usize,
408    ) -> (Self, impl Future<Output = ()>)
409    where
410        T: Sink<Vec<u8>, Error = RedisError>,
411        T: Stream<Item = RedisResult<Value>>,
412        T: Unpin + Send + 'static,
413    {
414        let (sender, mut receiver) = mpsc::channel(buffer_size);
415
416        let sink = PipelineSink::new(
417            sink_stream,
418            push_sender,
419            #[cfg(feature = "cache-aio")]
420            cache_manager,
421        );
422        let f = stream::poll_fn(move |cx| receiver.poll_recv(cx))
423            .map(Ok)
424            .forward(sink)
425            .map(|_| ());
426        (Pipeline { sender }, f)
427    }
428
429    async fn send_recv(
430        &mut self,
431        input: Vec<u8>,
432        // If `None`, this is a single request, not a pipeline of multiple requests.
433        // If `Some`, the value inside defines how the response should look like
434        expectation: Option<PipelineResponseExpectation>,
435        timeout: Option<Duration>,
436        skip_response: bool,
437    ) -> Result<Value, RedisError> {
438        if input.is_empty() {
439            return Err(RedisError::make_empty_command());
440        }
441
442        let request = async {
443            if skip_response {
444                self.sender
445                    .send(PipelineMessage {
446                        input,
447                        expectation,
448                        output: None,
449                    })
450                    .await
451                    .map_err(|_| None)?;
452
453                return Ok(Value::Nil);
454            }
455
456            let (sender, receiver) = oneshot::channel();
457
458            self.sender
459                .send(PipelineMessage {
460                    input,
461                    expectation,
462                    output: Some(sender),
463                })
464                .await
465                .map_err(|_| None)?;
466
467            receiver.await
468            // The `sender` was dropped which likely means that the stream part
469            // failed for one reason or another
470            .map_err(|_| None)
471            .and_then(|res| res.map_err(Some))
472        };
473
474        match timeout {
475            Some(timeout) => match Runtime::locate().timeout(timeout, request).await {
476                Ok(res) => res,
477                Err(elapsed) => Err(Some(elapsed.into())),
478            },
479            None => request.await,
480        }
481        .map_err(|err| err.unwrap_or_else(closed_connection_error))
482    }
483}
484
485/// A connection object which can be cloned, allowing requests to be be sent concurrently
486/// on the same underlying connection (tcp/unix socket).
487///
488/// This connection object is cancellation-safe, and the user can drop request future without polling them to completion,
489/// but this doesn't mean that the actual request sent to the server is cancelled.
490/// A side-effect of this is that the underlying connection won't be closed until all sent requests have been answered,
491/// which means that in case of blocking commands, the underlying connection resource might not be released,
492/// even when all clones of the multiplexed connection have been dropped (see <https://github.com/redis-rs/redis-rs/issues/1236>).
493/// This isn't an issue in a connection that was created in a canonical way, which ensures that `_task_handle` is set, so that
494/// once all of the connection's clones are dropped, the task will also be dropped. If the user creates the connection in
495/// another way and `_task_handle` isn't set, they should manually spawn the returned driver function, keep the spawned task's
496/// handle and abort the task whenever they want, at the risk of effectively closing the clones of the multiplexed connection.
497#[derive(Clone)]
498pub struct MultiplexedConnection {
499    pipeline: Pipeline,
500    db: i64,
501    response_timeout: Option<Duration>,
502    protocol: ProtocolVersion,
503    concurrency_limiter: Option<Arc<async_lock::Semaphore>>,
504    // This handle ensures that once all the clones of the connection will be dropped, the underlying task will stop.
505    // This handle is only set for connection whose task was spawned by the crate, not for users who spawned their own
506    // task.
507    _task_handle: Option<SharedHandleContainer>,
508    #[cfg(feature = "cache-aio")]
509    pub(crate) cache_manager: Option<CacheManager>,
510    #[cfg(feature = "token-based-authentication")]
511    // This handle ensures that once all the clones of the connection will be dropped, the underlying task will stop.
512    // It is only set for connections that use a credentials provider for token-based authentication.
513    _credentials_subscription_task_handle: Option<SharedHandleContainer>,
514    /// Flag indicating that re-authentication has failed and the connection is no longer usable.
515    /// When set, all subsequent commands will fail immediately with an authentication error.
516    #[cfg(feature = "token-based-authentication")]
517    re_authentication_failed: Arc<AtomicBool>,
518}
519
520impl Debug for MultiplexedConnection {
521    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
522        let MultiplexedConnection {
523            pipeline,
524            db,
525            response_timeout,
526            protocol,
527            concurrency_limiter: _,
528            _task_handle,
529            #[cfg(feature = "cache-aio")]
530                cache_manager: _,
531            #[cfg(feature = "token-based-authentication")]
532                _credentials_subscription_task_handle: _,
533            #[cfg(feature = "token-based-authentication")]
534                re_authentication_failed: _,
535        } = self;
536
537        f.debug_struct("MultiplexedConnection")
538            .field("pipeline", &pipeline)
539            .field("db", &db)
540            .field("response_timeout", &response_timeout)
541            .field("protocol", &protocol)
542            .finish()
543    }
544}
545
546impl MultiplexedConnection {
547    /// Constructs a new `MultiplexedConnection` out of a `AsyncRead + AsyncWrite` object
548    /// and a `RedisConnectionInfo`
549    pub async fn new<C>(
550        connection_info: &RedisConnectionInfo,
551        stream: C,
552    ) -> RedisResult<(Self, impl Future<Output = ()>)>
553    where
554        C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
555    {
556        Self::new_with_config(connection_info, stream, AsyncConnectionConfig::default()).await
557    }
558
559    /// Constructs a new `MultiplexedConnection` out of a `AsyncRead + AsyncWrite` object
560    /// , a `RedisConnectionInfo` and a `AsyncConnectionConfig`.
561    pub async fn new_with_config<C>(
562        connection_info: &RedisConnectionInfo,
563        stream: C,
564        config: AsyncConnectionConfig,
565    ) -> RedisResult<(Self, impl Future<Output = ()> + 'static)>
566    where
567        C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
568    {
569        let mut codec = ValueCodec::default().framed(stream);
570        if config.push_sender.is_some() {
571            check_resp3!(
572                connection_info.protocol,
573                "Can only pass push sender to a connection using RESP3"
574            );
575        }
576
577        #[cfg(feature = "cache-aio")]
578        let cache_config = config.cache.as_ref().map(|cache| match cache {
579            crate::client::Cache::Config(cache_config) => *cache_config,
580            #[cfg(any(feature = "connection-manager", feature = "cluster-async"))]
581            crate::client::Cache::Manager(cache_manager) => cache_manager.cache_config,
582        });
583        #[cfg(feature = "cache-aio")]
584        let cache_manager_opt = config
585            .cache
586            .map(|cache| {
587                check_resp3!(
588                    connection_info.protocol,
589                    "Can only enable client side caching in a connection using RESP3"
590                );
591                match cache {
592                    crate::client::Cache::Config(cache_config) => {
593                        Ok(CacheManager::new(cache_config))
594                    }
595                    #[cfg(any(feature = "connection-manager", feature = "cluster-async"))]
596                    crate::client::Cache::Manager(cache_manager) => Ok(cache_manager),
597                }
598            })
599            .transpose()?;
600
601        #[cfg(feature = "token-based-authentication")]
602        let mut connection_info = connection_info.clone();
603        #[cfg(not(feature = "token-based-authentication"))]
604        let connection_info = connection_info.clone();
605
606        #[cfg(feature = "token-based-authentication")]
607        if let Some(ref credentials_provider) = config.credentials_provider {
608            // Retrieve the initial credentials from the provider and apply them to the connection info
609            match credentials_provider.subscribe().next().await {
610                Some(Ok(credentials)) => {
611                    connection_info.username = Some(ArcStr::from(credentials.username));
612                    connection_info.password = Some(ArcStr::from(credentials.password));
613                }
614                Some(Err(err)) => {
615                    error!("Error while receiving credentials from stream: {err}");
616                    return Err(err);
617                }
618                None => {
619                    let err = RedisError::from((
620                        ErrorKind::AuthenticationFailed,
621                        "Credentials stream closed unexpectedly before yielding credentials!",
622                    ));
623                    error!("{err}");
624                    return Err(err);
625                }
626            }
627        }
628
629        setup_connection(
630            &mut codec,
631            &connection_info,
632            #[cfg(feature = "cache-aio")]
633            cache_config,
634        )
635        .await?;
636        if config.push_sender.is_some() {
637            check_resp3!(
638                connection_info.protocol,
639                "Can only pass push sender to a connection using RESP3"
640            );
641        }
642
643        let (pipeline, driver) = Pipeline::new(
644            codec,
645            config.push_sender,
646            #[cfg(feature = "cache-aio")]
647            cache_manager_opt.clone(),
648            Pipeline::resolve_buffer_size(config.pipeline_buffer_size),
649        );
650
651        let concurrency_limiter = config
652            .concurrency_limit
653            .map(|n| Arc::new(async_lock::Semaphore::new(n)));
654
655        let con = MultiplexedConnection {
656            pipeline,
657            db: connection_info.db,
658            response_timeout: config.response_timeout,
659            protocol: connection_info.protocol,
660            concurrency_limiter,
661            _task_handle: None,
662            #[cfg(feature = "cache-aio")]
663            cache_manager: cache_manager_opt,
664            #[cfg(feature = "token-based-authentication")]
665            _credentials_subscription_task_handle: None,
666            #[cfg(feature = "token-based-authentication")]
667            re_authentication_failed: Arc::new(AtomicBool::new(false)),
668        };
669
670        // Set up streaming credentials subscription if provider is available
671        #[cfg(feature = "token-based-authentication")]
672        if let Some(streaming_provider) = config.credentials_provider {
673            let mut inner_connection = con.clone();
674            let re_authentication_failed_arc = Arc::clone(&con.re_authentication_failed);
675            let mut stream = streaming_provider.subscribe();
676
677            let subscription_task_handle = Runtime::locate().spawn(async move {
678                let mut error_cause_logged = false;
679                while let Some(result) = stream.next().await {
680                    match result {
681                        Ok(credentials) => {
682                            if let Err(err) = inner_connection
683                                .re_authenticate_with_credentials(&credentials)
684                                .await
685                            {
686                                error!("Failed to re-authenticate async connection: {err}.");
687                                error_cause_logged = true;
688                                re_authentication_failed_arc.store(true, Ordering::Relaxed);
689                                break;
690                            } else {
691                                debug!("Re-authenticated async connection");
692                            }
693                        }
694                        Err(err) => {
695                            error!("Credentials stream error for async connection: {err}.");
696                            error_cause_logged = true;
697                        }
698                    }
699                }
700                if !re_authentication_failed_arc.load(Ordering::Relaxed) {
701                    if !error_cause_logged {
702                        error!("Re-authentication stream ended unexpectedly.");
703                    }
704                    re_authentication_failed_arc.store(true, Ordering::Relaxed);
705                }
706            });
707            return Ok((
708                Self {
709                    _credentials_subscription_task_handle: Some(SharedHandleContainer::new(
710                        subscription_task_handle,
711                    )),
712                    ..con
713                },
714                driver,
715            ));
716        }
717
718        Ok((con, driver))
719    }
720
721    /// This should be called strictly before the multiplexed connection is cloned - that is, before it is returned to the user.
722    /// Otherwise some clones will be able to kill the backing task, while other clones are still alive.
723    pub(crate) fn set_task_handle(&mut self, handle: TaskHandle) {
724        self._task_handle = Some(SharedHandleContainer::new(handle));
725    }
726
727    /// Sets the time that the multiplexer will wait for responses on operations before failing.
728    pub fn set_response_timeout(&mut self, timeout: std::time::Duration) {
729        self.response_timeout = Some(timeout);
730    }
731
732    /// Sends an already encoded (packed) command into the TCP socket and
733    /// reads the single response from it.
734    pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
735        let _permit = if cmd.skip_concurrency_limit {
736            None
737        } else if let Some(limiter) = &self.concurrency_limiter {
738            Some(limiter.acquire().await)
739        } else {
740            None
741        };
742        #[cfg(feature = "token-based-authentication")]
743        if self.re_authentication_failed.load(Ordering::Relaxed) {
744            return Err(RedisError::from((
745                ErrorKind::AuthenticationFailed,
746                "Connection is no longer usable due to re-authentication failure",
747            )));
748        }
749        #[cfg(feature = "cache-aio")]
750        if let Some(cache_manager) = &self.cache_manager {
751            match cache_manager.get_cached_cmd(cmd) {
752                PrepareCacheResult::Cached(value) => return Ok(value),
753                PrepareCacheResult::NotCached(cacheable_command) => {
754                    let mut pipeline = crate::Pipeline::new();
755                    cacheable_command.pack_command(cache_manager, &mut pipeline);
756
757                    let result = self
758                        .pipeline
759                        .send_recv(
760                            pipeline.get_packed_pipeline(),
761                            Some(PipelineResponseExpectation {
762                                skipped_response_count: 0,
763                                expected_response_count: pipeline.commands.len(),
764                                is_transaction: false,
765                                seen_responses: 0,
766                            }),
767                            self.response_timeout,
768                            cmd.is_no_response(),
769                        )
770                        .await?;
771                    let replies: Vec<Value> = crate::types::from_redis_value(result)?;
772                    return cacheable_command.resolve(cache_manager, replies.into_iter());
773                }
774                _ => (),
775            }
776        }
777        self.pipeline
778            .send_recv(
779                cmd.get_packed_command(),
780                None,
781                self.response_timeout,
782                cmd.is_no_response(),
783            )
784            .await
785    }
786
787    /// Sends multiple already encoded (packed) command into the TCP socket
788    /// and reads `count` responses from it.  This is used to implement
789    /// pipelining.
790    pub async fn send_packed_commands(
791        &mut self,
792        cmd: &crate::Pipeline,
793        offset: usize,
794        count: usize,
795    ) -> RedisResult<Vec<Value>> {
796        // Try to acquire 1 permit per command in the pipeline: block on the first to guarantee
797        // progress, then grab as many more as are immediately available without blocking.
798        // This roughly reflects the pipeline's load on the server while avoiding deadlock --
799        // a large pipeline can always proceed even if it can't acquire all permits.
800        let _permits = if let Some(limiter) = &self.concurrency_limiter {
801            let mut permits = Vec::with_capacity(count.max(1));
802            permits.push(limiter.acquire().await);
803            for _ in 1..count {
804                match limiter.try_acquire() {
805                    Some(permit) => permits.push(permit),
806                    None => break,
807                }
808            }
809            permits
810        } else {
811            Vec::new()
812        };
813        #[cfg(feature = "token-based-authentication")]
814        if self.re_authentication_failed.load(Ordering::Relaxed) {
815            return Err(RedisError::from((
816                ErrorKind::AuthenticationFailed,
817                "Connection is no longer usable due to re-authentication failure",
818            )));
819        }
820        #[cfg(feature = "cache-aio")]
821        if let Some(cache_manager) = &self.cache_manager {
822            let (cacheable_pipeline, pipeline, (skipped_response_count, expected_response_count)) =
823                cache_manager.get_cached_pipeline(cmd);
824            if pipeline.is_empty() {
825                return cacheable_pipeline.resolve(cache_manager, Value::Array(Vec::new()));
826            }
827            let result = self
828                .pipeline
829                .send_recv(
830                    pipeline.get_packed_pipeline(),
831                    Some(PipelineResponseExpectation {
832                        skipped_response_count,
833                        expected_response_count,
834                        is_transaction: cacheable_pipeline.transaction_mode,
835                        seen_responses: 0,
836                    }),
837                    self.response_timeout,
838                    false,
839                )
840                .await?;
841
842            return cacheable_pipeline.resolve(cache_manager, result);
843        }
844        let value = self
845            .pipeline
846            .send_recv(
847                cmd.get_packed_pipeline(),
848                Some(PipelineResponseExpectation {
849                    skipped_response_count: offset,
850                    expected_response_count: count,
851                    is_transaction: cmd.is_transaction(),
852                    seen_responses: 0,
853                }),
854                self.response_timeout,
855                false,
856            )
857            .await?;
858        match value {
859            Value::Array(values) => Ok(values),
860            _ => Ok(vec![value]),
861        }
862    }
863
864    /// Gets [`CacheStatistics`] for current connection if caching is enabled.
865    #[cfg(feature = "cache-aio")]
866    #[cfg_attr(docsrs, doc(cfg(feature = "cache-aio")))]
867    pub fn get_cache_statistics(&self) -> Option<CacheStatistics> {
868        self.cache_manager.as_ref().map(|cm| cm.statistics())
869    }
870}
871
872impl ConnectionLike for MultiplexedConnection {
873    fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
874        (async move { self.send_packed_command(cmd).await }).boxed()
875    }
876
877    fn req_packed_commands<'a>(
878        &'a mut self,
879        cmd: &'a crate::Pipeline,
880        offset: usize,
881        count: usize,
882    ) -> RedisFuture<'a, Vec<Value>> {
883        (async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
884    }
885
886    fn get_db(&self) -> i64 {
887        self.db
888    }
889}
890
891impl MultiplexedConnection {
892    /// Subscribes to a new channel(s).    
893    ///
894    /// Updates from the sender will be sent on the push sender that was passed to the connection.
895    /// If the connection was configured without a push sender, the connection won't be able to pass messages back to the user.
896    ///
897    /// This method is only available when the connection is using RESP3 protocol, and will return an error otherwise.
898    ///
899    /// ```rust,no_run
900    /// # async fn func() -> redis::RedisResult<()> {
901    /// let client = redis::Client::open("redis://127.0.0.1/?protocol=resp3").unwrap();
902    /// let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
903    /// let config = redis::AsyncConnectionConfig::new().set_push_sender(tx);
904    /// let mut con = client.get_multiplexed_async_connection_with_config(&config).await?;
905    /// con.subscribe(&["channel_1", "channel_2"]).await?;
906    /// # Ok(()) }
907    /// ```
908    pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
909        check_resp3!(self.protocol);
910        let mut cmd = cmd("SUBSCRIBE");
911        cmd.arg(channel_name);
912        cmd.exec_async(self).await?;
913        Ok(())
914    }
915
916    /// Unsubscribes from channel(s).
917    ///
918    /// This method is only available when the connection is using RESP3 protocol, and will return an error otherwise.
919    ///
920    /// ```rust,no_run
921    /// # async fn func() -> redis::RedisResult<()> {
922    /// let client = redis::Client::open("redis://127.0.0.1/?protocol=resp3").unwrap();
923    /// let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
924    /// let config = redis::AsyncConnectionConfig::new().set_push_sender(tx);
925    /// let mut con = client.get_multiplexed_async_connection_with_config(&config).await?;
926    /// con.subscribe(&["channel_1", "channel_2"]).await?;
927    /// con.unsubscribe(&["channel_1", "channel_2"]).await?;
928    /// # Ok(()) }
929    /// ```
930    pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
931        check_resp3!(self.protocol);
932        let mut cmd = cmd("UNSUBSCRIBE");
933        cmd.arg(channel_name);
934        cmd.exec_async(self).await?;
935        Ok(())
936    }
937
938    /// Subscribes to new channel(s) with pattern(s).
939    ///
940    /// Updates from the sender will be sent on the push sender that was passed to the connection.
941    /// If the connection was configured without a push sender, the connection won't be able to pass messages back to the user.
942    ///
943    /// This method is only available when the connection is using RESP3 protocol, and will return an error otherwise.
944    ///
945    /// ```rust,no_run
946    /// # async fn func() -> redis::RedisResult<()> {
947    /// let client = redis::Client::open("redis://127.0.0.1/?protocol=resp3").unwrap();
948    /// let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
949    /// let config = redis::AsyncConnectionConfig::new().set_push_sender(tx);
950    /// let mut con = client.get_multiplexed_async_connection_with_config(&config).await?;
951    /// con.psubscribe("channel*_1").await?;
952    /// con.psubscribe(&["channel*_2", "channel*_3"]).await?;
953    /// # Ok(())
954    /// # }
955    /// ```
956    pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
957        check_resp3!(self.protocol);
958        let mut cmd = cmd("PSUBSCRIBE");
959        cmd.arg(channel_pattern);
960        cmd.exec_async(self).await?;
961        Ok(())
962    }
963
964    /// Unsubscribes from channel pattern(s).
965    ///
966    /// This method is only available when the connection is using RESP3 protocol, and will return an error otherwise.
967    pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
968        check_resp3!(self.protocol);
969        let mut cmd = cmd("PUNSUBSCRIBE");
970        cmd.arg(channel_pattern);
971        cmd.exec_async(self).await?;
972        Ok(())
973    }
974}
975
976#[cfg(feature = "token-based-authentication")]
977impl MultiplexedConnection {
978    /// Re-authenticate the connection with new credentials
979    ///
980    /// This method allows existing async connections to update their authentication
981    /// when tokens are refreshed, enabling streaming credential updates.
982    async fn re_authenticate_with_credentials(
983        &mut self,
984        credentials: &crate::auth::BasicAuth,
985    ) -> RedisResult<()> {
986        let mut auth_cmd =
987            crate::connection::authenticate_cmd(Some(&credentials.username), &credentials.password);
988        auth_cmd.skip_concurrency_limit = true;
989        self.send_packed_command(&auth_cmd)
990            .await?
991            .extract_error()
992            .map(|_| ())
993    }
994}
995
996#[cfg(test)]
997mod tests {
998    use super::*;
999
1000    #[test]
1001    fn test_pipeline_resolve_buffer_size_default() {
1002        assert_eq!(Pipeline::resolve_buffer_size(None), 50);
1003    }
1004
1005    #[test]
1006    fn test_pipeline_resolve_buffer_size_custom() {
1007        assert_eq!(Pipeline::resolve_buffer_size(Some(100)), 100);
1008    }
1009
1010    fn mock_conn_info() -> RedisConnectionInfo {
1011        RedisConnectionInfo {
1012            skip_set_lib_name: true,
1013            ..Default::default()
1014        }
1015    }
1016
1017    async fn create_mock_connection(
1018        concurrency_limit: usize,
1019    ) -> (
1020        MultiplexedConnection,
1021        tokio::sync::mpsc::Receiver<()>,
1022        tokio::sync::mpsc::Sender<()>,
1023    ) {
1024        use futures_util::StreamExt;
1025        use tokio::io::AsyncWriteExt;
1026        use tokio_util::codec::FramedRead;
1027
1028        let (client_half, server_half) = tokio::io::duplex(4096);
1029        let (cmd_received_tx, cmd_received_rx) = tokio::sync::mpsc::channel::<()>(10);
1030        let (send_response_tx, mut send_response_rx) = tokio::sync::mpsc::channel::<()>(10);
1031
1032        let (server_read, mut server_write) = tokio::io::split(server_half);
1033
1034        tokio::spawn(async move {
1035            let mut reader = FramedRead::new(server_read, ValueCodec::default());
1036            while let Some(Ok(_)) = reader.next().await {
1037                let _ = cmd_received_tx.send(()).await;
1038            }
1039        });
1040
1041        tokio::spawn(async move {
1042            while send_response_rx.recv().await.is_some() {
1043                let _ = server_write.write_all(b"+OK\r\n").await;
1044                let _ = server_write.flush().await;
1045            }
1046        });
1047
1048        let config = AsyncConnectionConfig::new()
1049            .set_concurrency_limit(concurrency_limit)
1050            .set_response_timeout(None)
1051            .set_connection_timeout(None);
1052
1053        let (conn, driver) =
1054            MultiplexedConnection::new_with_config(&mock_conn_info(), client_half, config)
1055                .await
1056                .unwrap();
1057        tokio::spawn(driver);
1058
1059        (conn, cmd_received_rx, send_response_tx)
1060    }
1061
1062    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1063    async fn test_concurrency_limit_enforced() {
1064        let (conn, mut cmd_received_rx, send_response_tx) = create_mock_connection(2).await;
1065
1066        let h1 = tokio::spawn({
1067            let mut c = conn.clone();
1068            async move { c.send_packed_command(&cmd("PING")).await }
1069        });
1070        let h2 = tokio::spawn({
1071            let mut c = conn.clone();
1072            async move { c.send_packed_command(&cmd("PING")).await }
1073        });
1074        let h3 = tokio::spawn({
1075            let mut c = conn.clone();
1076            async move { c.send_packed_command(&cmd("PING")).await }
1077        });
1078
1079        cmd_received_rx.recv().await.unwrap();
1080        cmd_received_rx.recv().await.unwrap();
1081
1082        let third = tokio::time::timeout(Duration::from_millis(100), cmd_received_rx.recv()).await;
1083        assert!(
1084            third.is_err(),
1085            "3rd request should be blocked by concurrency limit"
1086        );
1087
1088        send_response_tx.send(()).await.unwrap();
1089
1090        cmd_received_rx.recv().await.unwrap();
1091
1092        send_response_tx.send(()).await.unwrap();
1093        send_response_tx.send(()).await.unwrap();
1094
1095        h1.await.unwrap().unwrap();
1096        h2.await.unwrap().unwrap();
1097        h3.await.unwrap().unwrap();
1098    }
1099
1100    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1101    async fn test_no_limit_bypasses_concurrency_limit() {
1102        let (conn, mut cmd_received_rx, send_response_tx) = create_mock_connection(1).await;
1103
1104        let h1 = tokio::spawn({
1105            let mut c = conn.clone();
1106            async move { c.send_packed_command(&cmd("PING")).await }
1107        });
1108
1109        cmd_received_rx.recv().await.unwrap();
1110
1111        let h2 = tokio::spawn({
1112            let mut c = conn.clone();
1113            async move {
1114                let mut ping = cmd("PING");
1115                ping.skip_concurrency_limit = true;
1116                c.send_packed_command(&ping).await
1117            }
1118        });
1119
1120        let received =
1121            tokio::time::timeout(Duration::from_millis(100), cmd_received_rx.recv()).await;
1122        assert!(
1123            received.is_ok(),
1124            "no_limit request should bypass concurrency limit"
1125        );
1126
1127        send_response_tx.send(()).await.unwrap();
1128        send_response_tx.send(()).await.unwrap();
1129
1130        h1.await.unwrap().unwrap();
1131        h2.await.unwrap().unwrap();
1132    }
1133
1134    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1135    async fn test_pipeline_acquires_multiple_permits() {
1136        let (conn, mut cmd_received_rx, send_response_tx) = create_mock_connection(3).await;
1137
1138        let pipeline_handle = tokio::spawn({
1139            let mut c = conn.clone();
1140            async move {
1141                let mut pipe = crate::Pipeline::new();
1142                pipe.cmd("SET").arg("a").arg("1");
1143                pipe.cmd("SET").arg("b").arg("2");
1144                pipe.cmd("SET").arg("c").arg("3");
1145                c.send_packed_commands(&pipe, 0, 3).await
1146            }
1147        });
1148
1149        for _ in 0..3 {
1150            cmd_received_rx.recv().await.unwrap();
1151        }
1152
1153        let single_handle = tokio::spawn({
1154            let mut c = conn.clone();
1155            async move { c.send_packed_command(&cmd("PING")).await }
1156        });
1157
1158        let blocked =
1159            tokio::time::timeout(Duration::from_millis(100), cmd_received_rx.recv()).await;
1160        assert!(
1161            blocked.is_err(),
1162            "single command should be blocked while pipeline holds all permits"
1163        );
1164
1165        for _ in 0..3 {
1166            send_response_tx.send(()).await.unwrap();
1167        }
1168
1169        cmd_received_rx.recv().await.unwrap();
1170        send_response_tx.send(()).await.unwrap();
1171
1172        pipeline_handle.await.unwrap().unwrap();
1173        single_handle.await.unwrap().unwrap();
1174    }
1175
1176    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1177    async fn test_pipeline_proceeds_with_partial_permits() {
1178        let (conn, mut cmd_received_rx, send_response_tx) = create_mock_connection(2).await;
1179
1180        let single_handle = tokio::spawn({
1181            let mut c = conn.clone();
1182            async move { c.send_packed_command(&cmd("PING")).await }
1183        });
1184        cmd_received_rx.recv().await.unwrap();
1185
1186        let pipeline_handle = tokio::spawn({
1187            let mut c = conn.clone();
1188            async move {
1189                let mut pipe = crate::Pipeline::new();
1190                for i in 0..5 {
1191                    pipe.cmd("SET").arg(format!("k{i}")).arg(i);
1192                }
1193                c.send_packed_commands(&pipe, 0, 5).await
1194            }
1195        });
1196
1197        let received =
1198            tokio::time::timeout(Duration::from_millis(100), cmd_received_rx.recv()).await;
1199        assert!(
1200            received.is_ok(),
1201            "pipeline should proceed even with only partial permits"
1202        );
1203
1204        for _ in 1..5 {
1205            cmd_received_rx.recv().await.unwrap();
1206        }
1207
1208        for _ in 0..6 {
1209            send_response_tx.send(()).await.unwrap();
1210        }
1211
1212        single_handle.await.unwrap().unwrap();
1213        pipeline_handle.await.unwrap().unwrap();
1214    }
1215
1216    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1217    async fn test_permit_released_on_cancellation() {
1218        let (conn, mut cmd_received_rx, send_response_tx) = create_mock_connection(1).await;
1219
1220        let h1 = tokio::spawn({
1221            let mut c = conn.clone();
1222            async move { c.send_packed_command(&cmd("PING")).await }
1223        });
1224        cmd_received_rx.recv().await.unwrap();
1225
1226        // Start a second request that will block on the semaphore, then cancel it
1227        let h2 = tokio::spawn({
1228            let mut c = conn.clone();
1229            async move { c.send_packed_command(&cmd("PING")).await }
1230        });
1231        tokio::time::sleep(Duration::from_millis(50)).await;
1232        h2.abort();
1233        let _ = h2.await;
1234
1235        // Complete the first request
1236        send_response_tx.send(()).await.unwrap();
1237        h1.await.unwrap().unwrap();
1238
1239        // The permit from the cancelled request should have been released,
1240        // so a new request should proceed
1241        let h3 = tokio::spawn({
1242            let mut c = conn.clone();
1243            async move { c.send_packed_command(&cmd("PING")).await }
1244        });
1245
1246        let received =
1247            tokio::time::timeout(Duration::from_millis(100), cmd_received_rx.recv()).await;
1248        assert!(
1249            received.is_ok(),
1250            "request after cancellation should acquire the permit"
1251        );
1252
1253        send_response_tx.send(()).await.unwrap();
1254        h3.await.unwrap().unwrap();
1255    }
1256
1257    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1258    async fn test_permit_released_on_response_timeout() {
1259        use futures_util::StreamExt;
1260        use tokio::io::AsyncWriteExt;
1261        use tokio_util::codec::FramedRead;
1262
1263        let (client_half, server_half) = tokio::io::duplex(4096);
1264        let (cmd_received_tx, mut cmd_received_rx) = tokio::sync::mpsc::channel::<()>(10);
1265
1266        let (server_read, mut server_write) = tokio::io::split(server_half);
1267
1268        tokio::spawn(async move {
1269            let mut reader = FramedRead::new(server_read, ValueCodec::default());
1270            while let Some(Ok(_)) = reader.next().await {
1271                let _ = cmd_received_tx.send(()).await;
1272            }
1273        });
1274
1275        tokio::spawn(async move {
1276            futures_util::future::pending::<()>().await;
1277            let _ = server_write.write_all(b"").await;
1278        });
1279
1280        let config = AsyncConnectionConfig::new()
1281            .set_concurrency_limit(1)
1282            .set_response_timeout(Some(Duration::from_millis(100)))
1283            .set_connection_timeout(None);
1284
1285        let (conn, driver) =
1286            MultiplexedConnection::new_with_config(&mock_conn_info(), client_half, config)
1287                .await
1288                .unwrap();
1289        tokio::spawn(driver);
1290
1291        // First request times out since the mock never responds
1292        let mut c1 = conn.clone();
1293        let err = c1.send_packed_command(&cmd("PING")).await.unwrap_err();
1294        assert!(err.is_io_error(), "expected IO error from timeout");
1295        cmd_received_rx.recv().await.unwrap();
1296
1297        // Second request should acquire the permit released by the first,
1298        // reach the server, and then also time out
1299        let mut c2 = conn.clone();
1300        let err = c2.send_packed_command(&cmd("PING")).await.unwrap_err();
1301        assert!(err.is_io_error(), "expected IO error from timeout");
1302        cmd_received_rx.recv().await.unwrap();
1303    }
1304}