redis/aio/
pubsub.rs

1use crate::aio::Runtime;
2use crate::parser::ValueCodec;
3use crate::types::{closed_connection_error, RedisError, RedisResult, Value};
4use crate::{cmd, from_owned_redis_value, FromRedisValue, Msg, RedisConnectionInfo, ToRedisArgs};
5use ::tokio::{
6    io::{AsyncRead, AsyncWrite},
7    sync::oneshot,
8};
9use futures_util::{
10    future::{Future, FutureExt},
11    ready,
12    sink::Sink,
13    stream::{self, Stream, StreamExt},
14};
15use pin_project_lite::pin_project;
16use std::collections::VecDeque;
17use std::pin::Pin;
18use std::task::{self, Poll};
19use tokio::sync::mpsc::unbounded_channel;
20use tokio::sync::mpsc::UnboundedSender;
21use tokio_util::codec::Decoder;
22
23use super::{setup_connection, SharedHandleContainer};
24
25// A signal that a un/subscribe request has completed.
26type RequestResultSender = oneshot::Sender<RedisResult<Value>>;
27
28// A single message sent through the pipeline
29struct PipelineMessage {
30    input: Vec<u8>,
31    output: RequestResultSender,
32}
33
34/// The sink part of a split async Pubsub.
35///
36/// The sink is used to subscribe and unsubscribe from
37/// channels.
38/// The stream part is independent from the sink,
39/// and dropping the sink doesn't cause the stream part to
40/// stop working.
41/// The sink isn't independent from the stream - dropping
42/// the stream will cause the sink to return errors on requests.
43#[derive(Clone)]
44pub struct PubSubSink {
45    sender: UnboundedSender<PipelineMessage>,
46}
47
48pin_project! {
49    /// The stream part of a split async Pubsub.
50    ///
51    /// The sink is used to subscribe and unsubscribe from
52    /// channels.
53    /// The stream part is independent from the sink,
54    /// and dropping the sink doesn't cause the stream part to
55    /// stop working.
56    /// The sink isn't independent from the stream - dropping
57    /// the stream will cause the sink to return errors on requests.
58    pub struct PubSubStream {
59        #[pin]
60        receiver: tokio::sync::mpsc::UnboundedReceiver<Msg>,
61        // This handle ensures that once the stream will be dropped, the underlying task will stop.
62        _task_handle: Option<SharedHandleContainer>,
63    }
64}
65
66pin_project! {
67    struct PipelineSink<T> {
68        // The `Sink + Stream` that sends requests and receives values from the server.
69        #[pin]
70        sink_stream: T,
71        // The requests that were sent and are awaiting a response.
72        in_flight: VecDeque<RequestResultSender>,
73        // A sender for the push messages received from the server.
74        sender: UnboundedSender<Msg>,
75    }
76}
77
78impl<T> PipelineSink<T>
79where
80    T: Stream<Item = RedisResult<Value>> + 'static,
81{
82    fn new(sink_stream: T, sender: UnboundedSender<Msg>) -> Self
83    where
84        T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
85    {
86        PipelineSink {
87            sink_stream,
88            in_flight: VecDeque::new(),
89            sender,
90        }
91    }
92
93    // Read messages from the stream and handle them.
94    fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
95        loop {
96            let self_ = self.as_mut().project();
97            if self_.sender.is_closed() {
98                return Poll::Ready(Err(()));
99            }
100
101            let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) {
102                Some(result) => result,
103                // The redis response stream is not going to produce any more items so we `Err`
104                // to break out of the `forward` combinator and stop handling requests
105                None => return Poll::Ready(Err(())),
106            };
107            self.as_mut().handle_message(item)?;
108        }
109    }
110
111    fn handle_message(self: Pin<&mut Self>, result: RedisResult<Value>) -> Result<(), ()> {
112        let self_ = self.project();
113
114        match result {
115            Ok(Value::Array(value)) => {
116                if let Some(Value::BulkString(kind)) = value.first() {
117                    if matches!(
118                        kind.as_slice(),
119                        b"subscribe" | b"psubscribe" | b"unsubscribe" | b"punsubscribe" | b"pong"
120                    ) {
121                        if let Some(entry) = self_.in_flight.pop_front() {
122                            let _ = entry.send(Ok(Value::Array(value)));
123                        };
124                        return Ok(());
125                    }
126                }
127
128                if let Some(msg) = Msg::from_owned_value(Value::Array(value)) {
129                    let _ = self_.sender.send(msg);
130                    Ok(())
131                } else {
132                    Err(())
133                }
134            }
135
136            Ok(Value::Push { kind, data }) => {
137                if kind.has_reply() {
138                    if let Some(entry) = self_.in_flight.pop_front() {
139                        let _ = entry.send(Ok(Value::Push { kind, data }));
140                    };
141                    return Ok(());
142                }
143
144                if let Some(msg) = Msg::from_push_info(crate::PushInfo { kind, data }) {
145                    let _ = self_.sender.send(msg);
146                    Ok(())
147                } else {
148                    Err(())
149                }
150            }
151
152            Err(err) if err.is_unrecoverable_error() => Err(()),
153
154            _ => {
155                if let Some(entry) = self_.in_flight.pop_front() {
156                    let _ = entry.send(result);
157                    Ok(())
158                } else {
159                    Err(())
160                }
161            }
162        }
163    }
164}
165
166impl<T> Sink<PipelineMessage> for PipelineSink<T>
167where
168    T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
169{
170    type Error = ();
171
172    // Retrieve incoming messages and write them to the sink
173    fn poll_ready(
174        mut self: Pin<&mut Self>,
175        cx: &mut task::Context,
176    ) -> Poll<Result<(), Self::Error>> {
177        self.as_mut()
178            .project()
179            .sink_stream
180            .poll_ready(cx)
181            .map_err(|_| ())
182    }
183
184    fn start_send(
185        mut self: Pin<&mut Self>,
186        PipelineMessage { input, output }: PipelineMessage,
187    ) -> Result<(), Self::Error> {
188        let self_ = self.as_mut().project();
189
190        match self_.sink_stream.start_send(input) {
191            Ok(()) => {
192                self_.in_flight.push_back(output);
193                Ok(())
194            }
195            Err(err) => {
196                let _ = output.send(Err(err));
197                Err(())
198            }
199        }
200    }
201
202    fn poll_flush(
203        mut self: Pin<&mut Self>,
204        cx: &mut task::Context,
205    ) -> Poll<Result<(), Self::Error>> {
206        ready!(self
207            .as_mut()
208            .project()
209            .sink_stream
210            .poll_flush(cx)
211            .map_err(|err| {
212                let _ = self.as_mut().handle_message(Err(err));
213            }))?;
214        self.poll_read(cx)
215    }
216
217    fn poll_close(
218        mut self: Pin<&mut Self>,
219        cx: &mut task::Context,
220    ) -> Poll<Result<(), Self::Error>> {
221        // No new requests will come in after the first call to `close` but we need to complete any
222        // in progress requests before closing
223        if !self.in_flight.is_empty() {
224            ready!(self.as_mut().poll_flush(cx))?;
225        }
226        let this = self.as_mut().project();
227
228        if this.sender.is_closed() {
229            return Poll::Ready(Ok(()));
230        }
231
232        match ready!(this.sink_stream.poll_next(cx)) {
233            Some(result) => {
234                let _ = self.handle_message(result);
235                Poll::Pending
236            }
237            None => Poll::Ready(Ok(())),
238        }
239    }
240}
241
242impl PubSubSink {
243    fn new<T>(
244        sink_stream: T,
245        messages_sender: UnboundedSender<Msg>,
246    ) -> (Self, impl Future<Output = ()>)
247    where
248        T: Sink<Vec<u8>, Error = RedisError>,
249        T: Stream<Item = RedisResult<Value>>,
250        T: Unpin + Send + 'static,
251    {
252        let (sender, mut receiver) = unbounded_channel();
253        let sink = PipelineSink::new(sink_stream, messages_sender);
254        let f = stream::poll_fn(move |cx| {
255            let res = receiver.poll_recv(cx);
256            match res {
257                // We don't want to stop the backing task for the stream, even if the sink was closed.
258                Poll::Ready(None) => Poll::Pending,
259                _ => res,
260            }
261        })
262        .map(Ok)
263        .forward(sink)
264        .map(|_| ());
265        (PubSubSink { sender }, f)
266    }
267
268    async fn send_recv(&mut self, input: Vec<u8>) -> Result<Value, RedisError> {
269        let (sender, receiver) = oneshot::channel();
270
271        self.sender
272            .send(PipelineMessage {
273                input,
274                output: sender,
275            })
276            .map_err(|_| closed_connection_error())?;
277        match receiver.await {
278            Ok(result) => result,
279            Err(_) => Err(closed_connection_error()),
280        }
281    }
282
283    /// Subscribes to a new channel(s).
284    ///
285    /// ```rust,no_run
286    /// # #[cfg(feature = "aio")]
287    /// # async fn do_something() -> redis::RedisResult<()> {
288    /// let client = redis::Client::open("redis://127.0.0.1/")?;
289    /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
290    /// sink.subscribe("channel_1").await?;
291    /// sink.subscribe(&["channel_2", "channel_3"]).await?;
292    /// # Ok(())
293    /// # }
294    /// ```
295    pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
296        let cmd = cmd("SUBSCRIBE").arg(channel_name).get_packed_command();
297        self.send_recv(cmd).await.map(|_| ())
298    }
299
300    /// Unsubscribes from channel(s).
301    ///
302    /// ```rust,no_run
303    /// # #[cfg(feature = "aio")]
304    /// # async fn do_something() -> redis::RedisResult<()> {
305    /// let client = redis::Client::open("redis://127.0.0.1/")?;
306    /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
307    /// sink.subscribe(&["channel_1", "channel_2"]).await?;
308    /// sink.unsubscribe(&["channel_1", "channel_2"]).await?;
309    /// # Ok(())
310    /// # }
311    /// ```
312    pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
313        let cmd = cmd("UNSUBSCRIBE").arg(channel_name).get_packed_command();
314        self.send_recv(cmd).await.map(|_| ())
315    }
316
317    /// Subscribes to new channel(s) with pattern(s).
318    ///
319    /// ```rust,no_run
320    /// # #[cfg(feature = "aio")]
321    /// # async fn do_something() -> redis::RedisResult<()> {
322    /// let client = redis::Client::open("redis://127.0.0.1/")?;
323    /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
324    /// sink.psubscribe("channel*_1").await?;
325    /// sink.psubscribe(&["channel*_2", "channel*_3"]).await?;
326    /// # Ok(())
327    /// # }
328    /// ```
329    pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
330        let cmd = cmd("PSUBSCRIBE").arg(channel_pattern).get_packed_command();
331        self.send_recv(cmd).await.map(|_| ())
332    }
333
334    /// Unsubscribes from channel pattern(s).
335    ///
336    /// ```rust,no_run
337    /// # #[cfg(feature = "aio")]
338    /// # async fn do_something() -> redis::RedisResult<()> {
339    /// let client = redis::Client::open("redis://127.0.0.1/")?;
340    /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
341    /// sink.psubscribe(&["channel_1", "channel_2"]).await?;
342    /// sink.punsubscribe(&["channel_1", "channel_2"]).await?;
343    /// # Ok(())
344    /// # }
345    /// ```
346    pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
347        let cmd = cmd("PUNSUBSCRIBE")
348            .arg(channel_pattern)
349            .get_packed_command();
350        self.send_recv(cmd).await.map(|_| ())
351    }
352
353    /// Sends a ping with a message to the server
354    pub async fn ping_message<T: FromRedisValue>(
355        &mut self,
356        message: impl ToRedisArgs,
357    ) -> RedisResult<T> {
358        let cmd = cmd("PING").arg(message).get_packed_command();
359        let response = self.send_recv(cmd).await?;
360        from_owned_redis_value(response)
361    }
362
363    /// Sends a ping to the server
364    pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
365        let cmd = cmd("PING").get_packed_command();
366        let response = self.send_recv(cmd).await?;
367        from_owned_redis_value(response)
368    }
369}
370
371/// A connection dedicated to RESP2 pubsub messages.
372///
373/// If you're using a DB that supports RESP3, consider using a regular connection and setting a [crate::aio::AsyncPushSender] on it using [crate::client::AsyncConnectionConfig::set_push_sender].
374pub struct PubSub {
375    sink: PubSubSink,
376    stream: PubSubStream,
377}
378
379impl PubSub {
380    /// Constructs a new `MultiplexedConnection` out of a `AsyncRead + AsyncWrite` object
381    /// and a `ConnectionInfo`
382    pub async fn new<C>(connection_info: &RedisConnectionInfo, stream: C) -> RedisResult<Self>
383    where
384        C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
385    {
386        let mut codec = ValueCodec::default().framed(stream);
387        setup_connection(
388            &mut codec,
389            connection_info,
390            #[cfg(feature = "cache-aio")]
391            None,
392        )
393        .await?;
394        let (sender, receiver) = unbounded_channel();
395        let (sink, driver) = PubSubSink::new(codec, sender);
396        let handle = Runtime::locate().spawn(driver);
397        let _task_handle = Some(SharedHandleContainer::new(handle));
398        let stream = PubSubStream {
399            receiver,
400            _task_handle,
401        };
402        let con = PubSub { sink, stream };
403        Ok(con)
404    }
405
406    /// Subscribes to a new channel(s).
407    ///
408    /// ```rust,no_run
409    /// # #[cfg(feature = "aio")]
410    /// # #[cfg(feature = "aio")]
411    /// # async fn do_something() -> redis::RedisResult<()> {
412    /// let client = redis::Client::open("redis://127.0.0.1/")?;
413    /// let mut pubsub = client.get_async_pubsub().await?;
414    /// pubsub.subscribe("channel_1").await?;
415    /// pubsub.subscribe(&["channel_2", "channel_3"]).await?;
416    /// # Ok(())
417    /// # }
418    /// ```
419    pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
420        self.sink.subscribe(channel_name).await
421    }
422
423    /// Unsubscribes from channel(s).
424    ///
425    /// ```rust,no_run
426    /// # #[cfg(feature = "aio")]
427    /// # #[cfg(feature = "aio")]
428    /// # async fn do_something() -> redis::RedisResult<()> {
429    /// let client = redis::Client::open("redis://127.0.0.1/")?;
430    /// let mut pubsub = client.get_async_pubsub().await?;
431    /// pubsub.subscribe(&["channel_1", "channel_2"]).await?;
432    /// pubsub.unsubscribe(&["channel_1", "channel_2"]).await?;
433    /// # Ok(())
434    /// # }
435    /// ```
436    pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
437        self.sink.unsubscribe(channel_name).await
438    }
439
440    /// Subscribes to new channel(s) with pattern(s).
441    ///
442    /// ```rust,no_run
443    /// # #[cfg(feature = "aio")]
444    /// # async fn do_something() -> redis::RedisResult<()> {
445    /// let client = redis::Client::open("redis://127.0.0.1/")?;
446    /// let mut pubsub = client.get_async_pubsub().await?;
447    /// pubsub.psubscribe("channel*_1").await?;
448    /// pubsub.psubscribe(&["channel*_2", "channel*_3"]).await?;
449    /// # Ok(())
450    /// # }
451    /// ```
452    pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
453        self.sink.psubscribe(channel_pattern).await
454    }
455
456    /// Unsubscribes from channel pattern(s).
457    ///
458    /// ```rust,no_run
459    /// # #[cfg(feature = "aio")]
460    /// # async fn do_something() -> redis::RedisResult<()> {
461    /// let client = redis::Client::open("redis://127.0.0.1/")?;
462    /// let mut pubsub = client.get_async_pubsub().await?;
463    /// pubsub.psubscribe(&["channel_1", "channel_2"]).await?;
464    /// pubsub.punsubscribe(&["channel_1", "channel_2"]).await?;
465    /// # Ok(())
466    /// # }
467    /// ```
468    pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
469        self.sink.punsubscribe(channel_pattern).await
470    }
471
472    /// Sends a ping to the server
473    pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
474        self.sink.ping().await
475    }
476
477    /// Sends a ping with a message to the server
478    pub async fn ping_message<T: FromRedisValue>(
479        &mut self,
480        message: impl ToRedisArgs,
481    ) -> RedisResult<T> {
482        self.sink.ping_message(message).await
483    }
484
485    /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions.
486    ///
487    /// The message itself is still generic and can be converted into an appropriate type through
488    /// the helper methods on it.
489    pub fn on_message(&mut self) -> impl Stream<Item = Msg> + '_ {
490        &mut self.stream
491    }
492
493    /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions consuming it.
494    ///
495    /// The message itself is still generic and can be converted into an appropriate type through
496    /// the helper methods on it.
497    /// This can be useful in cases where the stream needs to be returned or held by something other
498    /// than the [`PubSub`].
499    pub fn into_on_message(self) -> PubSubStream {
500        self.stream
501    }
502
503    /// Splits the PubSub into separate sink and stream components, so that subscriptions could be
504    /// updated through the `Sink` while concurrently waiting for new messages on the `Stream`.
505    pub fn split(self) -> (PubSubSink, PubSubStream) {
506        (self.sink, self.stream)
507    }
508}
509
510impl Stream for PubSubStream {
511    type Item = Msg;
512
513    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
514        self.project().receiver.poll_recv(cx)
515    }
516}