blob: 49d529b1f6cafea959b956f4bf86d8981108e8fc [file] [log] [blame]
Stjepan Glavina921e8a02020-01-06 14:31:28 -06001use core::fmt;
2use core::future::Future;
3use core::marker::{PhantomData, Unpin};
4use core::pin::Pin;
5use core::ptr::NonNull;
6use core::sync::atomic::Ordering;
7use core::task::{Context, Poll};
Stjepan Glavina1479e862019-08-12 20:18:51 +02008
9use crate::header::Header;
10use crate::state::*;
11use crate::utils::abort_on_panic;
12
13/// A handle that awaits the result of a task.
14///
Stjepan Glavina7a8962b2019-08-16 11:25:25 +020015/// This type is a future that resolves to an `Option<R>` where:
Stjepan Glavina1479e862019-08-12 20:18:51 +020016///
Stjepan Glavina5c398cf2019-08-20 15:29:43 +020017/// * `None` indicates the task has panicked or was cancelled.
18/// * `Some(result)` indicates the task has completed with `result` of type `R`.
Stjepan Glavina1479e862019-08-12 20:18:51 +020019pub struct JoinHandle<R, T> {
20 /// A raw task pointer.
21 pub(crate) raw_task: NonNull<()>,
22
Stjepan Glavina5c398cf2019-08-20 15:29:43 +020023 /// A marker capturing generic types `R` and `T`.
Stjepan Glavina1479e862019-08-12 20:18:51 +020024 pub(crate) _marker: PhantomData<(R, T)>,
25}
26
Stjepan Glavinafcfa4ab2019-11-25 18:39:17 +010027unsafe impl<R: Send, T> Send for JoinHandle<R, T> {}
Stjepan Glavina1479e862019-08-12 20:18:51 +020028unsafe impl<R, T> Sync for JoinHandle<R, T> {}
29
30impl<R, T> Unpin for JoinHandle<R, T> {}
31
32impl<R, T> JoinHandle<R, T> {
33 /// Cancels the task.
34 ///
Stjepan Glavina7a8962b2019-08-16 11:25:25 +020035 /// If the task has already completed, calling this method will have no effect.
Stjepan Glavina1479e862019-08-12 20:18:51 +020036 ///
Stjepan Glavina5c398cf2019-08-20 15:29:43 +020037 /// When a task is cancelled, its future will not be polled again.
Stjepan Glavina1479e862019-08-12 20:18:51 +020038 pub fn cancel(&self) {
39 let ptr = self.raw_task.as_ptr();
40 let header = ptr as *const Header;
41
42 unsafe {
43 let mut state = (*header).state.load(Ordering::Acquire);
44
45 loop {
46 // If the task has been completed or closed, it can't be cancelled.
47 if state & (COMPLETED | CLOSED) != 0 {
48 break;
49 }
50
51 // If the task is not scheduled nor running, we'll need to schedule it.
52 let new = if state & (SCHEDULED | RUNNING) == 0 {
53 (state | SCHEDULED | CLOSED) + REFERENCE
54 } else {
55 state | CLOSED
56 };
57
58 // Mark the task as closed.
59 match (*header).state.compare_exchange_weak(
60 state,
61 new,
62 Ordering::AcqRel,
63 Ordering::Acquire,
64 ) {
65 Ok(_) => {
Stjepan Glavina5c398cf2019-08-20 15:29:43 +020066 // If the task is not scheduled nor running, schedule it one more time so
67 // that its future gets dropped by the executor.
Stjepan Glavina1479e862019-08-12 20:18:51 +020068 if state & (SCHEDULED | RUNNING) == 0 {
69 ((*header).vtable.schedule)(ptr);
70 }
71
72 // Notify the awaiter that the task has been closed.
73 if state & AWAITER != 0 {
Stjepan Glavina921e8a02020-01-06 14:31:28 -060074 (*header).notify(None);
Stjepan Glavina1479e862019-08-12 20:18:51 +020075 }
76
77 break;
78 }
79 Err(s) => state = s,
80 }
81 }
82 }
83 }
84
85 /// Returns a reference to the tag stored inside the task.
Stjepan Glavina1479e862019-08-12 20:18:51 +020086 pub fn tag(&self) -> &T {
87 let offset = Header::offset_tag::<T>();
88 let ptr = self.raw_task.as_ptr();
89
90 unsafe {
91 let raw = (ptr as *mut u8).add(offset) as *const T;
92 &*raw
93 }
94 }
95}
96
97impl<R, T> Drop for JoinHandle<R, T> {
98 fn drop(&mut self) {
99 let ptr = self.raw_task.as_ptr();
100 let header = ptr as *const Header;
101
102 // A place where the output will be stored in case it needs to be dropped.
103 let mut output = None;
104
105 unsafe {
106 // Optimistically assume the `JoinHandle` is being dropped just after creating the
107 // task. This is a common case so if the handle is not used, the overhead of it is only
108 // one compare-exchange operation.
109 if let Err(mut state) = (*header).state.compare_exchange_weak(
110 SCHEDULED | HANDLE | REFERENCE,
111 SCHEDULED | REFERENCE,
112 Ordering::AcqRel,
113 Ordering::Acquire,
114 ) {
115 loop {
116 // If the task has been completed but not yet closed, that means its output
117 // must be dropped.
118 if state & COMPLETED != 0 && state & CLOSED == 0 {
119 // Mark the task as closed in order to grab its output.
120 match (*header).state.compare_exchange_weak(
121 state,
122 state | CLOSED,
123 Ordering::AcqRel,
124 Ordering::Acquire,
125 ) {
126 Ok(_) => {
127 // Read the output.
128 output =
129 Some((((*header).vtable.get_output)(ptr) as *mut R).read());
130
131 // Update the state variable because we're continuing the loop.
132 state |= CLOSED;
133 }
134 Err(s) => state = s,
135 }
136 } else {
Stjepan Glavina7a8962b2019-08-16 11:25:25 +0200137 // If this is the last reference to the task and it's not closed, then
138 // close it and schedule one more time so that its future gets dropped by
139 // the executor.
Stjepan Glavina1479e862019-08-12 20:18:51 +0200140 let new = if state & (!(REFERENCE - 1) | CLOSED) == 0 {
141 SCHEDULED | CLOSED | REFERENCE
142 } else {
143 state & !HANDLE
144 };
145
146 // Unset the handle flag.
147 match (*header).state.compare_exchange_weak(
148 state,
149 new,
150 Ordering::AcqRel,
151 Ordering::Acquire,
152 ) {
153 Ok(_) => {
154 // If this is the last reference to the task, we need to either
155 // schedule dropping its future or destroy it.
156 if state & !(REFERENCE - 1) == 0 {
157 if state & CLOSED == 0 {
158 ((*header).vtable.schedule)(ptr);
159 } else {
160 ((*header).vtable.destroy)(ptr);
161 }
162 }
163
164 break;
165 }
166 Err(s) => state = s,
167 }
168 }
169 }
170 }
171 }
172
173 // Drop the output if it was taken out of the task.
174 drop(output);
175 }
176}
177
178impl<R, T> Future for JoinHandle<R, T> {
179 type Output = Option<R>;
180
181 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
182 let ptr = self.raw_task.as_ptr();
183 let header = ptr as *const Header;
184
185 unsafe {
186 let mut state = (*header).state.load(Ordering::Acquire);
187
188 loop {
189 // If the task has been closed, notify the awaiter and return `None`.
190 if state & CLOSED != 0 {
191 // Even though the awaiter is most likely the current task, it could also be
192 // another task.
Stjepan Glavina921e8a02020-01-06 14:31:28 -0600193 (*header).notify(Some(cx.waker()));
Stjepan Glavina1479e862019-08-12 20:18:51 +0200194 return Poll::Ready(None);
195 }
196
197 // If the task is not completed, register the current task.
198 if state & COMPLETED == 0 {
199 // Replace the waker with one associated with the current task. We need a
200 // safeguard against panics because dropping the previous waker can panic.
201 abort_on_panic(|| {
Stjepan Glavina921e8a02020-01-06 14:31:28 -0600202 (*header).register(cx.waker());
Stjepan Glavina1479e862019-08-12 20:18:51 +0200203 });
204
205 // Reload the state after registering. It is possible that the task became
206 // completed or closed just before registration so we need to check for that.
207 state = (*header).state.load(Ordering::Acquire);
208
laizy2b0427a2019-11-20 21:55:50 +0800209 // If the task has been closed, return `None`. We do not need to notify the
210 // awaiter here, since we have replaced the waker above, and the executor can
211 // only set it back to `None`.
Stjepan Glavina1479e862019-08-12 20:18:51 +0200212 if state & CLOSED != 0 {
Stjepan Glavina1479e862019-08-12 20:18:51 +0200213 return Poll::Ready(None);
214 }
215
216 // If the task is still not completed, we're blocked on it.
217 if state & COMPLETED == 0 {
218 return Poll::Pending;
219 }
220 }
221
222 // Since the task is now completed, mark it as closed in order to grab its output.
223 match (*header).state.compare_exchange(
224 state,
225 state | CLOSED,
226 Ordering::AcqRel,
227 Ordering::Acquire,
228 ) {
229 Ok(_) => {
230 // Notify the awaiter. Even though the awaiter is most likely the current
231 // task, it could also be another task.
232 if state & AWAITER != 0 {
Stjepan Glavina921e8a02020-01-06 14:31:28 -0600233 (*header).notify(Some(cx.waker()));
Stjepan Glavina1479e862019-08-12 20:18:51 +0200234 }
235
236 // Take the output from the task.
237 let output = ((*header).vtable.get_output)(ptr) as *mut R;
238 return Poll::Ready(Some(output.read()));
239 }
240 Err(s) => state = s,
241 }
242 }
243 }
244 }
245}
246
247impl<R, T> fmt::Debug for JoinHandle<R, T> {
248 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249 let ptr = self.raw_task.as_ptr();
250 let header = ptr as *const Header;
251
252 f.debug_struct("JoinHandle")
253 .field("header", unsafe { &(*header) })
254 .finish()
255 }
256}