redis/aio/
pubsub.rs

1use crate::aio::Runtime;
2use crate::connection::{
3    check_connection_setup, connection_setup_pipeline, AuthResult, ConnectionSetupComponents,
4};
5#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
6use crate::parser::ValueCodec;
7use crate::types::{closed_connection_error, RedisError, RedisResult, Value};
8use crate::{cmd, from_owned_redis_value, FromRedisValue, Msg, RedisConnectionInfo, ToRedisArgs};
9use ::tokio::{
10    io::{AsyncRead, AsyncWrite},
11    sync::oneshot,
12};
13use futures_util::{
14    future::{Future, FutureExt},
15    ready,
16    sink::{Sink, SinkExt},
17    stream::{self, Stream, StreamExt},
18};
19use pin_project_lite::pin_project;
20use std::collections::VecDeque;
21use std::pin::Pin;
22use std::task::{self, Poll};
23use tokio::sync::mpsc::unbounded_channel;
24use tokio::sync::mpsc::UnboundedSender;
25#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
26use tokio_util::codec::Decoder;
27
28use super::SharedHandleContainer;
29
30// A signal that a un/subscribe request has completed.
31type RequestResultSender = oneshot::Sender<RedisResult<Value>>;
32
33// A single message sent through the pipeline
34struct PipelineMessage {
35    input: Vec<u8>,
36    output: RequestResultSender,
37}
38
39/// The sink part of a split async Pubsub.
40///
41/// The sink is used to subscribe and unsubscribe from
42/// channels.
43/// The stream part is independent from the sink,
44/// and dropping the sink doesn't cause the stream part to
45/// stop working.
46/// The sink isn't independent from the stream - dropping
47/// the stream will cause the sink to return errors on requests.
48#[derive(Clone)]
49pub struct PubSubSink {
50    sender: UnboundedSender<PipelineMessage>,
51}
52
53pin_project! {
54    /// The stream part of a split async Pubsub.
55    ///
56    /// The sink is used to subscribe and unsubscribe from
57    /// channels.
58    /// The stream part is independent from the sink,
59    /// and dropping the sink doesn't cause the stream part to
60    /// stop working.
61    /// The sink isn't independent from the stream - dropping
62    /// the stream will cause the sink to return errors on requests.
63    pub struct PubSubStream {
64        #[pin]
65        receiver: tokio::sync::mpsc::UnboundedReceiver<Msg>,
66        // This handle ensures that once the stream will be dropped, the underlying task will stop.
67        _task_handle: Option<SharedHandleContainer>,
68    }
69}
70
71pin_project! {
72    struct PipelineSink<T> {
73        // The `Sink + Stream` that sends requests and receives values from the server.
74        #[pin]
75        sink_stream: T,
76        // The requests that were sent and are awaiting a response.
77        in_flight: VecDeque<RequestResultSender>,
78        // A sender for the push messages received from the server.
79        sender: UnboundedSender<Msg>,
80    }
81}
82
83impl<T> PipelineSink<T>
84where
85    T: Stream<Item = RedisResult<Value>> + 'static,
86{
87    fn new(sink_stream: T, sender: UnboundedSender<Msg>) -> Self
88    where
89        T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
90    {
91        PipelineSink {
92            sink_stream,
93            in_flight: VecDeque::new(),
94            sender,
95        }
96    }
97
98    // Read messages from the stream and handle them.
99    fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
100        loop {
101            let self_ = self.as_mut().project();
102            if self_.sender.is_closed() {
103                return Poll::Ready(Err(()));
104            }
105
106            let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) {
107                Some(result) => result,
108                // The redis response stream is not going to produce any more items so we `Err`
109                // to break out of the `forward` combinator and stop handling requests
110                None => return Poll::Ready(Err(())),
111            };
112            self.as_mut().handle_message(item)?;
113        }
114    }
115
116    fn handle_message(self: Pin<&mut Self>, result: RedisResult<Value>) -> Result<(), ()> {
117        let self_ = self.project();
118
119        match result {
120            Ok(Value::Array(value)) => {
121                if let Some(Value::BulkString(kind)) = value.first() {
122                    if matches!(
123                        kind.as_slice(),
124                        b"subscribe" | b"psubscribe" | b"unsubscribe" | b"punsubscribe" | b"pong"
125                    ) {
126                        if let Some(entry) = self_.in_flight.pop_front() {
127                            let _ = entry.send(Ok(Value::Array(value)));
128                        };
129                        return Ok(());
130                    }
131                }
132
133                if let Some(msg) = Msg::from_owned_value(Value::Array(value)) {
134                    let _ = self_.sender.send(msg);
135                    Ok(())
136                } else {
137                    Err(())
138                }
139            }
140
141            Ok(Value::Push { kind, data }) => {
142                if kind.has_reply() {
143                    if let Some(entry) = self_.in_flight.pop_front() {
144                        let _ = entry.send(Ok(Value::Push { kind, data }));
145                    };
146                    return Ok(());
147                }
148
149                if let Some(msg) = Msg::from_push_info(crate::PushInfo { kind, data }) {
150                    let _ = self_.sender.send(msg);
151                    Ok(())
152                } else {
153                    Err(())
154                }
155            }
156
157            Err(err) if err.is_unrecoverable_error() => Err(()),
158
159            _ => {
160                if let Some(entry) = self_.in_flight.pop_front() {
161                    let _ = entry.send(result);
162                    Ok(())
163                } else {
164                    Err(())
165                }
166            }
167        }
168    }
169}
170
171impl<T> Sink<PipelineMessage> for PipelineSink<T>
172where
173    T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
174{
175    type Error = ();
176
177    // Retrieve incoming messages and write them to the sink
178    fn poll_ready(
179        mut self: Pin<&mut Self>,
180        cx: &mut task::Context,
181    ) -> Poll<Result<(), Self::Error>> {
182        self.as_mut()
183            .project()
184            .sink_stream
185            .poll_ready(cx)
186            .map_err(|_| ())
187    }
188
189    fn start_send(
190        mut self: Pin<&mut Self>,
191        PipelineMessage { input, output }: PipelineMessage,
192    ) -> Result<(), Self::Error> {
193        let self_ = self.as_mut().project();
194
195        match self_.sink_stream.start_send(input) {
196            Ok(()) => {
197                self_.in_flight.push_back(output);
198                Ok(())
199            }
200            Err(err) => {
201                let _ = output.send(Err(err));
202                Err(())
203            }
204        }
205    }
206
207    fn poll_flush(
208        mut self: Pin<&mut Self>,
209        cx: &mut task::Context,
210    ) -> Poll<Result<(), Self::Error>> {
211        ready!(self
212            .as_mut()
213            .project()
214            .sink_stream
215            .poll_flush(cx)
216            .map_err(|err| {
217                let _ = self.as_mut().handle_message(Err(err));
218            }))?;
219        self.poll_read(cx)
220    }
221
222    fn poll_close(
223        mut self: Pin<&mut Self>,
224        cx: &mut task::Context,
225    ) -> Poll<Result<(), Self::Error>> {
226        // No new requests will come in after the first call to `close` but we need to complete any
227        // in progress requests before closing
228        if !self.in_flight.is_empty() {
229            ready!(self.as_mut().poll_flush(cx))?;
230        }
231        let this = self.as_mut().project();
232
233        if this.sender.is_closed() {
234            return Poll::Ready(Ok(()));
235        }
236
237        match ready!(this.sink_stream.poll_next(cx)) {
238            Some(result) => {
239                let _ = self.handle_message(result);
240                Poll::Pending
241            }
242            None => Poll::Ready(Ok(())),
243        }
244    }
245}
246
247impl PubSubSink {
248    fn new<T>(
249        sink_stream: T,
250        messages_sender: UnboundedSender<Msg>,
251    ) -> (Self, impl Future<Output = ()>)
252    where
253        T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + Send + 'static,
254        T::Item: Send,
255        T::Error: Send,
256        T::Error: ::std::fmt::Debug,
257    {
258        let (sender, mut receiver) = unbounded_channel();
259        let sink = PipelineSink::new(sink_stream, messages_sender);
260        let f = stream::poll_fn(move |cx| {
261            let res = receiver.poll_recv(cx);
262            match res {
263                // We don't want to stop the backing task for the stream, even if the sink was closed.
264                Poll::Ready(None) => Poll::Pending,
265                _ => res,
266            }
267        })
268        .map(Ok)
269        .forward(sink)
270        .map(|_| ());
271        (PubSubSink { sender }, f)
272    }
273
274    async fn send_recv(&mut self, input: Vec<u8>) -> Result<Value, RedisError> {
275        let (sender, receiver) = oneshot::channel();
276
277        self.sender
278            .send(PipelineMessage {
279                input,
280                output: sender,
281            })
282            .map_err(|_| closed_connection_error())?;
283        match receiver.await {
284            Ok(result) => result,
285            Err(_) => Err(closed_connection_error()),
286        }
287    }
288
289    /// Subscribes to a new channel(s).
290    ///
291    /// ```rust,no_run
292    /// # #[cfg(feature = "aio")]
293    /// # async fn do_something() -> redis::RedisResult<()> {
294    /// let client = redis::Client::open("redis://127.0.0.1/")?;
295    /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
296    /// sink.subscribe("channel_1").await?;
297    /// sink.subscribe(&["channel_2", "channel_3"]).await?;
298    /// # Ok(())
299    /// # }
300    /// ```
301    pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
302        let cmd = cmd("SUBSCRIBE").arg(channel_name).get_packed_command();
303        self.send_recv(cmd).await.map(|_| ())
304    }
305
306    /// Unsubscribes from channel(s).
307    ///
308    /// ```rust,no_run
309    /// # #[cfg(feature = "aio")]
310    /// # async fn do_something() -> redis::RedisResult<()> {
311    /// let client = redis::Client::open("redis://127.0.0.1/")?;
312    /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
313    /// sink.subscribe(&["channel_1", "channel_2"]).await?;
314    /// sink.unsubscribe(&["channel_1", "channel_2"]).await?;
315    /// # Ok(())
316    /// # }
317    /// ```
318    pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
319        let cmd = cmd("UNSUBSCRIBE").arg(channel_name).get_packed_command();
320        self.send_recv(cmd).await.map(|_| ())
321    }
322
323    /// Subscribes to new channel(s) with pattern(s).
324    ///
325    /// ```rust,no_run
326    /// # #[cfg(feature = "aio")]
327    /// # async fn do_something() -> redis::RedisResult<()> {
328    /// let client = redis::Client::open("redis://127.0.0.1/")?;
329    /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
330    /// sink.psubscribe("channel*_1").await?;
331    /// sink.psubscribe(&["channel*_2", "channel*_3"]).await?;
332    /// # Ok(())
333    /// # }
334    /// ```
335    pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
336        let cmd = cmd("PSUBSCRIBE").arg(channel_pattern).get_packed_command();
337        self.send_recv(cmd).await.map(|_| ())
338    }
339
340    /// Unsubscribes from channel pattern(s).
341    ///
342    /// ```rust,no_run
343    /// # #[cfg(feature = "aio")]
344    /// # async fn do_something() -> redis::RedisResult<()> {
345    /// let client = redis::Client::open("redis://127.0.0.1/")?;
346    /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
347    /// sink.psubscribe(&["channel_1", "channel_2"]).await?;
348    /// sink.punsubscribe(&["channel_1", "channel_2"]).await?;
349    /// # Ok(())
350    /// # }
351    /// ```
352    pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
353        let cmd = cmd("PUNSUBSCRIBE")
354            .arg(channel_pattern)
355            .get_packed_command();
356        self.send_recv(cmd).await.map(|_| ())
357    }
358
359    /// Sends a ping with a message to the server
360    pub async fn ping_message<T: FromRedisValue>(
361        &mut self,
362        message: impl ToRedisArgs,
363    ) -> RedisResult<T> {
364        let cmd = cmd("PING").arg(message).get_packed_command();
365        let response = self.send_recv(cmd).await?;
366        from_owned_redis_value(response)
367    }
368
369    /// Sends a ping to the server
370    pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
371        let cmd = cmd("PING").get_packed_command();
372        let response = self.send_recv(cmd).await?;
373        from_owned_redis_value(response)
374    }
375}
376
377/// A connection dedicated to pubsub messages.
378pub struct PubSub {
379    sink: PubSubSink,
380    stream: PubSubStream,
381}
382
383async fn execute_connection_pipeline<T>(
384    codec: &mut T,
385    (pipeline, instructions): (crate::Pipeline, ConnectionSetupComponents),
386) -> RedisResult<AuthResult>
387where
388    T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
389    T: Send + 'static,
390    T::Item: Send,
391    T::Error: Send,
392    T::Error: ::std::fmt::Debug,
393    T: Unpin,
394{
395    let count = pipeline.len();
396    if count == 0 {
397        return Ok(AuthResult::Succeeded);
398    }
399    codec.send(pipeline.get_packed_pipeline()).await?;
400
401    let mut results = Vec::with_capacity(count);
402    for _ in 0..count {
403        let value = codec.next().await;
404        match value {
405            Some(Ok(val)) => results.push(val),
406            _ => return Err(closed_connection_error()),
407        }
408    }
409
410    check_connection_setup(results, instructions)
411}
412
413async fn setup_connection<T>(
414    codec: &mut T,
415    connection_info: &RedisConnectionInfo,
416) -> RedisResult<()>
417where
418    T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
419    T: Send + 'static,
420    T::Item: Send,
421    T::Error: Send,
422    T::Error: ::std::fmt::Debug,
423    T: Unpin,
424{
425    if execute_connection_pipeline(
426        codec,
427        connection_setup_pipeline(
428            connection_info,
429            true,
430            #[cfg(feature = "cache-aio")]
431            None,
432        ),
433    )
434    .await?
435        == AuthResult::ShouldRetryWithoutUsername
436    {
437        execute_connection_pipeline(
438            codec,
439            connection_setup_pipeline(
440                connection_info,
441                false,
442                #[cfg(feature = "cache-aio")]
443                None,
444            ),
445        )
446        .await?;
447    }
448
449    Ok(())
450}
451
452impl PubSub {
453    /// Constructs a new `MultiplexedConnection` out of a `AsyncRead + AsyncWrite` object
454    /// and a `ConnectionInfo`
455    pub async fn new<C>(connection_info: &RedisConnectionInfo, stream: C) -> RedisResult<Self>
456    where
457        C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
458    {
459        #[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))]
460        compile_error!("tokio-comp or async-std-comp features required for aio feature");
461
462        let mut codec = ValueCodec::default().framed(stream);
463        setup_connection(&mut codec, connection_info).await?;
464        let (sender, receiver) = unbounded_channel();
465        let (sink, driver) = PubSubSink::new(codec, sender);
466        let handle = Runtime::locate().spawn(driver);
467        let _task_handle = Some(SharedHandleContainer::new(handle));
468        let stream = PubSubStream {
469            receiver,
470            _task_handle,
471        };
472        let con = PubSub { sink, stream };
473        Ok(con)
474    }
475
476    /// Subscribes to a new channel(s).
477    ///
478    /// ```rust,no_run
479    /// # #[cfg(feature = "aio")]
480    /// # #[cfg(feature = "aio")]
481    /// # async fn do_something() -> redis::RedisResult<()> {
482    /// let client = redis::Client::open("redis://127.0.0.1/")?;
483    /// let mut pubsub = client.get_async_pubsub().await?;
484    /// pubsub.subscribe("channel_1").await?;
485    /// pubsub.subscribe(&["channel_2", "channel_3"]).await?;
486    /// # Ok(())
487    /// # }
488    /// ```
489    pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
490        self.sink.subscribe(channel_name).await
491    }
492
493    /// Unsubscribes from channel(s).
494    ///
495    /// ```rust,no_run
496    /// # #[cfg(feature = "aio")]
497    /// # #[cfg(feature = "aio")]
498    /// # async fn do_something() -> redis::RedisResult<()> {
499    /// let client = redis::Client::open("redis://127.0.0.1/")?;
500    /// let mut pubsub = client.get_async_pubsub().await?;
501    /// pubsub.subscribe(&["channel_1", "channel_2"]).await?;
502    /// pubsub.unsubscribe(&["channel_1", "channel_2"]).await?;
503    /// # Ok(())
504    /// # }
505    /// ```
506    pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
507        self.sink.unsubscribe(channel_name).await
508    }
509
510    /// Subscribes to new channel(s) with pattern(s).
511    ///
512    /// ```rust,no_run
513    /// # #[cfg(feature = "aio")]
514    /// # async fn do_something() -> redis::RedisResult<()> {
515    /// let client = redis::Client::open("redis://127.0.0.1/")?;
516    /// let mut pubsub = client.get_async_pubsub().await?;
517    /// pubsub.psubscribe("channel*_1").await?;
518    /// pubsub.psubscribe(&["channel*_2", "channel*_3"]).await?;
519    /// # Ok(())
520    /// # }
521    /// ```
522    pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
523        self.sink.psubscribe(channel_pattern).await
524    }
525
526    /// Unsubscribes from channel pattern(s).
527    ///
528    /// ```rust,no_run
529    /// # #[cfg(feature = "aio")]
530    /// # async fn do_something() -> redis::RedisResult<()> {
531    /// let client = redis::Client::open("redis://127.0.0.1/")?;
532    /// let mut pubsub = client.get_async_pubsub().await?;
533    /// pubsub.psubscribe(&["channel_1", "channel_2"]).await?;
534    /// pubsub.punsubscribe(&["channel_1", "channel_2"]).await?;
535    /// # Ok(())
536    /// # }
537    /// ```
538    pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
539        self.sink.punsubscribe(channel_pattern).await
540    }
541
542    /// Sends a ping to the server
543    pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
544        self.sink.ping().await
545    }
546
547    /// Sends a ping with a message to the server
548    pub async fn ping_message<T: FromRedisValue>(
549        &mut self,
550        message: impl ToRedisArgs,
551    ) -> RedisResult<T> {
552        self.sink.ping_message(message).await
553    }
554
555    /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions.
556    ///
557    /// The message itself is still generic and can be converted into an appropriate type through
558    /// the helper methods on it.
559    pub fn on_message(&mut self) -> impl Stream<Item = Msg> + '_ {
560        &mut self.stream
561    }
562
563    /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions consuming it.
564    ///
565    /// The message itself is still generic and can be converted into an appropriate type through
566    /// the helper methods on it.
567    /// This can be useful in cases where the stream needs to be returned or held by something other
568    /// than the [`PubSub`].
569    pub fn into_on_message(self) -> PubSubStream {
570        self.stream
571    }
572
573    /// Splits the PubSub into separate sink and stream components, so that subscriptions could be
574    /// updated through the `Sink` while concurrently waiting for new messages on the `Stream`.
575    pub fn split(self) -> (PubSubSink, PubSubStream) {
576        (self.sink, self.stream)
577    }
578}
579
580impl Stream for PubSubStream {
581    type Item = Msg;
582
583    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
584        self.project().receiver.poll_recv(cx)
585    }
586}