blob: d4fed504bdae9c0fbf26146ebdecf6d7dd8cbc7d [file] [log] [blame]
Stjepan Glavina1479e862019-08-12 20:18:51 +02001use std::fmt;
2use std::future::Future;
3use std::marker::{PhantomData, Unpin};
4use std::pin::Pin;
5use std::ptr::NonNull;
6use std::sync::atomic::Ordering;
7use std::task::{Context, Poll};
8
9use crate::header::Header;
10use crate::state::*;
11use crate::utils::abort_on_panic;
12
13/// A handle that awaits the result of a task.
14///
Stjepan Glavina7a8962b2019-08-16 11:25:25 +020015/// This type is a future that resolves to an `Option<R>` where:
Stjepan Glavina1479e862019-08-12 20:18:51 +020016///
Stjepan Glavina7a8962b2019-08-16 11:25:25 +020017/// * `None` indicates the task has panicked or was cancelled
18/// * `Some(res)` indicates the task has completed with `res`
Stjepan Glavina1479e862019-08-12 20:18:51 +020019pub struct JoinHandle<R, T> {
20 /// A raw task pointer.
21 pub(crate) raw_task: NonNull<()>,
22
23 /// A marker capturing the generic type `R`.
24 pub(crate) _marker: PhantomData<(R, T)>,
25}
26
27unsafe impl<R, T> Send for JoinHandle<R, T> {}
28unsafe impl<R, T> Sync for JoinHandle<R, T> {}
29
30impl<R, T> Unpin for JoinHandle<R, T> {}
31
32impl<R, T> JoinHandle<R, T> {
33 /// Cancels the task.
34 ///
Stjepan Glavina7a8962b2019-08-16 11:25:25 +020035 /// If the task has already completed, calling this method will have no effect.
Stjepan Glavina1479e862019-08-12 20:18:51 +020036 ///
Stjepan Glavina7a8962b2019-08-16 11:25:25 +020037 /// When a task is cancelled, its future cannot be polled again and will be dropped instead.
Stjepan Glavina1479e862019-08-12 20:18:51 +020038 pub fn cancel(&self) {
39 let ptr = self.raw_task.as_ptr();
40 let header = ptr as *const Header;
41
42 unsafe {
43 let mut state = (*header).state.load(Ordering::Acquire);
44
45 loop {
46 // If the task has been completed or closed, it can't be cancelled.
47 if state & (COMPLETED | CLOSED) != 0 {
48 break;
49 }
50
51 // If the task is not scheduled nor running, we'll need to schedule it.
52 let new = if state & (SCHEDULED | RUNNING) == 0 {
53 (state | SCHEDULED | CLOSED) + REFERENCE
54 } else {
55 state | CLOSED
56 };
57
58 // Mark the task as closed.
59 match (*header).state.compare_exchange_weak(
60 state,
61 new,
62 Ordering::AcqRel,
63 Ordering::Acquire,
64 ) {
65 Ok(_) => {
66 // If the task is not scheduled nor running, schedule it so that its future
67 // gets dropped by the executor.
68 if state & (SCHEDULED | RUNNING) == 0 {
69 ((*header).vtable.schedule)(ptr);
70 }
71
72 // Notify the awaiter that the task has been closed.
73 if state & AWAITER != 0 {
74 (*header).notify();
75 }
76
77 break;
78 }
79 Err(s) => state = s,
80 }
81 }
82 }
83 }
84
85 /// Returns a reference to the tag stored inside the task.
Stjepan Glavina1479e862019-08-12 20:18:51 +020086 pub fn tag(&self) -> &T {
87 let offset = Header::offset_tag::<T>();
88 let ptr = self.raw_task.as_ptr();
89
90 unsafe {
91 let raw = (ptr as *mut u8).add(offset) as *const T;
92 &*raw
93 }
94 }
95}
96
97impl<R, T> Drop for JoinHandle<R, T> {
98 fn drop(&mut self) {
99 let ptr = self.raw_task.as_ptr();
100 let header = ptr as *const Header;
101
102 // A place where the output will be stored in case it needs to be dropped.
103 let mut output = None;
104
105 unsafe {
106 // Optimistically assume the `JoinHandle` is being dropped just after creating the
107 // task. This is a common case so if the handle is not used, the overhead of it is only
108 // one compare-exchange operation.
109 if let Err(mut state) = (*header).state.compare_exchange_weak(
110 SCHEDULED | HANDLE | REFERENCE,
111 SCHEDULED | REFERENCE,
112 Ordering::AcqRel,
113 Ordering::Acquire,
114 ) {
115 loop {
116 // If the task has been completed but not yet closed, that means its output
117 // must be dropped.
118 if state & COMPLETED != 0 && state & CLOSED == 0 {
119 // Mark the task as closed in order to grab its output.
120 match (*header).state.compare_exchange_weak(
121 state,
122 state | CLOSED,
123 Ordering::AcqRel,
124 Ordering::Acquire,
125 ) {
126 Ok(_) => {
127 // Read the output.
128 output =
129 Some((((*header).vtable.get_output)(ptr) as *mut R).read());
130
131 // Update the state variable because we're continuing the loop.
132 state |= CLOSED;
133 }
134 Err(s) => state = s,
135 }
136 } else {
Stjepan Glavina7a8962b2019-08-16 11:25:25 +0200137 // If this is the last reference to the task and it's not closed, then
138 // close it and schedule one more time so that its future gets dropped by
139 // the executor.
Stjepan Glavina1479e862019-08-12 20:18:51 +0200140 let new = if state & (!(REFERENCE - 1) | CLOSED) == 0 {
141 SCHEDULED | CLOSED | REFERENCE
142 } else {
143 state & !HANDLE
144 };
145
146 // Unset the handle flag.
147 match (*header).state.compare_exchange_weak(
148 state,
149 new,
150 Ordering::AcqRel,
151 Ordering::Acquire,
152 ) {
153 Ok(_) => {
154 // If this is the last reference to the task, we need to either
155 // schedule dropping its future or destroy it.
156 if state & !(REFERENCE - 1) == 0 {
157 if state & CLOSED == 0 {
158 ((*header).vtable.schedule)(ptr);
159 } else {
160 ((*header).vtable.destroy)(ptr);
161 }
162 }
163
164 break;
165 }
166 Err(s) => state = s,
167 }
168 }
169 }
170 }
171 }
172
173 // Drop the output if it was taken out of the task.
174 drop(output);
175 }
176}
177
178impl<R, T> Future for JoinHandle<R, T> {
179 type Output = Option<R>;
180
181 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
182 let ptr = self.raw_task.as_ptr();
183 let header = ptr as *const Header;
184
185 unsafe {
186 let mut state = (*header).state.load(Ordering::Acquire);
187
188 loop {
189 // If the task has been closed, notify the awaiter and return `None`.
190 if state & CLOSED != 0 {
191 // Even though the awaiter is most likely the current task, it could also be
192 // another task.
193 (*header).notify_unless(cx.waker());
194 return Poll::Ready(None);
195 }
196
197 // If the task is not completed, register the current task.
198 if state & COMPLETED == 0 {
199 // Replace the waker with one associated with the current task. We need a
200 // safeguard against panics because dropping the previous waker can panic.
201 abort_on_panic(|| {
202 (*header).swap_awaiter(Some(cx.waker().clone()));
203 });
204
205 // Reload the state after registering. It is possible that the task became
206 // completed or closed just before registration so we need to check for that.
207 state = (*header).state.load(Ordering::Acquire);
208
209 // If the task has been closed, notify the awaiter and return `None`.
210 if state & CLOSED != 0 {
211 // Even though the awaiter is most likely the current task, it could also
212 // be another task.
213 (*header).notify_unless(cx.waker());
214 return Poll::Ready(None);
215 }
216
217 // If the task is still not completed, we're blocked on it.
218 if state & COMPLETED == 0 {
219 return Poll::Pending;
220 }
221 }
222
223 // Since the task is now completed, mark it as closed in order to grab its output.
224 match (*header).state.compare_exchange(
225 state,
226 state | CLOSED,
227 Ordering::AcqRel,
228 Ordering::Acquire,
229 ) {
230 Ok(_) => {
231 // Notify the awaiter. Even though the awaiter is most likely the current
232 // task, it could also be another task.
233 if state & AWAITER != 0 {
234 (*header).notify_unless(cx.waker());
235 }
236
237 // Take the output from the task.
238 let output = ((*header).vtable.get_output)(ptr) as *mut R;
239 return Poll::Ready(Some(output.read()));
240 }
241 Err(s) => state = s,
242 }
243 }
244 }
245 }
246}
247
248impl<R, T> fmt::Debug for JoinHandle<R, T> {
249 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
250 let ptr = self.raw_task.as_ptr();
251 let header = ptr as *const Header;
252
253 f.debug_struct("JoinHandle")
254 .field("header", unsafe { &(*header) })
255 .finish()
256 }
257}