1use crate::aio::Runtime;
2use crate::connection::{
3 check_connection_setup, connection_setup_pipeline, AuthResult, ConnectionSetupComponents,
4};
5#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
6use crate::parser::ValueCodec;
7use crate::types::{closed_connection_error, RedisError, RedisResult, Value};
8use crate::{cmd, from_owned_redis_value, FromRedisValue, Msg, RedisConnectionInfo, ToRedisArgs};
9use ::tokio::{
10 io::{AsyncRead, AsyncWrite},
11 sync::oneshot,
12};
13use futures_util::{
14 future::{Future, FutureExt},
15 ready,
16 sink::{Sink, SinkExt},
17 stream::{self, Stream, StreamExt},
18};
19use pin_project_lite::pin_project;
20use std::collections::VecDeque;
21use std::pin::Pin;
22use std::task::{self, Poll};
23use tokio::sync::mpsc::unbounded_channel;
24use tokio::sync::mpsc::UnboundedSender;
25#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
26use tokio_util::codec::Decoder;
27
28use super::SharedHandleContainer;
29
30type RequestResultSender = oneshot::Sender<RedisResult<Value>>;
32
33struct PipelineMessage {
35 input: Vec<u8>,
36 output: RequestResultSender,
37}
38
39#[derive(Clone)]
49pub struct PubSubSink {
50 sender: UnboundedSender<PipelineMessage>,
51}
52
53pin_project! {
54 pub struct PubSubStream {
64 #[pin]
65 receiver: tokio::sync::mpsc::UnboundedReceiver<Msg>,
66 _task_handle: Option<SharedHandleContainer>,
68 }
69}
70
71pin_project! {
72 struct PipelineSink<T> {
73 #[pin]
75 sink_stream: T,
76 in_flight: VecDeque<RequestResultSender>,
78 sender: UnboundedSender<Msg>,
80 }
81}
82
83impl<T> PipelineSink<T>
84where
85 T: Stream<Item = RedisResult<Value>> + 'static,
86{
87 fn new(sink_stream: T, sender: UnboundedSender<Msg>) -> Self
88 where
89 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
90 {
91 PipelineSink {
92 sink_stream,
93 in_flight: VecDeque::new(),
94 sender,
95 }
96 }
97
98 fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
100 loop {
101 let self_ = self.as_mut().project();
102 if self_.sender.is_closed() {
103 return Poll::Ready(Err(()));
104 }
105
106 let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) {
107 Some(result) => result,
108 None => return Poll::Ready(Err(())),
111 };
112 self.as_mut().handle_message(item)?;
113 }
114 }
115
116 fn handle_message(self: Pin<&mut Self>, result: RedisResult<Value>) -> Result<(), ()> {
117 let self_ = self.project();
118
119 match result {
120 Ok(Value::Array(value)) => {
121 if let Some(Value::BulkString(kind)) = value.first() {
122 if matches!(
123 kind.as_slice(),
124 b"subscribe" | b"psubscribe" | b"unsubscribe" | b"punsubscribe" | b"pong"
125 ) {
126 if let Some(entry) = self_.in_flight.pop_front() {
127 let _ = entry.send(Ok(Value::Array(value)));
128 };
129 return Ok(());
130 }
131 }
132
133 if let Some(msg) = Msg::from_owned_value(Value::Array(value)) {
134 let _ = self_.sender.send(msg);
135 Ok(())
136 } else {
137 Err(())
138 }
139 }
140
141 Ok(Value::Push { kind, data }) => {
142 if kind.has_reply() {
143 if let Some(entry) = self_.in_flight.pop_front() {
144 let _ = entry.send(Ok(Value::Push { kind, data }));
145 };
146 return Ok(());
147 }
148
149 if let Some(msg) = Msg::from_push_info(crate::PushInfo { kind, data }) {
150 let _ = self_.sender.send(msg);
151 Ok(())
152 } else {
153 Err(())
154 }
155 }
156
157 Err(err) if err.is_unrecoverable_error() => Err(()),
158
159 _ => {
160 if let Some(entry) = self_.in_flight.pop_front() {
161 let _ = entry.send(result);
162 Ok(())
163 } else {
164 Err(())
165 }
166 }
167 }
168 }
169}
170
171impl<T> Sink<PipelineMessage> for PipelineSink<T>
172where
173 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
174{
175 type Error = ();
176
177 fn poll_ready(
179 mut self: Pin<&mut Self>,
180 cx: &mut task::Context,
181 ) -> Poll<Result<(), Self::Error>> {
182 self.as_mut()
183 .project()
184 .sink_stream
185 .poll_ready(cx)
186 .map_err(|_| ())
187 }
188
189 fn start_send(
190 mut self: Pin<&mut Self>,
191 PipelineMessage { input, output }: PipelineMessage,
192 ) -> Result<(), Self::Error> {
193 let self_ = self.as_mut().project();
194
195 match self_.sink_stream.start_send(input) {
196 Ok(()) => {
197 self_.in_flight.push_back(output);
198 Ok(())
199 }
200 Err(err) => {
201 let _ = output.send(Err(err));
202 Err(())
203 }
204 }
205 }
206
207 fn poll_flush(
208 mut self: Pin<&mut Self>,
209 cx: &mut task::Context,
210 ) -> Poll<Result<(), Self::Error>> {
211 ready!(self
212 .as_mut()
213 .project()
214 .sink_stream
215 .poll_flush(cx)
216 .map_err(|err| {
217 let _ = self.as_mut().handle_message(Err(err));
218 }))?;
219 self.poll_read(cx)
220 }
221
222 fn poll_close(
223 mut self: Pin<&mut Self>,
224 cx: &mut task::Context,
225 ) -> Poll<Result<(), Self::Error>> {
226 if !self.in_flight.is_empty() {
229 ready!(self.as_mut().poll_flush(cx))?;
230 }
231 let this = self.as_mut().project();
232
233 if this.sender.is_closed() {
234 return Poll::Ready(Ok(()));
235 }
236
237 match ready!(this.sink_stream.poll_next(cx)) {
238 Some(result) => {
239 let _ = self.handle_message(result);
240 Poll::Pending
241 }
242 None => Poll::Ready(Ok(())),
243 }
244 }
245}
246
247impl PubSubSink {
248 fn new<T>(
249 sink_stream: T,
250 messages_sender: UnboundedSender<Msg>,
251 ) -> (Self, impl Future<Output = ()>)
252 where
253 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + Send + 'static,
254 T::Item: Send,
255 T::Error: Send,
256 T::Error: ::std::fmt::Debug,
257 {
258 let (sender, mut receiver) = unbounded_channel();
259 let sink = PipelineSink::new(sink_stream, messages_sender);
260 let f = stream::poll_fn(move |cx| {
261 let res = receiver.poll_recv(cx);
262 match res {
263 Poll::Ready(None) => Poll::Pending,
265 _ => res,
266 }
267 })
268 .map(Ok)
269 .forward(sink)
270 .map(|_| ());
271 (PubSubSink { sender }, f)
272 }
273
274 async fn send_recv(&mut self, input: Vec<u8>) -> Result<Value, RedisError> {
275 let (sender, receiver) = oneshot::channel();
276
277 self.sender
278 .send(PipelineMessage {
279 input,
280 output: sender,
281 })
282 .map_err(|_| closed_connection_error())?;
283 match receiver.await {
284 Ok(result) => result,
285 Err(_) => Err(closed_connection_error()),
286 }
287 }
288
289 pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
302 let cmd = cmd("SUBSCRIBE").arg(channel_name).get_packed_command();
303 self.send_recv(cmd).await.map(|_| ())
304 }
305
306 pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
319 let cmd = cmd("UNSUBSCRIBE").arg(channel_name).get_packed_command();
320 self.send_recv(cmd).await.map(|_| ())
321 }
322
323 pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
336 let cmd = cmd("PSUBSCRIBE").arg(channel_pattern).get_packed_command();
337 self.send_recv(cmd).await.map(|_| ())
338 }
339
340 pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
353 let cmd = cmd("PUNSUBSCRIBE")
354 .arg(channel_pattern)
355 .get_packed_command();
356 self.send_recv(cmd).await.map(|_| ())
357 }
358
359 pub async fn ping_message<T: FromRedisValue>(
361 &mut self,
362 message: impl ToRedisArgs,
363 ) -> RedisResult<T> {
364 let cmd = cmd("PING").arg(message).get_packed_command();
365 let response = self.send_recv(cmd).await?;
366 from_owned_redis_value(response)
367 }
368
369 pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
371 let cmd = cmd("PING").get_packed_command();
372 let response = self.send_recv(cmd).await?;
373 from_owned_redis_value(response)
374 }
375}
376
377pub struct PubSub {
379 sink: PubSubSink,
380 stream: PubSubStream,
381}
382
383async fn execute_connection_pipeline<T>(
384 codec: &mut T,
385 (pipeline, instructions): (crate::Pipeline, ConnectionSetupComponents),
386) -> RedisResult<AuthResult>
387where
388 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
389 T: Send + 'static,
390 T::Item: Send,
391 T::Error: Send,
392 T::Error: ::std::fmt::Debug,
393 T: Unpin,
394{
395 let count = pipeline.len();
396 if count == 0 {
397 return Ok(AuthResult::Succeeded);
398 }
399 codec.send(pipeline.get_packed_pipeline()).await?;
400
401 let mut results = Vec::with_capacity(count);
402 for _ in 0..count {
403 let value = codec.next().await;
404 match value {
405 Some(Ok(val)) => results.push(val),
406 _ => return Err(closed_connection_error()),
407 }
408 }
409
410 check_connection_setup(results, instructions)
411}
412
413async fn setup_connection<T>(
414 codec: &mut T,
415 connection_info: &RedisConnectionInfo,
416) -> RedisResult<()>
417where
418 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
419 T: Send + 'static,
420 T::Item: Send,
421 T::Error: Send,
422 T::Error: ::std::fmt::Debug,
423 T: Unpin,
424{
425 if execute_connection_pipeline(
426 codec,
427 connection_setup_pipeline(
428 connection_info,
429 true,
430 #[cfg(feature = "cache-aio")]
431 None,
432 ),
433 )
434 .await?
435 == AuthResult::ShouldRetryWithoutUsername
436 {
437 execute_connection_pipeline(
438 codec,
439 connection_setup_pipeline(
440 connection_info,
441 false,
442 #[cfg(feature = "cache-aio")]
443 None,
444 ),
445 )
446 .await?;
447 }
448
449 Ok(())
450}
451
452impl PubSub {
453 pub async fn new<C>(connection_info: &RedisConnectionInfo, stream: C) -> RedisResult<Self>
456 where
457 C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
458 {
459 #[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))]
460 compile_error!("tokio-comp or async-std-comp features required for aio feature");
461
462 let mut codec = ValueCodec::default().framed(stream);
463 setup_connection(&mut codec, connection_info).await?;
464 let (sender, receiver) = unbounded_channel();
465 let (sink, driver) = PubSubSink::new(codec, sender);
466 let handle = Runtime::locate().spawn(driver);
467 let _task_handle = Some(SharedHandleContainer::new(handle));
468 let stream = PubSubStream {
469 receiver,
470 _task_handle,
471 };
472 let con = PubSub { sink, stream };
473 Ok(con)
474 }
475
476 pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
490 self.sink.subscribe(channel_name).await
491 }
492
493 pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
507 self.sink.unsubscribe(channel_name).await
508 }
509
510 pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
523 self.sink.psubscribe(channel_pattern).await
524 }
525
526 pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
539 self.sink.punsubscribe(channel_pattern).await
540 }
541
542 pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
544 self.sink.ping().await
545 }
546
547 pub async fn ping_message<T: FromRedisValue>(
549 &mut self,
550 message: impl ToRedisArgs,
551 ) -> RedisResult<T> {
552 self.sink.ping_message(message).await
553 }
554
555 pub fn on_message(&mut self) -> impl Stream<Item = Msg> + '_ {
560 &mut self.stream
561 }
562
563 pub fn into_on_message(self) -> PubSubStream {
570 self.stream
571 }
572
573 pub fn split(self) -> (PubSubSink, PubSubStream) {
576 (self.sink, self.stream)
577 }
578}
579
580impl Stream for PubSubStream {
581 type Item = Msg;
582
583 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
584 self.project().receiver.poll_recv(cx)
585 }
586}