blob: fa98bae0b0157ded37a001893bc725aba2285709 [file] [log] [blame]
Jeff Vander Stoepbf372732021-02-18 09:39:46 +01001//! Futures task based helpers
2
3#![allow(clippy::mutex_atomic)]
4
5use std::future::Future;
6use std::mem;
7use std::ops;
8use std::pin::Pin;
9use std::sync::{Arc, Condvar, Mutex};
10use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
11
12use tokio_stream::Stream;
13
14/// TODO: dox
15pub fn spawn<T>(task: T) -> Spawn<T> {
16 Spawn {
17 task: MockTask::new(),
18 future: Box::pin(task),
19 }
20}
21
22/// Future spawned on a mock task
23#[derive(Debug)]
24pub struct Spawn<T> {
25 task: MockTask,
26 future: Pin<Box<T>>,
27}
28
29/// Mock task
30///
31/// A mock task is able to intercept and track wake notifications.
32#[derive(Debug, Clone)]
33struct MockTask {
34 waker: Arc<ThreadWaker>,
35}
36
37#[derive(Debug)]
38struct ThreadWaker {
39 state: Mutex<usize>,
40 condvar: Condvar,
41}
42
43const IDLE: usize = 0;
44const WAKE: usize = 1;
45const SLEEP: usize = 2;
46
47impl<T> Spawn<T> {
48 /// Consumes `self` returning the inner value
49 pub fn into_inner(self) -> T
50 where
51 T: Unpin,
52 {
53 *Pin::into_inner(self.future)
54 }
55
56 /// Returns `true` if the inner future has received a wake notification
57 /// since the last call to `enter`.
58 pub fn is_woken(&self) -> bool {
59 self.task.is_woken()
60 }
61
62 /// Returns the number of references to the task waker
63 ///
64 /// The task itself holds a reference. The return value will never be zero.
65 pub fn waker_ref_count(&self) -> usize {
66 self.task.waker_ref_count()
67 }
68
69 /// Enter the task context
70 pub fn enter<F, R>(&mut self, f: F) -> R
71 where
72 F: FnOnce(&mut Context<'_>, Pin<&mut T>) -> R,
73 {
74 let fut = self.future.as_mut();
75 self.task.enter(|cx| f(cx, fut))
76 }
77}
78
79impl<T: Unpin> ops::Deref for Spawn<T> {
80 type Target = T;
81
82 fn deref(&self) -> &T {
83 &self.future
84 }
85}
86
87impl<T: Unpin> ops::DerefMut for Spawn<T> {
88 fn deref_mut(&mut self) -> &mut T {
89 &mut self.future
90 }
91}
92
93impl<T: Future> Spawn<T> {
94 /// Polls a future
95 pub fn poll(&mut self) -> Poll<T::Output> {
96 let fut = self.future.as_mut();
97 self.task.enter(|cx| fut.poll(cx))
98 }
99}
100
101impl<T: Stream> Spawn<T> {
102 /// Polls a stream
103 pub fn poll_next(&mut self) -> Poll<Option<T::Item>> {
104 let stream = self.future.as_mut();
105 self.task.enter(|cx| stream.poll_next(cx))
106 }
107}
108
109impl<T: Future> Future for Spawn<T> {
110 type Output = T::Output;
111
112 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
113 self.future.as_mut().poll(cx)
114 }
115}
116
117impl<T: Stream> Stream for Spawn<T> {
118 type Item = T::Item;
119
120 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
121 self.future.as_mut().poll_next(cx)
122 }
123}
124
125impl MockTask {
126 /// Creates new mock task
127 fn new() -> Self {
128 MockTask {
129 waker: Arc::new(ThreadWaker::new()),
130 }
131 }
132
133 /// Runs a closure from the context of the task.
134 ///
135 /// Any wake notifications resulting from the execution of the closure are
136 /// tracked.
137 fn enter<F, R>(&mut self, f: F) -> R
138 where
139 F: FnOnce(&mut Context<'_>) -> R,
140 {
141 self.waker.clear();
142 let waker = self.waker();
143 let mut cx = Context::from_waker(&waker);
144
145 f(&mut cx)
146 }
147
148 /// Returns `true` if the inner future has received a wake notification
149 /// since the last call to `enter`.
150 fn is_woken(&self) -> bool {
151 self.waker.is_woken()
152 }
153
154 /// Returns the number of references to the task waker
155 ///
156 /// The task itself holds a reference. The return value will never be zero.
157 fn waker_ref_count(&self) -> usize {
158 Arc::strong_count(&self.waker)
159 }
160
161 fn waker(&self) -> Waker {
162 unsafe {
163 let raw = to_raw(self.waker.clone());
164 Waker::from_raw(raw)
165 }
166 }
167}
168
169impl Default for MockTask {
170 fn default() -> Self {
171 Self::new()
172 }
173}
174
175impl ThreadWaker {
176 fn new() -> Self {
177 ThreadWaker {
178 state: Mutex::new(IDLE),
179 condvar: Condvar::new(),
180 }
181 }
182
183 /// Clears any previously received wakes, avoiding potential spurrious
184 /// wake notifications. This should only be called immediately before running the
185 /// task.
186 fn clear(&self) {
187 *self.state.lock().unwrap() = IDLE;
188 }
189
190 fn is_woken(&self) -> bool {
191 match *self.state.lock().unwrap() {
192 IDLE => false,
193 WAKE => true,
194 _ => unreachable!(),
195 }
196 }
197
198 fn wake(&self) {
199 // First, try transitioning from IDLE -> NOTIFY, this does not require a lock.
200 let mut state = self.state.lock().unwrap();
201 let prev = *state;
202
203 if prev == WAKE {
204 return;
205 }
206
207 *state = WAKE;
208
209 if prev == IDLE {
210 return;
211 }
212
213 // The other half is sleeping, so we wake it up.
214 assert_eq!(prev, SLEEP);
215 self.condvar.notify_one();
216 }
217}
218
219static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker);
220
221unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker {
222 RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE)
223}
224
225unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> {
226 Arc::from_raw(raw as *const ThreadWaker)
227}
228
229unsafe fn clone(raw: *const ()) -> RawWaker {
230 let waker = from_raw(raw);
231
232 // Increment the ref count
233 mem::forget(waker.clone());
234
235 to_raw(waker)
236}
237
238unsafe fn wake(raw: *const ()) {
239 let waker = from_raw(raw);
240 waker.wake();
241}
242
243unsafe fn wake_by_ref(raw: *const ()) {
244 let waker = from_raw(raw);
245 waker.wake();
246
247 // We don't actually own a reference to the unparker
248 mem::forget(waker);
249}
250
251unsafe fn drop_waker(raw: *const ()) {
252 let _ = from_raw(raw);
253}