1use super::{AsyncPushSender, ConnectionLike, Runtime, SharedHandleContainer, TaskHandle};
2use crate::aio::{check_resp3, setup_connection};
3#[cfg(feature = "cache-aio")]
4use crate::caching::{CacheManager, CacheStatistics, PrepareCacheResult};
5use crate::cmd::Cmd;
6#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
7use crate::parser::ValueCodec;
8use crate::types::{closed_connection_error, RedisError, RedisFuture, RedisResult, Value};
9use crate::{
10 cmd, AsyncConnectionConfig, ProtocolVersion, PushInfo, RedisConnectionInfo, ToRedisArgs,
11};
12use ::tokio::{
13 io::{AsyncRead, AsyncWrite},
14 sync::{mpsc, oneshot},
15};
16use futures_util::{
17 future::{Future, FutureExt},
18 ready,
19 sink::Sink,
20 stream::{self, Stream, StreamExt},
21};
22use pin_project_lite::pin_project;
23use std::collections::VecDeque;
24use std::fmt;
25use std::fmt::Debug;
26use std::pin::Pin;
27use std::sync::Arc;
28use std::task::{self, Poll};
29use std::time::Duration;
30#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
31use tokio_util::codec::Decoder;
32
33type PipelineOutput = oneshot::Sender<RedisResult<Value>>;
35
36enum ResponseAggregate {
37 SingleCommand,
38 Pipeline {
39 buffer: Vec<Value>,
40 first_err: Option<RedisError>,
41 expectation: PipelineResponseExpectation,
42 },
43}
44
45struct PipelineResponseExpectation {
47 skipped_response_count: usize,
49 expected_response_count: usize,
51 is_transaction: bool,
53}
54
55impl ResponseAggregate {
56 fn new(expectation: Option<PipelineResponseExpectation>) -> Self {
57 match expectation {
58 Some(expectation) => ResponseAggregate::Pipeline {
59 buffer: Vec::new(),
60 first_err: None,
61 expectation,
62 },
63 None => ResponseAggregate::SingleCommand,
64 }
65 }
66}
67
68struct InFlight {
69 output: PipelineOutput,
70 response_aggregate: ResponseAggregate,
71}
72
73struct PipelineMessage {
75 input: Vec<u8>,
76 output: PipelineOutput,
77 expectation: Option<PipelineResponseExpectation>,
81}
82
83#[derive(Clone)]
88struct Pipeline {
89 sender: mpsc::Sender<PipelineMessage>,
90}
91
92impl Debug for Pipeline {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 f.debug_tuple("Pipeline").field(&self.sender).finish()
95 }
96}
97
98#[cfg(feature = "cache-aio")]
99pin_project! {
100 struct PipelineSink<T> {
101 #[pin]
102 sink_stream: T,
103 in_flight: VecDeque<InFlight>,
104 error: Option<RedisError>,
105 push_sender: Option<Arc<dyn AsyncPushSender>>,
106 cache_manager: Option<CacheManager>,
107 }
108}
109
110#[cfg(not(feature = "cache-aio"))]
111pin_project! {
112 struct PipelineSink<T> {
113 #[pin]
114 sink_stream: T,
115 in_flight: VecDeque<InFlight>,
116 error: Option<RedisError>,
117 push_sender: Option<Arc<dyn AsyncPushSender>>,
118 }
119}
120
121fn send_push(push_sender: &Option<Arc<dyn AsyncPushSender>>, info: PushInfo) {
122 if let Some(sender) = push_sender {
123 let _ = sender.send(info);
124 };
125}
126
127pub(crate) fn send_disconnect(push_sender: &Option<Arc<dyn AsyncPushSender>>) {
128 send_push(push_sender, PushInfo::disconnect());
129}
130
131impl<T> PipelineSink<T>
132where
133 T: Stream<Item = RedisResult<Value>> + 'static,
134{
135 fn new(
136 sink_stream: T,
137 push_sender: Option<Arc<dyn AsyncPushSender>>,
138 #[cfg(feature = "cache-aio")] cache_manager: Option<CacheManager>,
139 ) -> Self
140 where
141 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
142 {
143 PipelineSink {
144 sink_stream,
145 in_flight: VecDeque::new(),
146 error: None,
147 push_sender,
148 #[cfg(feature = "cache-aio")]
149 cache_manager,
150 }
151 }
152
153 fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
155 loop {
156 let item = ready!(self.as_mut().project().sink_stream.poll_next(cx));
157 let item = match item {
158 Some(result) => result,
159 None => Err(closed_connection_error()),
161 };
162
163 let is_unrecoverable = item.as_ref().is_err_and(|err| err.is_unrecoverable_error());
164 self.as_mut().send_result(item);
165 if is_unrecoverable {
166 let self_ = self.project();
167 send_disconnect(self_.push_sender);
168 return Poll::Ready(Err(()));
169 }
170 }
171 }
172
173 fn send_result(self: Pin<&mut Self>, result: RedisResult<Value>) {
174 let self_ = self.project();
175 let result = match result {
176 Ok(Value::Push { kind, data }) if !kind.has_reply() => {
178 #[cfg(feature = "cache-aio")]
179 if let Some(cache_manager) = &self_.cache_manager {
180 cache_manager.handle_push_value(&kind, &data);
181 }
182 send_push(self_.push_sender, PushInfo { kind, data });
183
184 return;
185 }
186 Ok(Value::Push { kind, data }) if kind.has_reply() => {
188 send_push(
189 self_.push_sender,
190 PushInfo {
191 kind: kind.clone(),
192 data: data.clone(),
193 },
194 );
195 Ok(Value::Push { kind, data })
196 }
197 _ => result,
198 };
199
200 let mut entry = match self_.in_flight.pop_front() {
201 Some(entry) => entry,
202 None => return,
203 };
204
205 match &mut entry.response_aggregate {
206 ResponseAggregate::SingleCommand => {
207 entry.output.send(result).ok();
208 }
209 ResponseAggregate::Pipeline {
210 buffer,
211 first_err,
212 expectation:
213 PipelineResponseExpectation {
214 expected_response_count,
215 skipped_response_count,
216 is_transaction,
217 },
218 } => {
219 if *skipped_response_count > 0 {
220 if first_err.is_none() && *is_transaction {
223 *first_err = result.and_then(Value::extract_error).err();
224 }
225
226 *skipped_response_count -= 1;
227 self_.in_flight.push_front(entry);
228 return;
229 }
230
231 match result {
232 Ok(item) => {
233 buffer.push(item);
234 }
235 Err(err) => {
236 if first_err.is_none() {
237 *first_err = Some(err);
238 }
239 }
240 }
241
242 if buffer.len() < *expected_response_count {
243 self_.in_flight.push_front(entry);
245 return;
246 }
247
248 let response = match first_err.take() {
249 Some(err) => Err(err),
250 None => Ok(Value::Array(std::mem::take(buffer))),
251 };
252
253 entry.output.send(response).ok();
257 }
258 }
259 }
260}
261
262impl<T> Sink<PipelineMessage> for PipelineSink<T>
263where
264 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
265{
266 type Error = ();
267
268 fn poll_ready(
270 mut self: Pin<&mut Self>,
271 cx: &mut task::Context,
272 ) -> Poll<Result<(), Self::Error>> {
273 match ready!(self.as_mut().project().sink_stream.poll_ready(cx)) {
274 Ok(()) => Ok(()).into(),
275 Err(err) => {
276 *self.project().error = Some(err);
277 Ok(()).into()
278 }
279 }
280 }
281
282 fn start_send(
283 mut self: Pin<&mut Self>,
284 PipelineMessage {
285 input,
286 output,
287 expectation,
288 }: PipelineMessage,
289 ) -> Result<(), Self::Error> {
290 if output.is_closed() {
294 return Ok(());
295 }
296
297 let self_ = self.as_mut().project();
298
299 if let Some(err) = self_.error.take() {
300 let _ = output.send(Err(err));
301 return Err(());
302 }
303
304 match self_.sink_stream.start_send(input) {
305 Ok(()) => {
306 let response_aggregate = ResponseAggregate::new(expectation);
307 let entry = InFlight {
308 output,
309 response_aggregate,
310 };
311
312 self_.in_flight.push_back(entry);
313 Ok(())
314 }
315 Err(err) => {
316 let _ = output.send(Err(err));
317 Err(())
318 }
319 }
320 }
321
322 fn poll_flush(
323 mut self: Pin<&mut Self>,
324 cx: &mut task::Context,
325 ) -> Poll<Result<(), Self::Error>> {
326 ready!(self
327 .as_mut()
328 .project()
329 .sink_stream
330 .poll_flush(cx)
331 .map_err(|err| {
332 self.as_mut().send_result(Err(err));
333 }))?;
334 self.poll_read(cx)
335 }
336
337 fn poll_close(
338 mut self: Pin<&mut Self>,
339 cx: &mut task::Context,
340 ) -> Poll<Result<(), Self::Error>> {
341 if !self.in_flight.is_empty() {
344 ready!(self.as_mut().poll_flush(cx))?;
345 }
346 let this = self.as_mut().project();
347 this.sink_stream.poll_close(cx).map_err(|err| {
348 self.send_result(Err(err));
349 })
350 }
351}
352
353impl Pipeline {
354 fn new<T>(
355 sink_stream: T,
356 push_sender: Option<Arc<dyn AsyncPushSender>>,
357 #[cfg(feature = "cache-aio")] cache_manager: Option<CacheManager>,
358 ) -> (Self, impl Future<Output = ()>)
359 where
360 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
361 T: Send + 'static,
362 T::Item: Send,
363 T::Error: Send,
364 T::Error: ::std::fmt::Debug,
365 {
366 const BUFFER_SIZE: usize = 50;
367 let (sender, mut receiver) = mpsc::channel(BUFFER_SIZE);
368
369 let sink = PipelineSink::new(
370 sink_stream,
371 push_sender,
372 #[cfg(feature = "cache-aio")]
373 cache_manager,
374 );
375 let f = stream::poll_fn(move |cx| receiver.poll_recv(cx))
376 .map(Ok)
377 .forward(sink)
378 .map(|_| ());
379 (Pipeline { sender }, f)
380 }
381
382 async fn send_recv(
383 &mut self,
384 input: Vec<u8>,
385 expectation: Option<PipelineResponseExpectation>,
388 timeout: Option<Duration>,
389 ) -> Result<Value, RedisError> {
390 let (sender, receiver) = oneshot::channel();
391
392 let request = async {
393 self.sender
394 .send(PipelineMessage {
395 input,
396 expectation,
397 output: sender,
398 })
399 .await
400 .map_err(|_| None)?;
401
402 receiver.await
403 .map_err(|_| None)
406 .and_then(|res| res.map_err(Some))
407 };
408
409 match timeout {
410 Some(timeout) => match Runtime::locate().timeout(timeout, request).await {
411 Ok(res) => res,
412 Err(elapsed) => Err(Some(elapsed.into())),
413 },
414 None => request.await,
415 }
416 .map_err(|err| err.unwrap_or_else(closed_connection_error))
417 }
418}
419
420#[derive(Clone)]
433pub struct MultiplexedConnection {
434 pipeline: Pipeline,
435 db: i64,
436 response_timeout: Option<Duration>,
437 protocol: ProtocolVersion,
438 _task_handle: Option<SharedHandleContainer>,
442 #[cfg(feature = "cache-aio")]
443 pub(crate) cache_manager: Option<CacheManager>,
444}
445
446impl Debug for MultiplexedConnection {
447 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
448 f.debug_struct("MultiplexedConnection")
449 .field("pipeline", &self.pipeline)
450 .field("db", &self.db)
451 .finish()
452 }
453}
454
455impl MultiplexedConnection {
456 pub async fn new<C>(
459 connection_info: &RedisConnectionInfo,
460 stream: C,
461 ) -> RedisResult<(Self, impl Future<Output = ()>)>
462 where
463 C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
464 {
465 Self::new_with_response_timeout(connection_info, stream, None).await
466 }
467
468 pub async fn new_with_response_timeout<C>(
471 connection_info: &RedisConnectionInfo,
472 stream: C,
473 response_timeout: Option<std::time::Duration>,
474 ) -> RedisResult<(Self, impl Future<Output = ()>)>
475 where
476 C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
477 {
478 Self::new_with_config(
479 connection_info,
480 stream,
481 AsyncConnectionConfig {
482 response_timeout,
483 ..Default::default()
484 },
485 )
486 .await
487 }
488
489 pub async fn new_with_config<C>(
492 connection_info: &RedisConnectionInfo,
493 stream: C,
494 config: AsyncConnectionConfig,
495 ) -> RedisResult<(Self, impl Future<Output = ()>)>
496 where
497 C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
498 {
499 #[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))]
500 compile_error!("tokio-comp or async-std-comp features required for aio feature");
501
502 let codec = ValueCodec::default().framed(stream);
503 if config.push_sender.is_some() {
504 check_resp3!(
505 connection_info.protocol,
506 "Can only pass push sender to a connection using RESP3"
507 );
508 }
509
510 #[cfg(feature = "cache-aio")]
511 let cache_config = config.cache.as_ref().map(|cache| match cache {
512 crate::client::Cache::Config(cache_config) => *cache_config,
513 #[cfg(feature = "connection-manager")]
514 crate::client::Cache::Manager(cache_manager) => cache_manager.cache_config,
515 });
516 #[cfg(feature = "cache-aio")]
517 let cache_manager_opt = config
518 .cache
519 .map(|cache| {
520 check_resp3!(
521 connection_info.protocol,
522 "Can only enable client side caching in a connection using RESP3"
523 );
524 match cache {
525 crate::client::Cache::Config(cache_config) => {
526 Ok(CacheManager::new(cache_config))
527 }
528 #[cfg(feature = "connection-manager")]
529 crate::client::Cache::Manager(cache_manager) => Ok(cache_manager),
530 }
531 })
532 .transpose()?;
533
534 let (pipeline, driver) = Pipeline::new(
535 codec,
536 config.push_sender,
537 #[cfg(feature = "cache-aio")]
538 cache_manager_opt.clone(),
539 );
540 let mut con = MultiplexedConnection {
541 pipeline,
542 db: connection_info.db,
543 response_timeout: config.response_timeout,
544 protocol: connection_info.protocol,
545 _task_handle: None,
546 #[cfg(feature = "cache-aio")]
547 cache_manager: cache_manager_opt,
548 };
549 let driver = {
550 let auth = setup_connection(
551 connection_info,
552 &mut con,
553 #[cfg(feature = "cache-aio")]
554 cache_config,
555 );
556
557 futures_util::pin_mut!(auth);
558
559 match futures_util::future::select(auth, driver).await {
560 futures_util::future::Either::Left((result, driver)) => {
561 result?;
562 driver
563 }
564 futures_util::future::Either::Right(((), _)) => {
565 return Err(RedisError::from((
566 crate::ErrorKind::IoError,
567 "Multiplexed connection driver unexpectedly terminated",
568 )));
569 }
570 }
571 };
572 Ok((con, driver))
573 }
574
575 pub(crate) fn set_task_handle(&mut self, handle: TaskHandle) {
578 self._task_handle = Some(SharedHandleContainer::new(handle));
579 }
580
581 pub fn set_response_timeout(&mut self, timeout: std::time::Duration) {
583 self.response_timeout = Some(timeout);
584 }
585
586 pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
589 #[cfg(feature = "cache-aio")]
590 if let Some(cache_manager) = &self.cache_manager {
591 match cache_manager.get_cached_cmd(cmd) {
592 PrepareCacheResult::Cached(value) => return Ok(value),
593 PrepareCacheResult::NotCached(cacheable_command) => {
594 let mut pipeline = crate::Pipeline::new();
595 cacheable_command.pack_command(cache_manager, &mut pipeline);
596
597 let result = self
598 .pipeline
599 .send_recv(
600 pipeline.get_packed_pipeline(),
601 Some(PipelineResponseExpectation {
602 skipped_response_count: 0,
603 expected_response_count: pipeline.commands.len(),
604 is_transaction: false,
605 }),
606 self.response_timeout,
607 )
608 .await?;
609 let replies: Vec<Value> = crate::types::from_owned_redis_value(result)?;
610 return cacheable_command.resolve(cache_manager, replies.into_iter());
611 }
612 _ => (),
613 }
614 }
615 self.pipeline
616 .send_recv(cmd.get_packed_command(), None, self.response_timeout)
617 .await
618 }
619
620 pub async fn send_packed_commands(
624 &mut self,
625 cmd: &crate::Pipeline,
626 offset: usize,
627 count: usize,
628 ) -> RedisResult<Vec<Value>> {
629 #[cfg(feature = "cache-aio")]
630 if let Some(cache_manager) = &self.cache_manager {
631 let (cacheable_pipeline, pipeline, (skipped_response_count, expected_response_count)) =
632 cache_manager.get_cached_pipeline(cmd);
633 let result = self
634 .pipeline
635 .send_recv(
636 pipeline.get_packed_pipeline(),
637 Some(PipelineResponseExpectation {
638 skipped_response_count,
639 expected_response_count,
640 is_transaction: cacheable_pipeline.transaction_mode,
641 }),
642 self.response_timeout,
643 )
644 .await?;
645
646 return cacheable_pipeline.resolve(cache_manager, result);
647 }
648 let value = self
649 .pipeline
650 .send_recv(
651 cmd.get_packed_pipeline(),
652 Some(PipelineResponseExpectation {
653 skipped_response_count: offset,
654 expected_response_count: count,
655 is_transaction: cmd.is_transaction(),
656 }),
657 self.response_timeout,
658 )
659 .await?;
660 match value {
661 Value::Array(values) => Ok(values),
662 _ => Ok(vec![value]),
663 }
664 }
665
666 #[cfg(feature = "cache-aio")]
668 #[cfg_attr(docsrs, doc(cfg(feature = "cache-aio")))]
669 pub fn get_cache_statistics(&self) -> Option<CacheStatistics> {
670 self.cache_manager.as_ref().map(|cm| cm.statistics())
671 }
672}
673
674impl ConnectionLike for MultiplexedConnection {
675 fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
676 (async move { self.send_packed_command(cmd).await }).boxed()
677 }
678
679 fn req_packed_commands<'a>(
680 &'a mut self,
681 cmd: &'a crate::Pipeline,
682 offset: usize,
683 count: usize,
684 ) -> RedisFuture<'a, Vec<Value>> {
685 (async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
686 }
687
688 fn get_db(&self) -> i64 {
689 self.db
690 }
691}
692
693impl MultiplexedConnection {
694 pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
710 check_resp3!(self.protocol);
711 let mut cmd = cmd("SUBSCRIBE");
712 cmd.arg(channel_name);
713 cmd.exec_async(self).await?;
714 Ok(())
715 }
716
717 pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
731 check_resp3!(self.protocol);
732 let mut cmd = cmd("UNSUBSCRIBE");
733 cmd.arg(channel_name);
734 cmd.exec_async(self).await?;
735 Ok(())
736 }
737
738 pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
755 check_resp3!(self.protocol);
756 let mut cmd = cmd("PSUBSCRIBE");
757 cmd.arg(channel_pattern);
758 cmd.exec_async(self).await?;
759 Ok(())
760 }
761
762 pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
766 check_resp3!(self.protocol);
767 let mut cmd = cmd("PUNSUBSCRIBE");
768 cmd.arg(channel_pattern);
769 cmd.exec_async(self).await?;
770 Ok(())
771 }
772}