futures_util/stream/stream/
flatten_unordered.rs1use alloc::sync::Arc;
2use core::{
3 cell::UnsafeCell,
4 convert::identity,
5 fmt,
6 marker::PhantomData,
7 num::NonZeroUsize,
8 pin::Pin,
9 sync::atomic::{AtomicU8, Ordering},
10};
11
12use pin_project_lite::pin_project;
13
14use futures_core::{
15 future::Future,
16 ready,
17 stream::{FusedStream, Stream},
18 task::{Context, Poll, Waker},
19};
20#[cfg(feature = "sink")]
21use futures_sink::Sink;
22use futures_task::{waker, ArcWake};
23
24use crate::stream::FuturesUnordered;
25
26pub type FlattenUnordered<St> = FlattenUnorderedWithFlowController<St, ()>;
29
30const NONE: u8 = 0;
32
33const NEED_TO_POLL_INNER_STREAMS: u8 = 1;
35
36const NEED_TO_POLL_STREAM: u8 = 0b10;
38
39const NEED_TO_POLL_ALL: u8 = NEED_TO_POLL_INNER_STREAMS | NEED_TO_POLL_STREAM;
41
42const POLLING: u8 = 0b100;
44
45const WAKING: u8 = 0b1000;
47
48const WOKEN: u8 = 0b10000;
50
51#[derive(Clone, Debug)]
53struct SharedPollState {
54 state: Arc<AtomicU8>,
55}
56
57impl SharedPollState {
58 fn new(value: u8) -> Self {
60 Self { state: Arc::new(AtomicU8::new(value)) }
61 }
62
63 fn start_polling(&self) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&Self) -> u8>)> {
66 let value = self
67 .state
68 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
69 if value & WAKING == NONE {
70 Some(POLLING)
71 } else {
72 None
73 }
74 })
75 .ok()?;
76 let bomb = PollStateBomb::new(self, Self::reset);
77
78 Some((value, bomb))
79 }
80
81 fn start_waking(
86 &self,
87 to_poll: u8,
88 ) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&Self) -> u8>)> {
89 let value = self
90 .state
91 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
92 let mut next_value = value | to_poll;
93 if value & (WOKEN | POLLING) == NONE {
94 next_value |= WAKING;
95 }
96
97 if next_value != value {
98 Some(next_value)
99 } else {
100 None
101 }
102 })
103 .ok()?;
104
105 if value & (WOKEN | POLLING | WAKING) == NONE {
107 let bomb = PollStateBomb::new(self, Self::stop_waking);
108
109 Some((value, bomb))
110 } else {
111 None
112 }
113 }
114
115 fn stop_polling(&self, to_poll: u8, will_be_woken: bool) -> u8 {
124 self.state
125 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |mut value| {
126 let mut next_value = to_poll;
127
128 value &= NEED_TO_POLL_ALL;
129 if value != NONE || will_be_woken {
130 next_value |= WOKEN;
131 }
132 next_value |= value;
133
134 Some(next_value & !POLLING & !WAKING)
135 })
136 .unwrap()
137 }
138
139 fn stop_waking(&self) -> u8 {
141 let value = self
142 .state
143 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
144 let next_value = value & !WAKING | WOKEN;
145
146 if next_value != value {
147 Some(next_value)
148 } else {
149 None
150 }
151 })
152 .unwrap_or_else(identity);
153
154 debug_assert!(value & (WOKEN | POLLING | WAKING) == WAKING);
155 value
156 }
157
158 fn reset(&self) -> u8 {
160 self.state.swap(NEED_TO_POLL_ALL, Ordering::SeqCst)
161 }
162}
163
164struct PollStateBomb<'a, F: FnOnce(&SharedPollState) -> u8> {
166 state: &'a SharedPollState,
167 drop: Option<F>,
168}
169
170impl<'a, F: FnOnce(&SharedPollState) -> u8> PollStateBomb<'a, F> {
171 fn new(state: &'a SharedPollState, drop: F) -> Self {
173 Self { state, drop: Some(drop) }
174 }
175
176 fn deactivate(mut self) {
178 self.drop.take();
179 }
180}
181
182impl<F: FnOnce(&SharedPollState) -> u8> Drop for PollStateBomb<'_, F> {
183 fn drop(&mut self) {
184 if let Some(drop) = self.drop.take() {
185 (drop)(self.state);
186 }
187 }
188}
189
190struct WrappedWaker {
193 inner_waker: UnsafeCell<Option<Waker>>,
194 poll_state: SharedPollState,
195 need_to_poll: u8,
196}
197
198unsafe impl Send for WrappedWaker {}
199unsafe impl Sync for WrappedWaker {}
200
201impl WrappedWaker {
202 unsafe fn replace_waker(self_arc: &mut Arc<Self>, cx: &Context<'_>) {
211 unsafe { *self_arc.inner_waker.get() = cx.waker().clone().into() }
212 }
213
214 fn start_waking(&self) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> {
217 self.poll_state.start_waking(self.need_to_poll)
218 }
219}
220
221impl ArcWake for WrappedWaker {
222 fn wake_by_ref(self_arc: &Arc<Self>) {
223 if let Some((_, state_bomb)) = self_arc.start_waking() {
224 let waker_opt = unsafe { self_arc.inner_waker.get().as_ref().unwrap() };
226
227 if let Some(inner_waker) = waker_opt.clone() {
228 drop(state_bomb);
230
231 inner_waker.wake();
233 }
234 }
235 }
236}
237
238pin_project! {
239 #[must_use = "futures do nothing unless you `.await` or poll them"]
248 struct PollStreamFut<St> {
249 #[pin]
250 stream: Option<St>,
251 }
252}
253
254impl<St> PollStreamFut<St> {
255 fn new(stream: impl Into<Option<St>>) -> Self {
257 Self { stream: stream.into() }
258 }
259}
260
261impl<St: Stream + Unpin> Future for PollStreamFut<St> {
262 type Output = Option<(St::Item, Self)>;
263
264 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
265 let mut stream = self.project().stream;
266
267 let item = if let Some(stream) = stream.as_mut().as_pin_mut() {
268 ready!(stream.poll_next(cx))
269 } else {
270 None
271 };
272 let next_item_fut = Self::new(stream.get_mut().take());
273 let out = item.map(|item| (item, next_item_fut));
274
275 Poll::Ready(out)
276 }
277}
278
279pin_project! {
280 #[project = FlattenUnorderedWithFlowControllerProj]
283 #[must_use = "streams do nothing unless polled"]
284 pub struct FlattenUnorderedWithFlowController<St, Fc> where St: Stream {
285 #[pin]
286 inner_streams: FuturesUnordered<PollStreamFut<St::Item>>,
287 #[pin]
288 stream: St,
289 poll_state: SharedPollState,
290 limit: Option<NonZeroUsize>,
291 is_stream_done: bool,
292 inner_streams_waker: Arc<WrappedWaker>,
293 stream_waker: Arc<WrappedWaker>,
294 flow_controller: PhantomData<Fc>
295 }
296}
297
298impl<St, Fc> fmt::Debug for FlattenUnorderedWithFlowController<St, Fc>
299where
300 St: Stream + fmt::Debug,
301 St::Item: Stream + fmt::Debug,
302{
303 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304 f.debug_struct("FlattenUnorderedWithFlowController")
305 .field("poll_state", &self.poll_state)
306 .field("inner_streams", &self.inner_streams)
307 .field("limit", &self.limit)
308 .field("stream", &self.stream)
309 .field("is_stream_done", &self.is_stream_done)
310 .field("flow_controller", &self.flow_controller)
311 .finish()
312 }
313}
314
315impl<St, Fc> FlattenUnorderedWithFlowController<St, Fc>
316where
317 St: Stream,
318 Fc: FlowController<St::Item, <St::Item as Stream>::Item>,
319 St::Item: Stream + Unpin,
320{
321 pub(crate) fn new(stream: St, limit: Option<usize>) -> Self {
322 let poll_state = SharedPollState::new(NEED_TO_POLL_STREAM);
323
324 Self {
325 inner_streams: FuturesUnordered::new(),
326 stream,
327 is_stream_done: false,
328 limit: limit.and_then(NonZeroUsize::new),
329 inner_streams_waker: Arc::new(WrappedWaker {
330 inner_waker: UnsafeCell::new(None),
331 poll_state: poll_state.clone(),
332 need_to_poll: NEED_TO_POLL_INNER_STREAMS,
333 }),
334 stream_waker: Arc::new(WrappedWaker {
335 inner_waker: UnsafeCell::new(None),
336 poll_state: poll_state.clone(),
337 need_to_poll: NEED_TO_POLL_STREAM,
338 }),
339 poll_state,
340 flow_controller: PhantomData,
341 }
342 }
343
344 delegate_access_inner!(stream, St, ());
345}
346
347pub trait FlowController<I, O> {
349 fn next_step(item: I) -> FlowStep<I, O>;
351}
352
353impl<I, O> FlowController<I, O> for () {
354 fn next_step(item: I) -> FlowStep<I, O> {
355 FlowStep::Continue(item)
356 }
357}
358
359#[derive(Debug, Clone)]
361pub enum FlowStep<C, R> {
362 Continue(C),
364 Return(R),
366}
367
368impl<St, Fc> FlattenUnorderedWithFlowControllerProj<'_, St, Fc>
369where
370 St: Stream,
371{
372 fn is_exceeded_limit(&self) -> bool {
374 self.limit.map_or(false, |limit| self.inner_streams.len() >= limit.get())
375 }
376}
377
378impl<St, Fc> FusedStream for FlattenUnorderedWithFlowController<St, Fc>
379where
380 St: FusedStream,
381 Fc: FlowController<St::Item, <St::Item as Stream>::Item>,
382 St::Item: Stream + Unpin,
383{
384 fn is_terminated(&self) -> bool {
385 self.stream.is_terminated() && self.inner_streams.is_empty()
386 }
387}
388
389impl<St, Fc> Stream for FlattenUnorderedWithFlowController<St, Fc>
390where
391 St: Stream,
392 Fc: FlowController<St::Item, <St::Item as Stream>::Item>,
393 St::Item: Stream + Unpin,
394{
395 type Item = <St::Item as Stream>::Item;
396
397 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
398 let mut next_item = None;
399 let mut need_to_poll_next = NONE;
400
401 let mut this = self.as_mut().project();
402
403 let (mut poll_state_value, state_bomb) = loop {
405 if let Some(value) = this.poll_state.start_polling() {
406 break value;
407 }
408 };
409
410 unsafe {
412 WrappedWaker::replace_waker(this.stream_waker, cx);
413 WrappedWaker::replace_waker(this.inner_streams_waker, cx)
414 };
415
416 if poll_state_value & NEED_TO_POLL_STREAM != NONE {
417 let mut stream_waker = None;
418
419 loop {
424 if this.is_exceeded_limit() || *this.is_stream_done {
425 if !*this.is_stream_done {
427 need_to_poll_next |= NEED_TO_POLL_STREAM;
429 }
430
431 break;
432 } else {
433 let mut cx = Context::from_waker(
434 stream_waker.get_or_insert_with(|| waker(this.stream_waker.clone())),
435 );
436
437 match this.stream.as_mut().poll_next(&mut cx) {
438 Poll::Ready(Some(item)) => {
439 let next_item_fut = match Fc::next_step(item) {
440 FlowStep::Return(item) => {
442 need_to_poll_next |= NEED_TO_POLL_STREAM
443 | (poll_state_value & NEED_TO_POLL_INNER_STREAMS);
444 poll_state_value &= !NEED_TO_POLL_INNER_STREAMS;
445
446 next_item = Some(item);
447
448 break;
449 }
450 FlowStep::Continue(inner_stream) => {
452 PollStreamFut::new(inner_stream)
453 }
454 };
455 this.inner_streams.as_mut().push(next_item_fut);
457 poll_state_value |= NEED_TO_POLL_INNER_STREAMS;
459 }
460 Poll::Ready(None) => {
461 *this.is_stream_done = true;
463 }
464 Poll::Pending => {
465 break;
466 }
467 }
468 }
469 }
470 }
471
472 if poll_state_value & NEED_TO_POLL_INNER_STREAMS != NONE {
473 let inner_streams_waker = waker(this.inner_streams_waker.clone());
474 let mut cx = Context::from_waker(&inner_streams_waker);
475
476 match this.inner_streams.as_mut().poll_next(&mut cx) {
477 Poll::Ready(Some(Some((item, next_item_fut)))) => {
478 this.inner_streams.as_mut().push(next_item_fut);
480 next_item = Some(item);
482 need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS;
484 }
485 Poll::Ready(Some(None)) => {
486 need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS;
488 }
489 _ => {}
490 }
491 }
492
493 state_bomb.deactivate();
495
496 let mut force_wake =
498 need_to_poll_next & NEED_TO_POLL_STREAM != NONE && !this.is_exceeded_limit()
500 || need_to_poll_next & NEED_TO_POLL_INNER_STREAMS != NONE;
502
503 poll_state_value = this.poll_state.stop_polling(need_to_poll_next, force_wake);
505 force_wake |= poll_state_value & NEED_TO_POLL_ALL != NONE;
507
508 let is_done = *this.is_stream_done && this.inner_streams.is_empty();
509
510 if next_item.is_some() || is_done {
511 Poll::Ready(next_item)
512 } else {
513 if force_wake {
514 cx.waker().wake_by_ref();
515 }
516
517 Poll::Pending
518 }
519 }
520}
521
522#[cfg(feature = "sink")]
524impl<St, Item, Fc> Sink<Item> for FlattenUnorderedWithFlowController<St, Fc>
525where
526 St: Stream + Sink<Item>,
527{
528 type Error = St::Error;
529
530 delegate_sink!(stream, Item);
531}