Add waker and conversion to raw pointer (#16)
diff --git a/src/join_handle.rs b/src/join_handle.rs
index 49d529b..f4710cc 100644
--- a/src/join_handle.rs
+++ b/src/join_handle.rs
@@ -4,7 +4,7 @@
use core::pin::Pin;
use core::ptr::NonNull;
use core::sync::atomic::Ordering;
-use core::task::{Context, Poll};
+use core::task::{Context, Poll, Waker};
use crate::header::Header;
use crate::state::*;
@@ -92,6 +92,17 @@
&*raw
}
}
+
+ /// Returns a waker associated with the task.
+ pub fn waker(&self) -> Waker {
+ let ptr = self.raw_task.as_ptr();
+ let header = ptr as *const Header;
+
+ unsafe {
+ let raw_waker = ((*header).vtable.clone_waker)(ptr);
+ Waker::from_raw(raw_waker)
+ }
+ }
}
impl<R, T> Drop for JoinHandle<R, T> {
diff --git a/src/raw.rs b/src/raw.rs
index 6af184f..c783d26 100644
--- a/src/raw.rs
+++ b/src/raw.rs
@@ -35,6 +35,9 @@
/// Runs the task.
pub(crate) run: unsafe fn(*const ()),
+
+ /// Creates a new waker associated with the task.
+ pub(crate) clone_waker: unsafe fn(ptr: *const ()) -> RawWaker,
}
/// Memory layout of a task.
@@ -131,6 +134,7 @@
drop_task: Self::drop_task,
destroy: Self::destroy,
run: Self::run,
+ clone_waker: Self::clone_waker,
},
});
diff --git a/src/task.rs b/src/task.rs
index 80953f4..b26c082 100644
--- a/src/task.rs
+++ b/src/task.rs
@@ -4,7 +4,7 @@
use core::mem::{self, ManuallyDrop};
use core::pin::Pin;
use core::ptr::NonNull;
-use core::task::{Context, Poll};
+use core::task::{Context, Poll, Waker};
use crate::header::Header;
use crate::raw::RawTask;
@@ -264,6 +264,41 @@
&*raw
}
}
+
+ /// Converts this task into a raw pointer to the tag.
+ pub fn into_raw(self) -> *const T {
+ let offset = Header::offset_tag::<T>();
+ let ptr = self.raw_task.as_ptr();
+ mem::forget(self);
+
+ unsafe { (ptr as *mut u8).add(offset) as *const T }
+ }
+
+ /// Converts a raw pointer to the tag into a task.
+ ///
+ /// This method should only be used with raw pointers returned from [`into_raw`].
+ ///
+ /// [`into_raw`]: #method.into_raw
+ pub unsafe fn from_raw(raw: *const T) -> Task<T> {
+ let offset = Header::offset_tag::<T>();
+ let ptr = (raw as *mut u8).sub(offset) as *mut ();
+
+ Task {
+ raw_task: NonNull::new_unchecked(ptr),
+ _marker: PhantomData,
+ }
+ }
+
+ /// Returns a waker associated with this task.
+ pub fn waker(&self) -> Waker {
+ let ptr = self.raw_task.as_ptr();
+ let header = ptr as *const Header;
+
+ unsafe {
+ let raw_waker = ((*header).vtable.clone_waker)(ptr);
+ Waker::from_raw(raw_waker)
+ }
+ }
}
impl<T> Drop for Task<T> {
diff --git a/tests/basic.rs b/tests/basic.rs
index 8426d2a..432e14c 100644
--- a/tests/basic.rs
+++ b/tests/basic.rs
@@ -332,3 +332,39 @@
);
task.schedule();
}
+
+#[test]
+fn waker() {
+ let (s, r) = channel::unbounded();
+ let schedule = move |t| s.send(t).unwrap();
+ let (task, handle) = async_task::spawn(
+ future::poll_fn(|_| Poll::<()>::Pending),
+ schedule,
+ Box::new(0),
+ );
+
+ assert!(r.is_empty());
+ let w = task.waker();
+ task.run();
+ w.wake();
+
+ let task = r.recv().unwrap();
+ task.run();
+ handle.waker().wake();
+
+ r.recv().unwrap();
+}
+
+#[test]
+fn raw() {
+ let (task, _handle) = async_task::spawn(async {}, |_| panic!(), Box::new(AtomicUsize::new(7)));
+
+ let a = task.into_raw();
+ let task = unsafe {
+ (*a).fetch_add(1, Ordering::SeqCst);
+ Task::from_raw(a)
+ };
+
+ assert_eq!(task.tag().load(Ordering::SeqCst), 8);
+ task.run();
+}