1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#[cfg(test)]
mod tests;

use crate::fmt;
use crate::sync::{Condvar, Mutex};

/// 屏障使多个线程能够同步某些计算的开始。
///
///
/// # Examples
///
/// ```
/// use std::sync::{Arc, Barrier};
/// use std::thread;
///
/// let mut handles = Vec::with_capacity(10);
/// let barrier = Arc::new(Barrier::new(10));
/// for _ in 0..10 {
///     let c = Arc::clone(&barrier);
///     // 相同的消息将一起打印。
///     // 您将看不到任何交错。
///     handles.push(thread::spawn(move|| {
///         println!("before wait");
///         c.wait();
///         println!("after wait");
///     }));
/// }
/// // 等待其他线程完成。
/// for handle in handles {
///     handle.join().unwrap();
/// }
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub struct Barrier {
    lock: Mutex<BarrierState>,
    cvar: Condvar,
    num_threads: usize,
}

// 双重屏障的内部状态
struct BarrierState {
    count: usize,
    generation_id: usize,
}

/// 当 [`Barrier`] 中的所有线程都汇合时,[`Barrier::wait()`] 将返回 `BarrierWaitResult`。
///
///
/// # Examples
///
/// ```
/// use std::sync::Barrier;
///
/// let barrier = Barrier::new(1);
/// let barrier_wait_result = barrier.wait();
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub struct BarrierWaitResult(bool);

#[stable(feature = "std_debug", since = "1.16.0")]
impl fmt::Debug for Barrier {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("Barrier").finish_non_exhaustive()
    }
}

impl Barrier {
    /// 创建一个新的屏障,该屏障可以阻止给定数量的线程。
    ///
    /// 屏障将阻塞调用 [`wait()`] 的 `n - 1` 个线程,然后在第 n 个线程调用 [`wait()`] 时立即唤醒所有线程。
    ///
    ///
    /// [`wait()`]: Barrier::wait
    ///
    /// # Examples
    ///
    /// ```
    /// use std::sync::Barrier;
    ///
    /// let barrier = Barrier::new(10);
    /// ```
    #[stable(feature = "rust1", since = "1.0.0")]
    pub fn new(n: usize) -> Barrier {
        Barrier {
            lock: Mutex::new(BarrierState { count: 0, generation_id: 0 }),
            cvar: Condvar::new(),
            num_threads: n,
        }
    }

    /// 阻塞当前线程,直到所有线程都在此处集合为止。
    ///
    /// 所有线程集合一次后,屏障可以重新使用,并且可以连续使用。
    ///
    /// 从该函数返回时,单个 (arbitrary) 线程将接收从 [`BarrierWaitResult::is_leader()`] 返回 `true` 的 [`BarrierWaitResult`],而所有其他线程将接收从 [`BarrierWaitResult::is_leader()`] 返回 `false` 的结果。
    ///
    ///
    /// # Examples
    ///
    /// ```
    /// use std::sync::{Arc, Barrier};
    /// use std::thread;
    ///
    /// let mut handles = Vec::with_capacity(10);
    /// let barrier = Arc::new(Barrier::new(10));
    /// for _ in 0..10 {
    ///     let c = Arc::clone(&barrier);
    ///     // 相同的消息将一起打印。
    ///     // 您将看不到任何交错。
    ///     handles.push(thread::spawn(move|| {
    ///         println!("before wait");
    ///         c.wait();
    ///         println!("after wait");
    ///     }));
    /// }
    /// // 等待其他线程完成。
    /// for handle in handles {
    ///     handle.join().unwrap();
    /// }
    /// ```
    ///
    ///
    ///
    #[stable(feature = "rust1", since = "1.0.0")]
    pub fn wait(&self) -> BarrierWaitResult {
        let mut lock = self.lock.lock().unwrap();
        let local_gen = lock.generation_id;
        lock.count += 1;
        if lock.count < self.num_threads {
            // 我们需要一个 while 循环来防止虚假唤醒。
            // https://en.wikipedia.org/wiki/Spurious_wakeup
            while local_gen == lock.generation_id && lock.count < self.num_threads {
                lock = self.cvar.wait(lock).unwrap();
            }
            BarrierWaitResult(false)
        } else {
            lock.count = 0;
            lock.generation_id = lock.generation_id.wrapping_add(1);
            self.cvar.notify_all();
            BarrierWaitResult(true)
        }
    }
}

#[stable(feature = "std_debug", since = "1.16.0")]
impl fmt::Debug for BarrierWaitResult {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("BarrierWaitResult").field("is_leader", &self.is_leader()).finish()
    }
}

impl BarrierWaitResult {
    /// 如果此线程是调用 [`Barrier::wait()`] 的 "leader thread",则返回 `true`。
    ///
    /// 只有一个线程从其结果返回 `true`,所有其他线程将返回 `false`。
    ///
    ///
    /// # Examples
    ///
    /// ```
    /// use std::sync::Barrier;
    ///
    /// let barrier = Barrier::new(1);
    /// let barrier_wait_result = barrier.wait();
    /// println!("{:?}", barrier_wait_result.is_leader());
    /// ```
    ///
    #[stable(feature = "rust1", since = "1.0.0")]
    pub fn is_leader(&self) -> bool {
        self.0
    }
}