futures_util/stream/stream/
flatten_unordered.rs1use alloc::sync::Arc;
2use core::{
3    cell::UnsafeCell,
4    convert::identity,
5    fmt,
6    marker::PhantomData,
7    num::NonZeroUsize,
8    pin::Pin,
9    sync::atomic::{AtomicU8, Ordering},
10};
11
12use pin_project_lite::pin_project;
13
14use futures_core::{
15    future::Future,
16    ready,
17    stream::{FusedStream, Stream},
18    task::{Context, Poll, Waker},
19};
20#[cfg(feature = "sink")]
21use futures_sink::Sink;
22use futures_task::{waker, ArcWake};
23
24use crate::stream::FuturesUnordered;
25
26pub type FlattenUnordered<St> = FlattenUnorderedWithFlowController<St, ()>;
29
30const NONE: u8 = 0;
32
33const NEED_TO_POLL_INNER_STREAMS: u8 = 1;
35
36const NEED_TO_POLL_STREAM: u8 = 0b10;
38
39const NEED_TO_POLL_ALL: u8 = NEED_TO_POLL_INNER_STREAMS | NEED_TO_POLL_STREAM;
41
42const POLLING: u8 = 0b100;
44
45const WAKING: u8 = 0b1000;
47
48const WOKEN: u8 = 0b10000;
50
51#[derive(Clone, Debug)]
53struct SharedPollState {
54    state: Arc<AtomicU8>,
55}
56
57impl SharedPollState {
58    fn new(value: u8) -> Self {
60        Self { state: Arc::new(AtomicU8::new(value)) }
61    }
62
63    fn start_polling(&self) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&Self) -> u8>)> {
66        let value = self
67            .state
68            .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
69                if value & WAKING == NONE {
70                    Some(POLLING)
71                } else {
72                    None
73                }
74            })
75            .ok()?;
76        let bomb = PollStateBomb::new(self, Self::reset);
77
78        Some((value, bomb))
79    }
80
81    fn start_waking(
86        &self,
87        to_poll: u8,
88    ) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&Self) -> u8>)> {
89        let value = self
90            .state
91            .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
92                let mut next_value = value | to_poll;
93                if value & (WOKEN | POLLING) == NONE {
94                    next_value |= WAKING;
95                }
96
97                if next_value != value {
98                    Some(next_value)
99                } else {
100                    None
101                }
102            })
103            .ok()?;
104
105        if value & (WOKEN | POLLING | WAKING) == NONE {
107            let bomb = PollStateBomb::new(self, Self::stop_waking);
108
109            Some((value, bomb))
110        } else {
111            None
112        }
113    }
114
115    fn stop_polling(&self, to_poll: u8, will_be_woken: bool) -> u8 {
124        self.state
125            .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |mut value| {
126                let mut next_value = to_poll;
127
128                value &= NEED_TO_POLL_ALL;
129                if value != NONE || will_be_woken {
130                    next_value |= WOKEN;
131                }
132                next_value |= value;
133
134                Some(next_value & !POLLING & !WAKING)
135            })
136            .unwrap()
137    }
138
139    fn stop_waking(&self) -> u8 {
141        let value = self
142            .state
143            .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
144                let next_value = value & !WAKING | WOKEN;
145
146                if next_value != value {
147                    Some(next_value)
148                } else {
149                    None
150                }
151            })
152            .unwrap_or_else(identity);
153
154        debug_assert!(value & (WOKEN | POLLING | WAKING) == WAKING);
155        value
156    }
157
158    fn reset(&self) -> u8 {
160        self.state.swap(NEED_TO_POLL_ALL, Ordering::SeqCst)
161    }
162}
163
164struct PollStateBomb<'a, F: FnOnce(&SharedPollState) -> u8> {
166    state: &'a SharedPollState,
167    drop: Option<F>,
168}
169
170impl<'a, F: FnOnce(&SharedPollState) -> u8> PollStateBomb<'a, F> {
171    fn new(state: &'a SharedPollState, drop: F) -> Self {
173        Self { state, drop: Some(drop) }
174    }
175
176    fn deactivate(mut self) {
178        self.drop.take();
179    }
180}
181
182impl<F: FnOnce(&SharedPollState) -> u8> Drop for PollStateBomb<'_, F> {
183    fn drop(&mut self) {
184        if let Some(drop) = self.drop.take() {
185            (drop)(self.state);
186        }
187    }
188}
189
190struct WrappedWaker {
193    inner_waker: UnsafeCell<Option<Waker>>,
194    poll_state: SharedPollState,
195    need_to_poll: u8,
196}
197
198unsafe impl Send for WrappedWaker {}
199unsafe impl Sync for WrappedWaker {}
200
201impl WrappedWaker {
202    unsafe fn replace_waker(self_arc: &mut Arc<Self>, cx: &Context<'_>) {
211        unsafe { *self_arc.inner_waker.get() = cx.waker().clone().into() }
212    }
213
214    fn start_waking(&self) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> {
217        self.poll_state.start_waking(self.need_to_poll)
218    }
219}
220
221impl ArcWake for WrappedWaker {
222    fn wake_by_ref(self_arc: &Arc<Self>) {
223        if let Some((_, state_bomb)) = self_arc.start_waking() {
224            let waker_opt = unsafe { self_arc.inner_waker.get().as_ref().unwrap() };
226
227            if let Some(inner_waker) = waker_opt.clone() {
228                drop(state_bomb);
230
231                inner_waker.wake();
233            }
234        }
235    }
236}
237
238pin_project! {
239    #[must_use = "futures do nothing unless you `.await` or poll them"]
248    struct PollStreamFut<St> {
249        #[pin]
250        stream: Option<St>,
251    }
252}
253
254impl<St> PollStreamFut<St> {
255    fn new(stream: impl Into<Option<St>>) -> Self {
257        Self { stream: stream.into() }
258    }
259}
260
261impl<St: Stream + Unpin> Future for PollStreamFut<St> {
262    type Output = Option<(St::Item, Self)>;
263
264    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
265        let mut stream = self.project().stream;
266
267        let item = if let Some(stream) = stream.as_mut().as_pin_mut() {
268            ready!(stream.poll_next(cx))
269        } else {
270            None
271        };
272        let next_item_fut = Self::new(stream.get_mut().take());
273        let out = item.map(|item| (item, next_item_fut));
274
275        Poll::Ready(out)
276    }
277}
278
279pin_project! {
280    #[project = FlattenUnorderedWithFlowControllerProj]
283    #[must_use = "streams do nothing unless polled"]
284    pub struct FlattenUnorderedWithFlowController<St, Fc> where St: Stream {
285        #[pin]
286        inner_streams: FuturesUnordered<PollStreamFut<St::Item>>,
287        #[pin]
288        stream: St,
289        poll_state: SharedPollState,
290        limit: Option<NonZeroUsize>,
291        is_stream_done: bool,
292        inner_streams_waker: Arc<WrappedWaker>,
293        stream_waker: Arc<WrappedWaker>,
294        flow_controller: PhantomData<Fc>
295    }
296}
297
298impl<St, Fc> fmt::Debug for FlattenUnorderedWithFlowController<St, Fc>
299where
300    St: Stream + fmt::Debug,
301    St::Item: Stream + fmt::Debug,
302{
303    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304        f.debug_struct("FlattenUnorderedWithFlowController")
305            .field("poll_state", &self.poll_state)
306            .field("inner_streams", &self.inner_streams)
307            .field("limit", &self.limit)
308            .field("stream", &self.stream)
309            .field("is_stream_done", &self.is_stream_done)
310            .field("flow_controller", &self.flow_controller)
311            .finish()
312    }
313}
314
315impl<St, Fc> FlattenUnorderedWithFlowController<St, Fc>
316where
317    St: Stream,
318    Fc: FlowController<St::Item, <St::Item as Stream>::Item>,
319    St::Item: Stream + Unpin,
320{
321    pub(crate) fn new(stream: St, limit: Option<usize>) -> Self {
322        let poll_state = SharedPollState::new(NEED_TO_POLL_STREAM);
323
324        Self {
325            inner_streams: FuturesUnordered::new(),
326            stream,
327            is_stream_done: false,
328            limit: limit.and_then(NonZeroUsize::new),
329            inner_streams_waker: Arc::new(WrappedWaker {
330                inner_waker: UnsafeCell::new(None),
331                poll_state: poll_state.clone(),
332                need_to_poll: NEED_TO_POLL_INNER_STREAMS,
333            }),
334            stream_waker: Arc::new(WrappedWaker {
335                inner_waker: UnsafeCell::new(None),
336                poll_state: poll_state.clone(),
337                need_to_poll: NEED_TO_POLL_STREAM,
338            }),
339            poll_state,
340            flow_controller: PhantomData,
341        }
342    }
343
344    delegate_access_inner!(stream, St, ());
345}
346
347pub trait FlowController<I, O> {
349    fn next_step(item: I) -> FlowStep<I, O>;
351}
352
353impl<I, O> FlowController<I, O> for () {
354    fn next_step(item: I) -> FlowStep<I, O> {
355        FlowStep::Continue(item)
356    }
357}
358
359#[derive(Debug, Clone)]
361pub enum FlowStep<C, R> {
362    Continue(C),
364    Return(R),
366}
367
368impl<St, Fc> FlattenUnorderedWithFlowControllerProj<'_, St, Fc>
369where
370    St: Stream,
371{
372    fn is_exceeded_limit(&self) -> bool {
374        self.limit.map_or(false, |limit| self.inner_streams.len() >= limit.get())
375    }
376}
377
378impl<St, Fc> FusedStream for FlattenUnorderedWithFlowController<St, Fc>
379where
380    St: FusedStream,
381    Fc: FlowController<St::Item, <St::Item as Stream>::Item>,
382    St::Item: Stream + Unpin,
383{
384    fn is_terminated(&self) -> bool {
385        self.stream.is_terminated() && self.inner_streams.is_empty()
386    }
387}
388
389impl<St, Fc> Stream for FlattenUnorderedWithFlowController<St, Fc>
390where
391    St: Stream,
392    Fc: FlowController<St::Item, <St::Item as Stream>::Item>,
393    St::Item: Stream + Unpin,
394{
395    type Item = <St::Item as Stream>::Item;
396
397    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
398        let mut next_item = None;
399        let mut need_to_poll_next = NONE;
400
401        let mut this = self.as_mut().project();
402
403        let (mut poll_state_value, state_bomb) = loop {
405            if let Some(value) = this.poll_state.start_polling() {
406                break value;
407            }
408        };
409
410        unsafe {
412            WrappedWaker::replace_waker(this.stream_waker, cx);
413            WrappedWaker::replace_waker(this.inner_streams_waker, cx)
414        };
415
416        if poll_state_value & NEED_TO_POLL_STREAM != NONE {
417            let mut stream_waker = None;
418
419            loop {
424                if this.is_exceeded_limit() || *this.is_stream_done {
425                    if !*this.is_stream_done {
427                        need_to_poll_next |= NEED_TO_POLL_STREAM;
429                    }
430
431                    break;
432                } else {
433                    let mut cx = Context::from_waker(
434                        stream_waker.get_or_insert_with(|| waker(this.stream_waker.clone())),
435                    );
436
437                    match this.stream.as_mut().poll_next(&mut cx) {
438                        Poll::Ready(Some(item)) => {
439                            let next_item_fut = match Fc::next_step(item) {
440                                FlowStep::Return(item) => {
442                                    need_to_poll_next |= NEED_TO_POLL_STREAM
443                                        | (poll_state_value & NEED_TO_POLL_INNER_STREAMS);
444                                    poll_state_value &= !NEED_TO_POLL_INNER_STREAMS;
445
446                                    next_item = Some(item);
447
448                                    break;
449                                }
450                                FlowStep::Continue(inner_stream) => {
452                                    PollStreamFut::new(inner_stream)
453                                }
454                            };
455                            this.inner_streams.as_mut().push(next_item_fut);
457                            poll_state_value |= NEED_TO_POLL_INNER_STREAMS;
459                        }
460                        Poll::Ready(None) => {
461                            *this.is_stream_done = true;
463                        }
464                        Poll::Pending => {
465                            break;
466                        }
467                    }
468                }
469            }
470        }
471
472        if poll_state_value & NEED_TO_POLL_INNER_STREAMS != NONE {
473            let inner_streams_waker = waker(this.inner_streams_waker.clone());
474            let mut cx = Context::from_waker(&inner_streams_waker);
475
476            match this.inner_streams.as_mut().poll_next(&mut cx) {
477                Poll::Ready(Some(Some((item, next_item_fut)))) => {
478                    this.inner_streams.as_mut().push(next_item_fut);
480                    next_item = Some(item);
482                    need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS;
484                }
485                Poll::Ready(Some(None)) => {
486                    need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS;
488                }
489                _ => {}
490            }
491        }
492
493        state_bomb.deactivate();
495
496        let mut force_wake =
498            need_to_poll_next & NEED_TO_POLL_STREAM != NONE && !this.is_exceeded_limit()
500            || need_to_poll_next & NEED_TO_POLL_INNER_STREAMS != NONE;
502
503        poll_state_value = this.poll_state.stop_polling(need_to_poll_next, force_wake);
505        force_wake |= poll_state_value & NEED_TO_POLL_ALL != NONE;
507
508        let is_done = *this.is_stream_done && this.inner_streams.is_empty();
509
510        if next_item.is_some() || is_done {
511            Poll::Ready(next_item)
512        } else {
513            if force_wake {
514                cx.waker().wake_by_ref();
515            }
516
517            Poll::Pending
518        }
519    }
520}
521
522#[cfg(feature = "sink")]
524impl<St, Item, Fc> Sink<Item> for FlattenUnorderedWithFlowController<St, Fc>
525where
526    St: Stream + Sink<Item>,
527{
528    type Error = St::Error;
529
530    delegate_sink!(stream, Item);
531}