tokio/sync/barrier.rs
1use crate::loom::sync::Mutex;
2use crate::sync::watch;
3#[cfg(all(tokio_unstable, feature = "tracing"))]
4use crate::util::trace;
5
6/// A barrier enables multiple tasks to synchronize the beginning of some computation.
7///
8/// ```
9/// # #[tokio::main]
10/// # async fn main() {
11/// use tokio::sync::Barrier;
12/// use std::sync::Arc;
13///
14/// let mut handles = Vec::with_capacity(10);
15/// let barrier = Arc::new(Barrier::new(10));
16/// for _ in 0..10 {
17///     let c = barrier.clone();
18///     // The same messages will be printed together.
19///     // You will NOT see any interleaving.
20///     handles.push(tokio::spawn(async move {
21///         println!("before wait");
22///         let wait_result = c.wait().await;
23///         println!("after wait");
24///         wait_result
25///     }));
26/// }
27///
28/// // Will not resolve until all "after wait" messages have been printed
29/// let mut num_leaders = 0;
30/// for handle in handles {
31///     let wait_result = handle.await.unwrap();
32///     if wait_result.is_leader() {
33///         num_leaders += 1;
34///     }
35/// }
36///
37/// // Exactly one barrier will resolve as the "leader"
38/// assert_eq!(num_leaders, 1);
39/// # }
40/// ```
41#[derive(Debug)]
42pub struct Barrier {
43    state: Mutex<BarrierState>,
44    wait: watch::Receiver<usize>,
45    n: usize,
46    #[cfg(all(tokio_unstable, feature = "tracing"))]
47    resource_span: tracing::Span,
48}
49
50#[derive(Debug)]
51struct BarrierState {
52    waker: watch::Sender<usize>,
53    arrived: usize,
54    generation: usize,
55}
56
57impl Barrier {
58    /// Creates a new barrier that can block a given number of tasks.
59    ///
60    /// A barrier will block `n`-1 tasks which call [`Barrier::wait`] and then wake up all
61    /// tasks at once when the `n`th task calls `wait`.
62    #[track_caller]
63    pub fn new(mut n: usize) -> Barrier {
64        let (waker, wait) = crate::sync::watch::channel(0);
65
66        if n == 0 {
67            // if n is 0, it's not clear what behavior the user wants.
68            // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every
69            // .wait() immediately unblocks, so we adopt that here as well.
70            n = 1;
71        }
72
73        #[cfg(all(tokio_unstable, feature = "tracing"))]
74        let resource_span = {
75            let location = std::panic::Location::caller();
76            let resource_span = tracing::trace_span!(
77                parent: None,
78                "runtime.resource",
79                concrete_type = "Barrier",
80                kind = "Sync",
81                loc.file = location.file(),
82                loc.line = location.line(),
83                loc.col = location.column(),
84            );
85
86            resource_span.in_scope(|| {
87                tracing::trace!(
88                    target: "runtime::resource::state_update",
89                    size = n,
90                );
91
92                tracing::trace!(
93                    target: "runtime::resource::state_update",
94                    arrived = 0,
95                )
96            });
97            resource_span
98        };
99
100        Barrier {
101            state: Mutex::new(BarrierState {
102                waker,
103                arrived: 0,
104                generation: 1,
105            }),
106            n,
107            wait,
108            #[cfg(all(tokio_unstable, feature = "tracing"))]
109            resource_span,
110        }
111    }
112
113    /// Does not resolve until all tasks have rendezvoused here.
114    ///
115    /// Barriers are re-usable after all tasks have rendezvoused once, and can
116    /// be used continuously.
117    ///
118    /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from
119    /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other tasks
120    /// will receive a result that will return `false` from `is_leader`.
121    ///
122    /// # Cancel safety
123    ///
124    /// This method is not cancel safe.
125    pub async fn wait(&self) -> BarrierWaitResult {
126        #[cfg(all(tokio_unstable, feature = "tracing"))]
127        return trace::async_op(
128            || self.wait_internal(),
129            self.resource_span.clone(),
130            "Barrier::wait",
131            "poll",
132            false,
133        )
134        .await;
135
136        #[cfg(any(not(tokio_unstable), not(feature = "tracing")))]
137        return self.wait_internal().await;
138    }
139    async fn wait_internal(&self) -> BarrierWaitResult {
140        crate::trace::async_trace_leaf().await;
141
142        // NOTE: we are taking a _synchronous_ lock here.
143        // It is okay to do so because the critical section is fast and never yields, so it cannot
144        // deadlock even if another future is concurrently holding the lock.
145        // It is _desirable_ to do so as synchronous Mutexes are, at least in theory, faster than
146        // the asynchronous counter-parts, so we should use them where possible [citation needed].
147        // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across
148        // a yield point, and thus marks the returned future as !Send.
149        let generation = {
150            let mut state = self.state.lock();
151            let generation = state.generation;
152            state.arrived += 1;
153            #[cfg(all(tokio_unstable, feature = "tracing"))]
154            tracing::trace!(
155                target: "runtime::resource::state_update",
156                arrived = 1,
157                arrived.op = "add",
158            );
159            #[cfg(all(tokio_unstable, feature = "tracing"))]
160            tracing::trace!(
161                target: "runtime::resource::async_op::state_update",
162                arrived = true,
163            );
164            if state.arrived == self.n {
165                #[cfg(all(tokio_unstable, feature = "tracing"))]
166                tracing::trace!(
167                    target: "runtime::resource::async_op::state_update",
168                    is_leader = true,
169                );
170                // we are the leader for this generation
171                // wake everyone, increment the generation, and return
172                state
173                    .waker
174                    .send(state.generation)
175                    .expect("there is at least one receiver");
176                state.arrived = 0;
177                state.generation += 1;
178                return BarrierWaitResult(true);
179            }
180
181            generation
182        };
183
184        // we're going to have to wait for the last of the generation to arrive
185        let mut wait = self.wait.clone();
186
187        loop {
188            let _ = wait.changed().await;
189
190            // note that the first time through the loop, this _will_ yield a generation
191            // immediately, since we cloned a receiver that has never seen any values.
192            if *wait.borrow() >= generation {
193                break;
194            }
195        }
196
197        BarrierWaitResult(false)
198    }
199}
200
201/// A `BarrierWaitResult` is returned by `wait` when all tasks in the `Barrier` have rendezvoused.
202#[derive(Debug, Clone)]
203pub struct BarrierWaitResult(bool);
204
205impl BarrierWaitResult {
206    /// Returns `true` if this task from wait is the "leader task".
207    ///
208    /// Only one task will have `true` returned from their result, all other tasks will have
209    /// `false` returned.
210    pub fn is_leader(&self) -> bool {
211        self.0
212    }
213}