Skip to main content

redis/aio/
pubsub.rs

1use crate::types::{RedisResult, Value};
2use crate::{
3    FromRedisValue, Msg, RedisConnectionInfo, ToRedisArgs, aio::Runtime, cmd, errors::RedisError,
4    errors::closed_connection_error, from_redis_value, parser::ValueCodec,
5};
6use ::tokio::{
7    io::{AsyncRead, AsyncWrite},
8    sync::oneshot,
9};
10use futures_util::{
11    future::{Future, FutureExt},
12    ready,
13    sink::Sink,
14    stream::{self, Stream, StreamExt},
15};
16use pin_project_lite::pin_project;
17use std::collections::VecDeque;
18use std::pin::Pin;
19use std::task::{self, Poll};
20use tokio::sync::mpsc::UnboundedSender;
21use tokio::sync::mpsc::unbounded_channel;
22use tokio_util::codec::Decoder;
23
24use super::{SharedHandleContainer, setup_connection};
25
26// A signal that a un/subscribe request has completed.
27type RequestResultSender = oneshot::Sender<RedisResult<Value>>;
28
29// A single message sent through the pipeline
30struct PipelineMessage {
31    input: Vec<u8>,
32    output: RequestResultSender,
33}
34
35/// The sink part of a split async Pubsub.
36///
37/// The sink is used to subscribe and unsubscribe from
38/// channels.
39/// The stream part is independent from the sink,
40/// and dropping the sink doesn't cause the stream part to
41/// stop working.
42/// The sink isn't independent from the stream - dropping
43/// the stream will cause the sink to return errors on requests.
44#[derive(Clone)]
45pub struct PubSubSink {
46    sender: UnboundedSender<PipelineMessage>,
47}
48
49pin_project! {
50    /// The stream part of a split async Pubsub.
51    ///
52    /// The sink is used to subscribe and unsubscribe from
53    /// channels.
54    /// The stream part is independent from the sink,
55    /// and dropping the sink doesn't cause the stream part to
56    /// stop working.
57    /// The sink isn't independent from the stream - dropping
58    /// the stream will cause the sink to return errors on requests.
59    pub struct PubSubStream {
60        #[pin]
61        receiver: tokio::sync::mpsc::UnboundedReceiver<Msg>,
62        // This handle ensures that once the stream will be dropped, the underlying task will stop.
63        _task_handle: Option<SharedHandleContainer>,
64    }
65}
66
67pin_project! {
68    struct PipelineSink<T> {
69        // The `Sink + Stream` that sends requests and receives values from the server.
70        #[pin]
71        sink_stream: T,
72        // The requests that were sent and are awaiting a response.
73        in_flight: VecDeque<RequestResultSender>,
74        // A sender for the push messages received from the server.
75        sender: UnboundedSender<Msg>,
76    }
77}
78
79impl<T> PipelineSink<T>
80where
81    T: Stream<Item = RedisResult<Value>> + 'static,
82{
83    fn new(sink_stream: T, sender: UnboundedSender<Msg>) -> Self
84    where
85        T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
86    {
87        PipelineSink {
88            sink_stream,
89            in_flight: VecDeque::new(),
90            sender,
91        }
92    }
93
94    // Read messages from the stream and handle them.
95    fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
96        loop {
97            let self_ = self.as_mut().project();
98            if self_.sender.is_closed() {
99                return Poll::Ready(Err(()));
100            }
101
102            let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) {
103                Some(result) => result,
104                // The redis response stream is not going to produce any more items so we `Err`
105                // to break out of the `forward` combinator and stop handling requests
106                None => return Poll::Ready(Err(())),
107            };
108            self.as_mut().handle_message(item)?;
109        }
110    }
111
112    fn handle_message(self: Pin<&mut Self>, result: RedisResult<Value>) -> Result<(), ()> {
113        let self_ = self.project();
114
115        match result {
116            Ok(Value::Array(value)) => {
117                if let Some(Value::BulkString(kind)) = value.first() {
118                    if matches!(
119                        kind.as_slice(),
120                        b"subscribe" | b"psubscribe" | b"unsubscribe" | b"punsubscribe" | b"pong"
121                    ) {
122                        if let Some(entry) = self_.in_flight.pop_front() {
123                            let _ = entry.send(Ok(Value::Array(value)));
124                        };
125                        return Ok(());
126                    }
127                }
128
129                if let Some(msg) = Msg::from_owned_value(Value::Array(value)) {
130                    let _ = self_.sender.send(msg);
131                    Ok(())
132                } else {
133                    Err(())
134                }
135            }
136
137            Ok(Value::Push { kind, data }) => {
138                if kind.has_reply() {
139                    if let Some(entry) = self_.in_flight.pop_front() {
140                        let _ = entry.send(Ok(Value::Push { kind, data }));
141                    };
142                    return Ok(());
143                }
144
145                if let Some(msg) = Msg::from_push_info(crate::PushInfo { kind, data }) {
146                    let _ = self_.sender.send(msg);
147                    Ok(())
148                } else {
149                    Err(())
150                }
151            }
152
153            Err(err) if err.is_unrecoverable_error() => Err(()),
154
155            _ => {
156                if let Some(entry) = self_.in_flight.pop_front() {
157                    let _ = entry.send(result);
158                    Ok(())
159                } else {
160                    Err(())
161                }
162            }
163        }
164    }
165}
166
167impl<T> Sink<PipelineMessage> for PipelineSink<T>
168where
169    T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
170{
171    type Error = ();
172
173    // Retrieve incoming messages and write them to the sink
174    fn poll_ready(
175        mut self: Pin<&mut Self>,
176        cx: &mut task::Context,
177    ) -> Poll<Result<(), Self::Error>> {
178        self.as_mut()
179            .project()
180            .sink_stream
181            .poll_ready(cx)
182            .map_err(|_| ())
183    }
184
185    fn start_send(
186        mut self: Pin<&mut Self>,
187        PipelineMessage { input, output }: PipelineMessage,
188    ) -> Result<(), Self::Error> {
189        let self_ = self.as_mut().project();
190
191        match self_.sink_stream.start_send(input) {
192            Ok(()) => {
193                self_.in_flight.push_back(output);
194                Ok(())
195            }
196            Err(err) => {
197                let _ = output.send(Err(err));
198                Err(())
199            }
200        }
201    }
202
203    fn poll_flush(
204        mut self: Pin<&mut Self>,
205        cx: &mut task::Context,
206    ) -> Poll<Result<(), Self::Error>> {
207        ready!(
208            self.as_mut()
209                .project()
210                .sink_stream
211                .poll_flush(cx)
212                .map_err(|err| {
213                    let _ = self.as_mut().handle_message(Err(err));
214                })
215        )?;
216        self.poll_read(cx)
217    }
218
219    fn poll_close(
220        mut self: Pin<&mut Self>,
221        cx: &mut task::Context,
222    ) -> Poll<Result<(), Self::Error>> {
223        // No new requests will come in after the first call to `close` but we need to complete any
224        // in progress requests before closing
225        if !self.in_flight.is_empty() {
226            ready!(self.as_mut().poll_flush(cx))?;
227        }
228        let this = self.as_mut().project();
229
230        if this.sender.is_closed() {
231            return Poll::Ready(Ok(()));
232        }
233
234        match ready!(this.sink_stream.poll_next(cx)) {
235            Some(result) => {
236                let _ = self.handle_message(result);
237                Poll::Pending
238            }
239            None => Poll::Ready(Ok(())),
240        }
241    }
242}
243
244impl PubSubSink {
245    fn new<T>(
246        sink_stream: T,
247        messages_sender: UnboundedSender<Msg>,
248    ) -> (Self, impl Future<Output = ()>)
249    where
250        T: Sink<Vec<u8>, Error = RedisError>,
251        T: Stream<Item = RedisResult<Value>>,
252        T: Unpin + Send + 'static,
253    {
254        let (sender, mut receiver) = unbounded_channel();
255        let sink = PipelineSink::new(sink_stream, messages_sender);
256        let f = stream::poll_fn(move |cx| {
257            let res = receiver.poll_recv(cx);
258            match res {
259                // We don't want to stop the backing task for the stream, even if the sink was closed.
260                Poll::Ready(None) => Poll::Pending,
261                _ => res,
262            }
263        })
264        .map(Ok)
265        .forward(sink)
266        .map(|_| ());
267        (PubSubSink { sender }, f)
268    }
269
270    async fn send_recv(&mut self, input: Vec<u8>) -> Result<Value, RedisError> {
271        let (sender, receiver) = oneshot::channel();
272
273        self.sender
274            .send(PipelineMessage {
275                input,
276                output: sender,
277            })
278            .map_err(|_| closed_connection_error())?;
279        match receiver.await {
280            Ok(result) => result,
281            Err(_) => Err(closed_connection_error()),
282        }
283    }
284
285    /// Subscribes to a new channel(s).
286    ///
287    /// ```rust,no_run
288    /// # #[cfg(feature = "aio")]
289    /// # async fn do_something() -> redis::RedisResult<()> {
290    /// let client = redis::Client::open("redis://127.0.0.1/")?;
291    /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
292    /// sink.subscribe("channel_1").await?;
293    /// sink.subscribe(&["channel_2", "channel_3"]).await?;
294    /// # Ok(())
295    /// # }
296    /// ```
297    pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
298        let cmd = cmd("SUBSCRIBE").arg(channel_name).get_packed_command();
299        self.send_recv(cmd).await.map(|_| ())
300    }
301
302    /// Unsubscribes from channel(s).
303    ///
304    /// ```rust,no_run
305    /// # #[cfg(feature = "aio")]
306    /// # async fn do_something() -> redis::RedisResult<()> {
307    /// let client = redis::Client::open("redis://127.0.0.1/")?;
308    /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
309    /// sink.subscribe(&["channel_1", "channel_2"]).await?;
310    /// sink.unsubscribe(&["channel_1", "channel_2"]).await?;
311    /// # Ok(())
312    /// # }
313    /// ```
314    pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
315        let cmd = cmd("UNSUBSCRIBE").arg(channel_name).get_packed_command();
316        self.send_recv(cmd).await.map(|_| ())
317    }
318
319    /// Subscribes to new channel(s) with pattern(s).
320    ///
321    /// ```rust,no_run
322    /// # #[cfg(feature = "aio")]
323    /// # async fn do_something() -> redis::RedisResult<()> {
324    /// let client = redis::Client::open("redis://127.0.0.1/")?;
325    /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
326    /// sink.psubscribe("channel*_1").await?;
327    /// sink.psubscribe(&["channel*_2", "channel*_3"]).await?;
328    /// # Ok(())
329    /// # }
330    /// ```
331    pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
332        let cmd = cmd("PSUBSCRIBE").arg(channel_pattern).get_packed_command();
333        self.send_recv(cmd).await.map(|_| ())
334    }
335
336    /// Unsubscribes from channel pattern(s).
337    ///
338    /// ```rust,no_run
339    /// # #[cfg(feature = "aio")]
340    /// # async fn do_something() -> redis::RedisResult<()> {
341    /// let client = redis::Client::open("redis://127.0.0.1/")?;
342    /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
343    /// sink.psubscribe(&["channel_1", "channel_2"]).await?;
344    /// sink.punsubscribe(&["channel_1", "channel_2"]).await?;
345    /// # Ok(())
346    /// # }
347    /// ```
348    pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
349        let cmd = cmd("PUNSUBSCRIBE")
350            .arg(channel_pattern)
351            .get_packed_command();
352        self.send_recv(cmd).await.map(|_| ())
353    }
354
355    /// Sends a ping with a message to the server
356    pub async fn ping_message<T: FromRedisValue>(
357        &mut self,
358        message: impl ToRedisArgs,
359    ) -> RedisResult<T> {
360        let cmd = cmd("PING").arg(message).get_packed_command();
361        let response = self.send_recv(cmd).await?;
362        Ok(from_redis_value(response)?)
363    }
364
365    /// Sends a ping to the server
366    pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
367        let cmd = cmd("PING").get_packed_command();
368        let response = self.send_recv(cmd).await?;
369        Ok(from_redis_value(response)?)
370    }
371}
372
373/// A connection dedicated to RESP2 pubsub messages.
374///
375/// If you're using a DB that supports RESP3, consider using a regular connection and setting a [crate::aio::AsyncPushSender] on it using [crate::client::AsyncConnectionConfig::set_push_sender].
376pub struct PubSub {
377    sink: PubSubSink,
378    stream: PubSubStream,
379}
380
381impl PubSub {
382    /// Constructs a new `MultiplexedConnection` out of a `AsyncRead + AsyncWrite` object
383    /// and a `ConnectionInfo`
384    pub async fn new<C>(connection_info: &RedisConnectionInfo, stream: C) -> RedisResult<Self>
385    where
386        C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
387    {
388        let mut codec = ValueCodec::default().framed(stream);
389        setup_connection(
390            &mut codec,
391            connection_info,
392            #[cfg(feature = "cache-aio")]
393            None,
394        )
395        .await?;
396        let (sender, receiver) = unbounded_channel();
397        let (sink, driver) = PubSubSink::new(codec, sender);
398        let handle = Runtime::locate().spawn(driver);
399        let _task_handle = Some(SharedHandleContainer::new(handle));
400        let stream = PubSubStream {
401            receiver,
402            _task_handle,
403        };
404        let con = PubSub { sink, stream };
405        Ok(con)
406    }
407
408    /// Subscribes to a new channel(s).
409    ///
410    /// ```rust,no_run
411    /// # #[cfg(feature = "aio")]
412    /// # #[cfg(feature = "aio")]
413    /// # async fn do_something() -> redis::RedisResult<()> {
414    /// let client = redis::Client::open("redis://127.0.0.1/")?;
415    /// let mut pubsub = client.get_async_pubsub().await?;
416    /// pubsub.subscribe("channel_1").await?;
417    /// pubsub.subscribe(&["channel_2", "channel_3"]).await?;
418    /// # Ok(())
419    /// # }
420    /// ```
421    pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
422        self.sink.subscribe(channel_name).await
423    }
424
425    /// Unsubscribes from channel(s).
426    ///
427    /// ```rust,no_run
428    /// # #[cfg(feature = "aio")]
429    /// # #[cfg(feature = "aio")]
430    /// # async fn do_something() -> redis::RedisResult<()> {
431    /// let client = redis::Client::open("redis://127.0.0.1/")?;
432    /// let mut pubsub = client.get_async_pubsub().await?;
433    /// pubsub.subscribe(&["channel_1", "channel_2"]).await?;
434    /// pubsub.unsubscribe(&["channel_1", "channel_2"]).await?;
435    /// # Ok(())
436    /// # }
437    /// ```
438    pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
439        self.sink.unsubscribe(channel_name).await
440    }
441
442    /// Subscribes to new channel(s) with pattern(s).
443    ///
444    /// ```rust,no_run
445    /// # #[cfg(feature = "aio")]
446    /// # async fn do_something() -> redis::RedisResult<()> {
447    /// let client = redis::Client::open("redis://127.0.0.1/")?;
448    /// let mut pubsub = client.get_async_pubsub().await?;
449    /// pubsub.psubscribe("channel*_1").await?;
450    /// pubsub.psubscribe(&["channel*_2", "channel*_3"]).await?;
451    /// # Ok(())
452    /// # }
453    /// ```
454    pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
455        self.sink.psubscribe(channel_pattern).await
456    }
457
458    /// Unsubscribes from channel pattern(s).
459    ///
460    /// ```rust,no_run
461    /// # #[cfg(feature = "aio")]
462    /// # async fn do_something() -> redis::RedisResult<()> {
463    /// let client = redis::Client::open("redis://127.0.0.1/")?;
464    /// let mut pubsub = client.get_async_pubsub().await?;
465    /// pubsub.psubscribe(&["channel_1", "channel_2"]).await?;
466    /// pubsub.punsubscribe(&["channel_1", "channel_2"]).await?;
467    /// # Ok(())
468    /// # }
469    /// ```
470    pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
471        self.sink.punsubscribe(channel_pattern).await
472    }
473
474    /// Sends a ping to the server
475    pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
476        self.sink.ping().await
477    }
478
479    /// Sends a ping with a message to the server
480    pub async fn ping_message<T: FromRedisValue>(
481        &mut self,
482        message: impl ToRedisArgs,
483    ) -> RedisResult<T> {
484        self.sink.ping_message(message).await
485    }
486
487    /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions.
488    ///
489    /// The message itself is still generic and can be converted into an appropriate type through
490    /// the helper methods on it.
491    pub fn on_message(&mut self) -> impl Stream<Item = Msg> + '_ {
492        &mut self.stream
493    }
494
495    /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions consuming it.
496    ///
497    /// The message itself is still generic and can be converted into an appropriate type through
498    /// the helper methods on it.
499    /// This can be useful in cases where the stream needs to be returned or held by something other
500    /// than the [`PubSub`].
501    pub fn into_on_message(self) -> PubSubStream {
502        self.stream
503    }
504
505    /// Splits the PubSub into separate sink and stream components, so that subscriptions could be
506    /// updated through the `Sink` while concurrently waiting for new messages on the `Stream`.
507    pub fn split(self) -> (PubSubSink, PubSubStream) {
508        (self.sink, self.stream)
509    }
510}
511
512impl Stream for PubSubStream {
513    type Item = Msg;
514
515    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
516        self.project().receiver.poll_recv(cx)
517    }
518}