redis/aio/pubsub.rs
1use crate::types::{RedisResult, Value};
2use crate::{
3 FromRedisValue, Msg, RedisConnectionInfo, ToRedisArgs, aio::Runtime, cmd, errors::RedisError,
4 errors::closed_connection_error, from_redis_value, parser::ValueCodec,
5};
6use ::tokio::{
7 io::{AsyncRead, AsyncWrite},
8 sync::oneshot,
9};
10use futures_util::{
11 future::{Future, FutureExt},
12 ready,
13 sink::Sink,
14 stream::{self, Stream, StreamExt},
15};
16use pin_project_lite::pin_project;
17use std::collections::VecDeque;
18use std::pin::Pin;
19use std::task::{self, Poll};
20use tokio::sync::mpsc::UnboundedSender;
21use tokio::sync::mpsc::unbounded_channel;
22use tokio_util::codec::Decoder;
23
24use super::{SharedHandleContainer, setup_connection};
25
26// A signal that a un/subscribe request has completed.
27type RequestResultSender = oneshot::Sender<RedisResult<Value>>;
28
29// A single message sent through the pipeline
30struct PipelineMessage {
31 input: Vec<u8>,
32 output: RequestResultSender,
33}
34
35/// The sink part of a split async Pubsub.
36///
37/// The sink is used to subscribe and unsubscribe from
38/// channels.
39/// The stream part is independent from the sink,
40/// and dropping the sink doesn't cause the stream part to
41/// stop working.
42/// The sink isn't independent from the stream - dropping
43/// the stream will cause the sink to return errors on requests.
44#[derive(Clone)]
45pub struct PubSubSink {
46 sender: UnboundedSender<PipelineMessage>,
47}
48
49pin_project! {
50 /// The stream part of a split async Pubsub.
51 ///
52 /// The sink is used to subscribe and unsubscribe from
53 /// channels.
54 /// The stream part is independent from the sink,
55 /// and dropping the sink doesn't cause the stream part to
56 /// stop working.
57 /// The sink isn't independent from the stream - dropping
58 /// the stream will cause the sink to return errors on requests.
59 pub struct PubSubStream {
60 #[pin]
61 receiver: tokio::sync::mpsc::UnboundedReceiver<Msg>,
62 // This handle ensures that once the stream will be dropped, the underlying task will stop.
63 _task_handle: Option<SharedHandleContainer>,
64 }
65}
66
67pin_project! {
68 struct PipelineSink<T> {
69 // The `Sink + Stream` that sends requests and receives values from the server.
70 #[pin]
71 sink_stream: T,
72 // The requests that were sent and are awaiting a response.
73 in_flight: VecDeque<RequestResultSender>,
74 // A sender for the push messages received from the server.
75 sender: UnboundedSender<Msg>,
76 }
77}
78
79impl<T> PipelineSink<T>
80where
81 T: Stream<Item = RedisResult<Value>> + 'static,
82{
83 fn new(sink_stream: T, sender: UnboundedSender<Msg>) -> Self
84 where
85 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
86 {
87 PipelineSink {
88 sink_stream,
89 in_flight: VecDeque::new(),
90 sender,
91 }
92 }
93
94 // Read messages from the stream and handle them.
95 fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
96 loop {
97 let self_ = self.as_mut().project();
98 if self_.sender.is_closed() {
99 return Poll::Ready(Err(()));
100 }
101
102 let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) {
103 Some(result) => result,
104 // The redis response stream is not going to produce any more items so we `Err`
105 // to break out of the `forward` combinator and stop handling requests
106 None => return Poll::Ready(Err(())),
107 };
108 self.as_mut().handle_message(item)?;
109 }
110 }
111
112 fn handle_message(self: Pin<&mut Self>, result: RedisResult<Value>) -> Result<(), ()> {
113 let self_ = self.project();
114
115 match result {
116 Ok(Value::Array(value)) => {
117 if let Some(Value::BulkString(kind)) = value.first() {
118 if matches!(
119 kind.as_slice(),
120 b"subscribe" | b"psubscribe" | b"unsubscribe" | b"punsubscribe" | b"pong"
121 ) {
122 if let Some(entry) = self_.in_flight.pop_front() {
123 let _ = entry.send(Ok(Value::Array(value)));
124 };
125 return Ok(());
126 }
127 }
128
129 if let Some(msg) = Msg::from_owned_value(Value::Array(value)) {
130 let _ = self_.sender.send(msg);
131 Ok(())
132 } else {
133 Err(())
134 }
135 }
136
137 Ok(Value::Push { kind, data }) => {
138 if kind.has_reply() {
139 if let Some(entry) = self_.in_flight.pop_front() {
140 let _ = entry.send(Ok(Value::Push { kind, data }));
141 };
142 return Ok(());
143 }
144
145 if let Some(msg) = Msg::from_push_info(crate::PushInfo { kind, data }) {
146 let _ = self_.sender.send(msg);
147 Ok(())
148 } else {
149 Err(())
150 }
151 }
152
153 Err(err) if err.is_unrecoverable_error() => Err(()),
154
155 _ => {
156 if let Some(entry) = self_.in_flight.pop_front() {
157 let _ = entry.send(result);
158 Ok(())
159 } else {
160 Err(())
161 }
162 }
163 }
164 }
165}
166
167impl<T> Sink<PipelineMessage> for PipelineSink<T>
168where
169 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
170{
171 type Error = ();
172
173 // Retrieve incoming messages and write them to the sink
174 fn poll_ready(
175 mut self: Pin<&mut Self>,
176 cx: &mut task::Context,
177 ) -> Poll<Result<(), Self::Error>> {
178 self.as_mut()
179 .project()
180 .sink_stream
181 .poll_ready(cx)
182 .map_err(|_| ())
183 }
184
185 fn start_send(
186 mut self: Pin<&mut Self>,
187 PipelineMessage { input, output }: PipelineMessage,
188 ) -> Result<(), Self::Error> {
189 let self_ = self.as_mut().project();
190
191 match self_.sink_stream.start_send(input) {
192 Ok(()) => {
193 self_.in_flight.push_back(output);
194 Ok(())
195 }
196 Err(err) => {
197 let _ = output.send(Err(err));
198 Err(())
199 }
200 }
201 }
202
203 fn poll_flush(
204 mut self: Pin<&mut Self>,
205 cx: &mut task::Context,
206 ) -> Poll<Result<(), Self::Error>> {
207 ready!(
208 self.as_mut()
209 .project()
210 .sink_stream
211 .poll_flush(cx)
212 .map_err(|err| {
213 let _ = self.as_mut().handle_message(Err(err));
214 })
215 )?;
216 self.poll_read(cx)
217 }
218
219 fn poll_close(
220 mut self: Pin<&mut Self>,
221 cx: &mut task::Context,
222 ) -> Poll<Result<(), Self::Error>> {
223 // No new requests will come in after the first call to `close` but we need to complete any
224 // in progress requests before closing
225 if !self.in_flight.is_empty() {
226 ready!(self.as_mut().poll_flush(cx))?;
227 }
228 let this = self.as_mut().project();
229
230 if this.sender.is_closed() {
231 return Poll::Ready(Ok(()));
232 }
233
234 match ready!(this.sink_stream.poll_next(cx)) {
235 Some(result) => {
236 let _ = self.handle_message(result);
237 Poll::Pending
238 }
239 None => Poll::Ready(Ok(())),
240 }
241 }
242}
243
244impl PubSubSink {
245 fn new<T>(
246 sink_stream: T,
247 messages_sender: UnboundedSender<Msg>,
248 ) -> (Self, impl Future<Output = ()>)
249 where
250 T: Sink<Vec<u8>, Error = RedisError>,
251 T: Stream<Item = RedisResult<Value>>,
252 T: Unpin + Send + 'static,
253 {
254 let (sender, mut receiver) = unbounded_channel();
255 let sink = PipelineSink::new(sink_stream, messages_sender);
256 let f = stream::poll_fn(move |cx| {
257 let res = receiver.poll_recv(cx);
258 match res {
259 // We don't want to stop the backing task for the stream, even if the sink was closed.
260 Poll::Ready(None) => Poll::Pending,
261 _ => res,
262 }
263 })
264 .map(Ok)
265 .forward(sink)
266 .map(|_| ());
267 (PubSubSink { sender }, f)
268 }
269
270 async fn send_recv(&mut self, input: Vec<u8>) -> Result<Value, RedisError> {
271 let (sender, receiver) = oneshot::channel();
272
273 self.sender
274 .send(PipelineMessage {
275 input,
276 output: sender,
277 })
278 .map_err(|_| closed_connection_error())?;
279 match receiver.await {
280 Ok(result) => result,
281 Err(_) => Err(closed_connection_error()),
282 }
283 }
284
285 /// Subscribes to a new channel(s).
286 ///
287 /// ```rust,no_run
288 /// # #[cfg(feature = "aio")]
289 /// # async fn do_something() -> redis::RedisResult<()> {
290 /// let client = redis::Client::open("redis://127.0.0.1/")?;
291 /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
292 /// sink.subscribe("channel_1").await?;
293 /// sink.subscribe(&["channel_2", "channel_3"]).await?;
294 /// # Ok(())
295 /// # }
296 /// ```
297 pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
298 let cmd = cmd("SUBSCRIBE").arg(channel_name).get_packed_command();
299 self.send_recv(cmd).await.map(|_| ())
300 }
301
302 /// Unsubscribes from channel(s).
303 ///
304 /// ```rust,no_run
305 /// # #[cfg(feature = "aio")]
306 /// # async fn do_something() -> redis::RedisResult<()> {
307 /// let client = redis::Client::open("redis://127.0.0.1/")?;
308 /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
309 /// sink.subscribe(&["channel_1", "channel_2"]).await?;
310 /// sink.unsubscribe(&["channel_1", "channel_2"]).await?;
311 /// # Ok(())
312 /// # }
313 /// ```
314 pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
315 let cmd = cmd("UNSUBSCRIBE").arg(channel_name).get_packed_command();
316 self.send_recv(cmd).await.map(|_| ())
317 }
318
319 /// Subscribes to new channel(s) with pattern(s).
320 ///
321 /// ```rust,no_run
322 /// # #[cfg(feature = "aio")]
323 /// # async fn do_something() -> redis::RedisResult<()> {
324 /// let client = redis::Client::open("redis://127.0.0.1/")?;
325 /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
326 /// sink.psubscribe("channel*_1").await?;
327 /// sink.psubscribe(&["channel*_2", "channel*_3"]).await?;
328 /// # Ok(())
329 /// # }
330 /// ```
331 pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
332 let cmd = cmd("PSUBSCRIBE").arg(channel_pattern).get_packed_command();
333 self.send_recv(cmd).await.map(|_| ())
334 }
335
336 /// Unsubscribes from channel pattern(s).
337 ///
338 /// ```rust,no_run
339 /// # #[cfg(feature = "aio")]
340 /// # async fn do_something() -> redis::RedisResult<()> {
341 /// let client = redis::Client::open("redis://127.0.0.1/")?;
342 /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
343 /// sink.psubscribe(&["channel_1", "channel_2"]).await?;
344 /// sink.punsubscribe(&["channel_1", "channel_2"]).await?;
345 /// # Ok(())
346 /// # }
347 /// ```
348 pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
349 let cmd = cmd("PUNSUBSCRIBE")
350 .arg(channel_pattern)
351 .get_packed_command();
352 self.send_recv(cmd).await.map(|_| ())
353 }
354
355 /// Sends a ping with a message to the server
356 pub async fn ping_message<T: FromRedisValue>(
357 &mut self,
358 message: impl ToRedisArgs,
359 ) -> RedisResult<T> {
360 let cmd = cmd("PING").arg(message).get_packed_command();
361 let response = self.send_recv(cmd).await?;
362 Ok(from_redis_value(response)?)
363 }
364
365 /// Sends a ping to the server
366 pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
367 let cmd = cmd("PING").get_packed_command();
368 let response = self.send_recv(cmd).await?;
369 Ok(from_redis_value(response)?)
370 }
371}
372
373/// A connection dedicated to RESP2 pubsub messages.
374///
375/// If you're using a DB that supports RESP3, consider using a regular connection and setting a [crate::aio::AsyncPushSender] on it using [crate::client::AsyncConnectionConfig::set_push_sender].
376pub struct PubSub {
377 sink: PubSubSink,
378 stream: PubSubStream,
379}
380
381impl PubSub {
382 /// Constructs a new `MultiplexedConnection` out of a `AsyncRead + AsyncWrite` object
383 /// and a `ConnectionInfo`
384 pub async fn new<C>(connection_info: &RedisConnectionInfo, stream: C) -> RedisResult<Self>
385 where
386 C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
387 {
388 let mut codec = ValueCodec::default().framed(stream);
389 setup_connection(
390 &mut codec,
391 connection_info,
392 #[cfg(feature = "cache-aio")]
393 None,
394 )
395 .await?;
396 let (sender, receiver) = unbounded_channel();
397 let (sink, driver) = PubSubSink::new(codec, sender);
398 let handle = Runtime::locate().spawn(driver);
399 let _task_handle = Some(SharedHandleContainer::new(handle));
400 let stream = PubSubStream {
401 receiver,
402 _task_handle,
403 };
404 let con = PubSub { sink, stream };
405 Ok(con)
406 }
407
408 /// Subscribes to a new channel(s).
409 ///
410 /// ```rust,no_run
411 /// # #[cfg(feature = "aio")]
412 /// # #[cfg(feature = "aio")]
413 /// # async fn do_something() -> redis::RedisResult<()> {
414 /// let client = redis::Client::open("redis://127.0.0.1/")?;
415 /// let mut pubsub = client.get_async_pubsub().await?;
416 /// pubsub.subscribe("channel_1").await?;
417 /// pubsub.subscribe(&["channel_2", "channel_3"]).await?;
418 /// # Ok(())
419 /// # }
420 /// ```
421 pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
422 self.sink.subscribe(channel_name).await
423 }
424
425 /// Unsubscribes from channel(s).
426 ///
427 /// ```rust,no_run
428 /// # #[cfg(feature = "aio")]
429 /// # #[cfg(feature = "aio")]
430 /// # async fn do_something() -> redis::RedisResult<()> {
431 /// let client = redis::Client::open("redis://127.0.0.1/")?;
432 /// let mut pubsub = client.get_async_pubsub().await?;
433 /// pubsub.subscribe(&["channel_1", "channel_2"]).await?;
434 /// pubsub.unsubscribe(&["channel_1", "channel_2"]).await?;
435 /// # Ok(())
436 /// # }
437 /// ```
438 pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
439 self.sink.unsubscribe(channel_name).await
440 }
441
442 /// Subscribes to new channel(s) with pattern(s).
443 ///
444 /// ```rust,no_run
445 /// # #[cfg(feature = "aio")]
446 /// # async fn do_something() -> redis::RedisResult<()> {
447 /// let client = redis::Client::open("redis://127.0.0.1/")?;
448 /// let mut pubsub = client.get_async_pubsub().await?;
449 /// pubsub.psubscribe("channel*_1").await?;
450 /// pubsub.psubscribe(&["channel*_2", "channel*_3"]).await?;
451 /// # Ok(())
452 /// # }
453 /// ```
454 pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
455 self.sink.psubscribe(channel_pattern).await
456 }
457
458 /// Unsubscribes from channel pattern(s).
459 ///
460 /// ```rust,no_run
461 /// # #[cfg(feature = "aio")]
462 /// # async fn do_something() -> redis::RedisResult<()> {
463 /// let client = redis::Client::open("redis://127.0.0.1/")?;
464 /// let mut pubsub = client.get_async_pubsub().await?;
465 /// pubsub.psubscribe(&["channel_1", "channel_2"]).await?;
466 /// pubsub.punsubscribe(&["channel_1", "channel_2"]).await?;
467 /// # Ok(())
468 /// # }
469 /// ```
470 pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
471 self.sink.punsubscribe(channel_pattern).await
472 }
473
474 /// Sends a ping to the server
475 pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
476 self.sink.ping().await
477 }
478
479 /// Sends a ping with a message to the server
480 pub async fn ping_message<T: FromRedisValue>(
481 &mut self,
482 message: impl ToRedisArgs,
483 ) -> RedisResult<T> {
484 self.sink.ping_message(message).await
485 }
486
487 /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions.
488 ///
489 /// The message itself is still generic and can be converted into an appropriate type through
490 /// the helper methods on it.
491 pub fn on_message(&mut self) -> impl Stream<Item = Msg> + '_ {
492 &mut self.stream
493 }
494
495 /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions consuming it.
496 ///
497 /// The message itself is still generic and can be converted into an appropriate type through
498 /// the helper methods on it.
499 /// This can be useful in cases where the stream needs to be returned or held by something other
500 /// than the [`PubSub`].
501 pub fn into_on_message(self) -> PubSubStream {
502 self.stream
503 }
504
505 /// Splits the PubSub into separate sink and stream components, so that subscriptions could be
506 /// updated through the `Sink` while concurrently waiting for new messages on the `Stream`.
507 pub fn split(self) -> (PubSubSink, PubSubStream) {
508 (self.sink, self.stream)
509 }
510}
511
512impl Stream for PubSubStream {
513 type Item = Msg;
514
515 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
516 self.project().receiver.poll_recv(cx)
517 }
518}