redis/aio/pubsub.rs
1use crate::aio::Runtime;
2use crate::parser::ValueCodec;
3use crate::types::{closed_connection_error, RedisError, RedisResult, Value};
4use crate::{cmd, from_owned_redis_value, FromRedisValue, Msg, RedisConnectionInfo, ToRedisArgs};
5use ::tokio::{
6 io::{AsyncRead, AsyncWrite},
7 sync::oneshot,
8};
9use futures_util::{
10 future::{Future, FutureExt},
11 ready,
12 sink::Sink,
13 stream::{self, Stream, StreamExt},
14};
15use pin_project_lite::pin_project;
16use std::collections::VecDeque;
17use std::pin::Pin;
18use std::task::{self, Poll};
19use tokio::sync::mpsc::unbounded_channel;
20use tokio::sync::mpsc::UnboundedSender;
21use tokio_util::codec::Decoder;
22
23use super::{setup_connection, SharedHandleContainer};
24
25// A signal that a un/subscribe request has completed.
26type RequestResultSender = oneshot::Sender<RedisResult<Value>>;
27
28// A single message sent through the pipeline
29struct PipelineMessage {
30 input: Vec<u8>,
31 output: RequestResultSender,
32}
33
34/// The sink part of a split async Pubsub.
35///
36/// The sink is used to subscribe and unsubscribe from
37/// channels.
38/// The stream part is independent from the sink,
39/// and dropping the sink doesn't cause the stream part to
40/// stop working.
41/// The sink isn't independent from the stream - dropping
42/// the stream will cause the sink to return errors on requests.
43#[derive(Clone)]
44pub struct PubSubSink {
45 sender: UnboundedSender<PipelineMessage>,
46}
47
48pin_project! {
49 /// The stream part of a split async Pubsub.
50 ///
51 /// The sink is used to subscribe and unsubscribe from
52 /// channels.
53 /// The stream part is independent from the sink,
54 /// and dropping the sink doesn't cause the stream part to
55 /// stop working.
56 /// The sink isn't independent from the stream - dropping
57 /// the stream will cause the sink to return errors on requests.
58 pub struct PubSubStream {
59 #[pin]
60 receiver: tokio::sync::mpsc::UnboundedReceiver<Msg>,
61 // This handle ensures that once the stream will be dropped, the underlying task will stop.
62 _task_handle: Option<SharedHandleContainer>,
63 }
64}
65
66pin_project! {
67 struct PipelineSink<T> {
68 // The `Sink + Stream` that sends requests and receives values from the server.
69 #[pin]
70 sink_stream: T,
71 // The requests that were sent and are awaiting a response.
72 in_flight: VecDeque<RequestResultSender>,
73 // A sender for the push messages received from the server.
74 sender: UnboundedSender<Msg>,
75 }
76}
77
78impl<T> PipelineSink<T>
79where
80 T: Stream<Item = RedisResult<Value>> + 'static,
81{
82 fn new(sink_stream: T, sender: UnboundedSender<Msg>) -> Self
83 where
84 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
85 {
86 PipelineSink {
87 sink_stream,
88 in_flight: VecDeque::new(),
89 sender,
90 }
91 }
92
93 // Read messages from the stream and handle them.
94 fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
95 loop {
96 let self_ = self.as_mut().project();
97 if self_.sender.is_closed() {
98 return Poll::Ready(Err(()));
99 }
100
101 let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) {
102 Some(result) => result,
103 // The redis response stream is not going to produce any more items so we `Err`
104 // to break out of the `forward` combinator and stop handling requests
105 None => return Poll::Ready(Err(())),
106 };
107 self.as_mut().handle_message(item)?;
108 }
109 }
110
111 fn handle_message(self: Pin<&mut Self>, result: RedisResult<Value>) -> Result<(), ()> {
112 let self_ = self.project();
113
114 match result {
115 Ok(Value::Array(value)) => {
116 if let Some(Value::BulkString(kind)) = value.first() {
117 if matches!(
118 kind.as_slice(),
119 b"subscribe" | b"psubscribe" | b"unsubscribe" | b"punsubscribe" | b"pong"
120 ) {
121 if let Some(entry) = self_.in_flight.pop_front() {
122 let _ = entry.send(Ok(Value::Array(value)));
123 };
124 return Ok(());
125 }
126 }
127
128 if let Some(msg) = Msg::from_owned_value(Value::Array(value)) {
129 let _ = self_.sender.send(msg);
130 Ok(())
131 } else {
132 Err(())
133 }
134 }
135
136 Ok(Value::Push { kind, data }) => {
137 if kind.has_reply() {
138 if let Some(entry) = self_.in_flight.pop_front() {
139 let _ = entry.send(Ok(Value::Push { kind, data }));
140 };
141 return Ok(());
142 }
143
144 if let Some(msg) = Msg::from_push_info(crate::PushInfo { kind, data }) {
145 let _ = self_.sender.send(msg);
146 Ok(())
147 } else {
148 Err(())
149 }
150 }
151
152 Err(err) if err.is_unrecoverable_error() => Err(()),
153
154 _ => {
155 if let Some(entry) = self_.in_flight.pop_front() {
156 let _ = entry.send(result);
157 Ok(())
158 } else {
159 Err(())
160 }
161 }
162 }
163 }
164}
165
166impl<T> Sink<PipelineMessage> for PipelineSink<T>
167where
168 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
169{
170 type Error = ();
171
172 // Retrieve incoming messages and write them to the sink
173 fn poll_ready(
174 mut self: Pin<&mut Self>,
175 cx: &mut task::Context,
176 ) -> Poll<Result<(), Self::Error>> {
177 self.as_mut()
178 .project()
179 .sink_stream
180 .poll_ready(cx)
181 .map_err(|_| ())
182 }
183
184 fn start_send(
185 mut self: Pin<&mut Self>,
186 PipelineMessage { input, output }: PipelineMessage,
187 ) -> Result<(), Self::Error> {
188 let self_ = self.as_mut().project();
189
190 match self_.sink_stream.start_send(input) {
191 Ok(()) => {
192 self_.in_flight.push_back(output);
193 Ok(())
194 }
195 Err(err) => {
196 let _ = output.send(Err(err));
197 Err(())
198 }
199 }
200 }
201
202 fn poll_flush(
203 mut self: Pin<&mut Self>,
204 cx: &mut task::Context,
205 ) -> Poll<Result<(), Self::Error>> {
206 ready!(self
207 .as_mut()
208 .project()
209 .sink_stream
210 .poll_flush(cx)
211 .map_err(|err| {
212 let _ = self.as_mut().handle_message(Err(err));
213 }))?;
214 self.poll_read(cx)
215 }
216
217 fn poll_close(
218 mut self: Pin<&mut Self>,
219 cx: &mut task::Context,
220 ) -> Poll<Result<(), Self::Error>> {
221 // No new requests will come in after the first call to `close` but we need to complete any
222 // in progress requests before closing
223 if !self.in_flight.is_empty() {
224 ready!(self.as_mut().poll_flush(cx))?;
225 }
226 let this = self.as_mut().project();
227
228 if this.sender.is_closed() {
229 return Poll::Ready(Ok(()));
230 }
231
232 match ready!(this.sink_stream.poll_next(cx)) {
233 Some(result) => {
234 let _ = self.handle_message(result);
235 Poll::Pending
236 }
237 None => Poll::Ready(Ok(())),
238 }
239 }
240}
241
242impl PubSubSink {
243 fn new<T>(
244 sink_stream: T,
245 messages_sender: UnboundedSender<Msg>,
246 ) -> (Self, impl Future<Output = ()>)
247 where
248 T: Sink<Vec<u8>, Error = RedisError>,
249 T: Stream<Item = RedisResult<Value>>,
250 T: Unpin + Send + 'static,
251 {
252 let (sender, mut receiver) = unbounded_channel();
253 let sink = PipelineSink::new(sink_stream, messages_sender);
254 let f = stream::poll_fn(move |cx| {
255 let res = receiver.poll_recv(cx);
256 match res {
257 // We don't want to stop the backing task for the stream, even if the sink was closed.
258 Poll::Ready(None) => Poll::Pending,
259 _ => res,
260 }
261 })
262 .map(Ok)
263 .forward(sink)
264 .map(|_| ());
265 (PubSubSink { sender }, f)
266 }
267
268 async fn send_recv(&mut self, input: Vec<u8>) -> Result<Value, RedisError> {
269 let (sender, receiver) = oneshot::channel();
270
271 self.sender
272 .send(PipelineMessage {
273 input,
274 output: sender,
275 })
276 .map_err(|_| closed_connection_error())?;
277 match receiver.await {
278 Ok(result) => result,
279 Err(_) => Err(closed_connection_error()),
280 }
281 }
282
283 /// Subscribes to a new channel(s).
284 ///
285 /// ```rust,no_run
286 /// # #[cfg(feature = "aio")]
287 /// # async fn do_something() -> redis::RedisResult<()> {
288 /// let client = redis::Client::open("redis://127.0.0.1/")?;
289 /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
290 /// sink.subscribe("channel_1").await?;
291 /// sink.subscribe(&["channel_2", "channel_3"]).await?;
292 /// # Ok(())
293 /// # }
294 /// ```
295 pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
296 let cmd = cmd("SUBSCRIBE").arg(channel_name).get_packed_command();
297 self.send_recv(cmd).await.map(|_| ())
298 }
299
300 /// Unsubscribes from channel(s).
301 ///
302 /// ```rust,no_run
303 /// # #[cfg(feature = "aio")]
304 /// # async fn do_something() -> redis::RedisResult<()> {
305 /// let client = redis::Client::open("redis://127.0.0.1/")?;
306 /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
307 /// sink.subscribe(&["channel_1", "channel_2"]).await?;
308 /// sink.unsubscribe(&["channel_1", "channel_2"]).await?;
309 /// # Ok(())
310 /// # }
311 /// ```
312 pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
313 let cmd = cmd("UNSUBSCRIBE").arg(channel_name).get_packed_command();
314 self.send_recv(cmd).await.map(|_| ())
315 }
316
317 /// Subscribes to new channel(s) with pattern(s).
318 ///
319 /// ```rust,no_run
320 /// # #[cfg(feature = "aio")]
321 /// # async fn do_something() -> redis::RedisResult<()> {
322 /// let client = redis::Client::open("redis://127.0.0.1/")?;
323 /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
324 /// sink.psubscribe("channel*_1").await?;
325 /// sink.psubscribe(&["channel*_2", "channel*_3"]).await?;
326 /// # Ok(())
327 /// # }
328 /// ```
329 pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
330 let cmd = cmd("PSUBSCRIBE").arg(channel_pattern).get_packed_command();
331 self.send_recv(cmd).await.map(|_| ())
332 }
333
334 /// Unsubscribes from channel pattern(s).
335 ///
336 /// ```rust,no_run
337 /// # #[cfg(feature = "aio")]
338 /// # async fn do_something() -> redis::RedisResult<()> {
339 /// let client = redis::Client::open("redis://127.0.0.1/")?;
340 /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
341 /// sink.psubscribe(&["channel_1", "channel_2"]).await?;
342 /// sink.punsubscribe(&["channel_1", "channel_2"]).await?;
343 /// # Ok(())
344 /// # }
345 /// ```
346 pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
347 let cmd = cmd("PUNSUBSCRIBE")
348 .arg(channel_pattern)
349 .get_packed_command();
350 self.send_recv(cmd).await.map(|_| ())
351 }
352
353 /// Sends a ping with a message to the server
354 pub async fn ping_message<T: FromRedisValue>(
355 &mut self,
356 message: impl ToRedisArgs,
357 ) -> RedisResult<T> {
358 let cmd = cmd("PING").arg(message).get_packed_command();
359 let response = self.send_recv(cmd).await?;
360 from_owned_redis_value(response)
361 }
362
363 /// Sends a ping to the server
364 pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
365 let cmd = cmd("PING").get_packed_command();
366 let response = self.send_recv(cmd).await?;
367 from_owned_redis_value(response)
368 }
369}
370
371/// A connection dedicated to RESP2 pubsub messages.
372///
373/// If you're using a DB that supports RESP3, consider using a regular connection and setting a [crate::aio::AsyncPushSender] on it using [crate::client::AsyncConnectionConfig::set_push_sender].
374pub struct PubSub {
375 sink: PubSubSink,
376 stream: PubSubStream,
377}
378
379impl PubSub {
380 /// Constructs a new `MultiplexedConnection` out of a `AsyncRead + AsyncWrite` object
381 /// and a `ConnectionInfo`
382 pub async fn new<C>(connection_info: &RedisConnectionInfo, stream: C) -> RedisResult<Self>
383 where
384 C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
385 {
386 let mut codec = ValueCodec::default().framed(stream);
387 setup_connection(
388 &mut codec,
389 connection_info,
390 #[cfg(feature = "cache-aio")]
391 None,
392 )
393 .await?;
394 let (sender, receiver) = unbounded_channel();
395 let (sink, driver) = PubSubSink::new(codec, sender);
396 let handle = Runtime::locate().spawn(driver);
397 let _task_handle = Some(SharedHandleContainer::new(handle));
398 let stream = PubSubStream {
399 receiver,
400 _task_handle,
401 };
402 let con = PubSub { sink, stream };
403 Ok(con)
404 }
405
406 /// Subscribes to a new channel(s).
407 ///
408 /// ```rust,no_run
409 /// # #[cfg(feature = "aio")]
410 /// # #[cfg(feature = "aio")]
411 /// # async fn do_something() -> redis::RedisResult<()> {
412 /// let client = redis::Client::open("redis://127.0.0.1/")?;
413 /// let mut pubsub = client.get_async_pubsub().await?;
414 /// pubsub.subscribe("channel_1").await?;
415 /// pubsub.subscribe(&["channel_2", "channel_3"]).await?;
416 /// # Ok(())
417 /// # }
418 /// ```
419 pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
420 self.sink.subscribe(channel_name).await
421 }
422
423 /// Unsubscribes from channel(s).
424 ///
425 /// ```rust,no_run
426 /// # #[cfg(feature = "aio")]
427 /// # #[cfg(feature = "aio")]
428 /// # async fn do_something() -> redis::RedisResult<()> {
429 /// let client = redis::Client::open("redis://127.0.0.1/")?;
430 /// let mut pubsub = client.get_async_pubsub().await?;
431 /// pubsub.subscribe(&["channel_1", "channel_2"]).await?;
432 /// pubsub.unsubscribe(&["channel_1", "channel_2"]).await?;
433 /// # Ok(())
434 /// # }
435 /// ```
436 pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
437 self.sink.unsubscribe(channel_name).await
438 }
439
440 /// Subscribes to new channel(s) with pattern(s).
441 ///
442 /// ```rust,no_run
443 /// # #[cfg(feature = "aio")]
444 /// # async fn do_something() -> redis::RedisResult<()> {
445 /// let client = redis::Client::open("redis://127.0.0.1/")?;
446 /// let mut pubsub = client.get_async_pubsub().await?;
447 /// pubsub.psubscribe("channel*_1").await?;
448 /// pubsub.psubscribe(&["channel*_2", "channel*_3"]).await?;
449 /// # Ok(())
450 /// # }
451 /// ```
452 pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
453 self.sink.psubscribe(channel_pattern).await
454 }
455
456 /// Unsubscribes from channel pattern(s).
457 ///
458 /// ```rust,no_run
459 /// # #[cfg(feature = "aio")]
460 /// # async fn do_something() -> redis::RedisResult<()> {
461 /// let client = redis::Client::open("redis://127.0.0.1/")?;
462 /// let mut pubsub = client.get_async_pubsub().await?;
463 /// pubsub.psubscribe(&["channel_1", "channel_2"]).await?;
464 /// pubsub.punsubscribe(&["channel_1", "channel_2"]).await?;
465 /// # Ok(())
466 /// # }
467 /// ```
468 pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
469 self.sink.punsubscribe(channel_pattern).await
470 }
471
472 /// Sends a ping to the server
473 pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
474 self.sink.ping().await
475 }
476
477 /// Sends a ping with a message to the server
478 pub async fn ping_message<T: FromRedisValue>(
479 &mut self,
480 message: impl ToRedisArgs,
481 ) -> RedisResult<T> {
482 self.sink.ping_message(message).await
483 }
484
485 /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions.
486 ///
487 /// The message itself is still generic and can be converted into an appropriate type through
488 /// the helper methods on it.
489 pub fn on_message(&mut self) -> impl Stream<Item = Msg> + '_ {
490 &mut self.stream
491 }
492
493 /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions consuming it.
494 ///
495 /// The message itself is still generic and can be converted into an appropriate type through
496 /// the helper methods on it.
497 /// This can be useful in cases where the stream needs to be returned or held by something other
498 /// than the [`PubSub`].
499 pub fn into_on_message(self) -> PubSubStream {
500 self.stream
501 }
502
503 /// Splits the PubSub into separate sink and stream components, so that subscriptions could be
504 /// updated through the `Sink` while concurrently waiting for new messages on the `Stream`.
505 pub fn split(self) -> (PubSubSink, PubSubStream) {
506 (self.sink, self.stream)
507 }
508}
509
510impl Stream for PubSubStream {
511 type Item = Msg;
512
513 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
514 self.project().receiver.poll_recv(cx)
515 }
516}