Import async-trait crate.
Bug: 164106399
Bug: 158290206
Test: mm
Change-Id: I0c4a491b785134e4ac1d5572e5ff94b8699dc0bc
diff --git a/src/args.rs b/src/args.rs
new file mode 100644
index 0000000..72d97e9
--- /dev/null
+++ b/src/args.rs
@@ -0,0 +1,36 @@
+use proc_macro2::Span;
+use syn::parse::{Error, Parse, ParseStream, Result};
+use syn::Token;
+
+#[derive(Copy, Clone)]
+pub struct Args {
+ pub local: bool,
+}
+
+mod kw {
+ syn::custom_keyword!(Send);
+}
+
+impl Parse for Args {
+ fn parse(input: ParseStream) -> Result<Self> {
+ match try_parse(input) {
+ Ok(args) if input.is_empty() => Ok(args),
+ _ => Err(error()),
+ }
+ }
+}
+
+fn try_parse(input: ParseStream) -> Result<Args> {
+ if input.peek(Token![?]) {
+ input.parse::<Token![?]>()?;
+ input.parse::<kw::Send>()?;
+ Ok(Args { local: true })
+ } else {
+ Ok(Args { local: false })
+ }
+}
+
+fn error() -> Error {
+ let msg = "expected #[async_trait] or #[async_trait(?Send)]";
+ Error::new(Span::call_site(), msg)
+}
diff --git a/src/expand.rs b/src/expand.rs
new file mode 100644
index 0000000..7d0cec2
--- /dev/null
+++ b/src/expand.rs
@@ -0,0 +1,475 @@
+use crate::lifetime::{has_async_lifetime, CollectLifetimes};
+use crate::parse::Item;
+use crate::receiver::{
+ has_self_in_block, has_self_in_sig, has_self_in_where_predicate, ReplaceReceiver,
+};
+use proc_macro2::{Span, TokenStream};
+use quote::{format_ident, quote, quote_spanned, ToTokens};
+use std::mem;
+use syn::punctuated::Punctuated;
+use syn::visit_mut::VisitMut;
+use syn::{
+ parse_quote, Block, FnArg, GenericParam, Generics, Ident, ImplItem, Lifetime, Pat, PatIdent,
+ Path, Receiver, ReturnType, Signature, Stmt, Token, TraitItem, Type, TypeParam, TypeParamBound,
+ WhereClause,
+};
+
+impl ToTokens for Item {
+ fn to_tokens(&self, tokens: &mut TokenStream) {
+ match self {
+ Item::Trait(item) => item.to_tokens(tokens),
+ Item::Impl(item) => item.to_tokens(tokens),
+ }
+ }
+}
+
+#[derive(Clone, Copy)]
+enum Context<'a> {
+ Trait {
+ name: &'a Ident,
+ generics: &'a Generics,
+ supertraits: &'a Supertraits,
+ },
+ Impl {
+ impl_generics: &'a Generics,
+ receiver: &'a Type,
+ as_trait: &'a Path,
+ },
+}
+
+impl Context<'_> {
+ fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a GenericParam> {
+ let generics = match self {
+ Context::Trait { generics, .. } => generics,
+ Context::Impl { impl_generics, .. } => impl_generics,
+ };
+ generics.params.iter().filter(move |param| {
+ if let GenericParam::Lifetime(param) = param {
+ used.contains(¶m.lifetime)
+ } else {
+ false
+ }
+ })
+ }
+}
+
+type Supertraits = Punctuated<TypeParamBound, Token![+]>;
+
+pub fn expand(input: &mut Item, is_local: bool) {
+ match input {
+ Item::Trait(input) => {
+ let context = Context::Trait {
+ name: &input.ident,
+ generics: &input.generics,
+ supertraits: &input.supertraits,
+ };
+ for inner in &mut input.items {
+ if let TraitItem::Method(method) = inner {
+ let sig = &mut method.sig;
+ if sig.asyncness.is_some() {
+ let block = &mut method.default;
+ let mut has_self = has_self_in_sig(sig);
+ if let Some(block) = block {
+ has_self |= has_self_in_block(block);
+ transform_block(context, sig, block, has_self, is_local);
+ }
+ let has_default = method.default.is_some();
+ transform_sig(context, sig, has_self, has_default, is_local);
+ method.attrs.push(parse_quote!(#[must_use]));
+ }
+ }
+ }
+ }
+ Item::Impl(input) => {
+ let mut lifetimes = CollectLifetimes::new("'impl");
+ lifetimes.visit_type_mut(&mut *input.self_ty);
+ lifetimes.visit_path_mut(&mut input.trait_.as_mut().unwrap().1);
+ let params = &input.generics.params;
+ let elided = lifetimes.elided;
+ input.generics.params = parse_quote!(#(#elided,)* #params);
+
+ let context = Context::Impl {
+ impl_generics: &input.generics,
+ receiver: &input.self_ty,
+ as_trait: &input.trait_.as_ref().unwrap().1,
+ };
+ for inner in &mut input.items {
+ if let ImplItem::Method(method) = inner {
+ let sig = &mut method.sig;
+ if sig.asyncness.is_some() {
+ let block = &mut method.block;
+ let has_self = has_self_in_sig(sig) || has_self_in_block(block);
+ transform_block(context, sig, block, has_self, is_local);
+ transform_sig(context, sig, has_self, false, is_local);
+ }
+ }
+ }
+ }
+ }
+}
+
+// Input:
+// async fn f<T>(&self, x: &T) -> Ret;
+//
+// Output:
+// fn f<'life0, 'life1, 'async_trait, T>(
+// &'life0 self,
+// x: &'life1 T,
+// ) -> Pin<Box<dyn Future<Output = Ret> + Send + 'async_trait>>
+// where
+// 'life0: 'async_trait,
+// 'life1: 'async_trait,
+// T: 'async_trait,
+// Self: Sync + 'async_trait;
+fn transform_sig(
+ context: Context,
+ sig: &mut Signature,
+ has_self: bool,
+ has_default: bool,
+ is_local: bool,
+) {
+ sig.fn_token.span = sig.asyncness.take().unwrap().span;
+
+ let ret = match &sig.output {
+ ReturnType::Default => quote!(()),
+ ReturnType::Type(_, ret) => quote!(#ret),
+ };
+
+ let mut lifetimes = CollectLifetimes::new("'life");
+ for arg in sig.inputs.iter_mut() {
+ match arg {
+ FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg),
+ FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty),
+ }
+ }
+
+ let where_clause = sig
+ .generics
+ .where_clause
+ .get_or_insert_with(|| WhereClause {
+ where_token: Default::default(),
+ predicates: Punctuated::new(),
+ });
+ for param in sig
+ .generics
+ .params
+ .iter()
+ .chain(context.lifetimes(&lifetimes.explicit))
+ {
+ match param {
+ GenericParam::Type(param) => {
+ let param = ¶m.ident;
+ where_clause
+ .predicates
+ .push(parse_quote!(#param: 'async_trait));
+ }
+ GenericParam::Lifetime(param) => {
+ let param = ¶m.lifetime;
+ where_clause
+ .predicates
+ .push(parse_quote!(#param: 'async_trait));
+ }
+ GenericParam::Const(_) => {}
+ }
+ }
+ for elided in lifetimes.elided {
+ sig.generics.params.push(parse_quote!(#elided));
+ where_clause
+ .predicates
+ .push(parse_quote!(#elided: 'async_trait));
+ }
+ sig.generics.params.push(parse_quote!('async_trait));
+ if has_self {
+ let bound: Ident = match sig.inputs.iter().next() {
+ Some(FnArg::Receiver(Receiver {
+ reference: Some(_),
+ mutability: None,
+ ..
+ })) => parse_quote!(Sync),
+ Some(FnArg::Typed(arg))
+ if match (arg.pat.as_ref(), arg.ty.as_ref()) {
+ (Pat::Ident(pat), Type::Reference(ty)) => {
+ pat.ident == "self" && ty.mutability.is_none()
+ }
+ _ => false,
+ } =>
+ {
+ parse_quote!(Sync)
+ }
+ _ => parse_quote!(Send),
+ };
+ let assume_bound = match context {
+ Context::Trait { supertraits, .. } => !has_default || has_bound(supertraits, &bound),
+ Context::Impl { .. } => true,
+ };
+ where_clause.predicates.push(if assume_bound || is_local {
+ parse_quote!(Self: 'async_trait)
+ } else {
+ parse_quote!(Self: ::core::marker::#bound + 'async_trait)
+ });
+ }
+
+ for (i, arg) in sig.inputs.iter_mut().enumerate() {
+ match arg {
+ FnArg::Receiver(Receiver {
+ reference: Some(_), ..
+ }) => {}
+ FnArg::Receiver(arg) => arg.mutability = None,
+ FnArg::Typed(arg) => {
+ if let Pat::Ident(ident) = &mut *arg.pat {
+ ident.by_ref = None;
+ ident.mutability = None;
+ } else {
+ let positional = positional_arg(i);
+ *arg.pat = parse_quote!(#positional);
+ }
+ }
+ }
+ }
+
+ let bounds = if is_local {
+ quote!('async_trait)
+ } else {
+ quote!(::core::marker::Send + 'async_trait)
+ };
+
+ sig.output = parse_quote! {
+ -> ::core::pin::Pin<Box<
+ dyn ::core::future::Future<Output = #ret> + #bounds
+ >>
+ };
+}
+
+// Input:
+// async fn f<T>(&self, x: &T) -> Ret {
+// self + x
+// }
+//
+// Output:
+// async fn f<T, AsyncTrait>(_self: &AsyncTrait, x: &T) -> Ret {
+// _self + x
+// }
+// Box::pin(async_trait_method::<T, Self>(self, x))
+fn transform_block(
+ context: Context,
+ sig: &mut Signature,
+ block: &mut Block,
+ has_self: bool,
+ is_local: bool,
+) {
+ if let Some(Stmt::Item(syn::Item::Verbatim(item))) = block.stmts.first() {
+ if block.stmts.len() == 1 && item.to_string() == ";" {
+ return;
+ }
+ }
+
+ let inner = format_ident!("__{}", sig.ident);
+ let args = sig.inputs.iter().enumerate().map(|(i, arg)| match arg {
+ FnArg::Receiver(Receiver { self_token, .. }) => quote!(#self_token),
+ FnArg::Typed(arg) => {
+ if let Pat::Ident(PatIdent { ident, .. }) = &*arg.pat {
+ quote!(#ident)
+ } else {
+ positional_arg(i).into_token_stream()
+ }
+ }
+ });
+
+ let mut standalone = sig.clone();
+ standalone.ident = inner.clone();
+
+ let generics = match context {
+ Context::Trait { generics, .. } => generics,
+ Context::Impl { impl_generics, .. } => impl_generics,
+ };
+
+ let mut outer_generics = generics.clone();
+ if !has_self {
+ if let Some(mut where_clause) = outer_generics.where_clause {
+ where_clause.predicates = where_clause
+ .predicates
+ .into_iter()
+ .filter_map(|mut pred| {
+ if has_self_in_where_predicate(&mut pred) {
+ None
+ } else {
+ Some(pred)
+ }
+ })
+ .collect();
+ outer_generics.where_clause = Some(where_clause);
+ }
+ }
+
+ let fn_generics = mem::replace(&mut standalone.generics, outer_generics);
+ standalone.generics.params.extend(fn_generics.params);
+ if let Some(where_clause) = fn_generics.where_clause {
+ standalone
+ .generics
+ .make_where_clause()
+ .predicates
+ .extend(where_clause.predicates);
+ }
+
+ if has_async_lifetime(&mut standalone, block) {
+ standalone.generics.params.push(parse_quote!('async_trait));
+ }
+
+ let mut types = standalone
+ .generics
+ .type_params()
+ .map(|param| param.ident.clone())
+ .collect::<Vec<_>>();
+
+ let mut self_bound = None::<TypeParamBound>;
+ match standalone.inputs.iter_mut().next() {
+ Some(
+ arg @ FnArg::Receiver(Receiver {
+ reference: Some(_), ..
+ }),
+ ) => {
+ let (lifetime, mutability, self_token) = match arg {
+ FnArg::Receiver(Receiver {
+ reference: Some((_, lifetime)),
+ mutability,
+ self_token,
+ ..
+ }) => (lifetime, mutability, self_token),
+ _ => unreachable!(),
+ };
+ let under_self = Ident::new("_self", self_token.span);
+ match context {
+ Context::Trait { .. } => {
+ self_bound = Some(match mutability {
+ Some(_) => parse_quote!(::core::marker::Send),
+ None => parse_quote!(::core::marker::Sync),
+ });
+ *arg = parse_quote! {
+ #under_self: &#lifetime #mutability AsyncTrait
+ };
+ }
+ Context::Impl { receiver, .. } => {
+ let mut ty = quote!(#receiver);
+ if let Type::TraitObject(trait_object) = receiver {
+ if trait_object.dyn_token.is_none() {
+ ty = quote!(dyn #ty);
+ }
+ if trait_object.bounds.len() > 1 {
+ ty = quote!((#ty));
+ }
+ }
+ *arg = parse_quote! {
+ #under_self: &#lifetime #mutability #ty
+ };
+ }
+ }
+ }
+ Some(arg @ FnArg::Receiver(_)) => {
+ let (self_token, mutability) = match arg {
+ FnArg::Receiver(Receiver {
+ self_token,
+ mutability,
+ ..
+ }) => (self_token, mutability),
+ _ => unreachable!(),
+ };
+ let under_self = Ident::new("_self", self_token.span);
+ match context {
+ Context::Trait { .. } => {
+ self_bound = Some(parse_quote!(::core::marker::Send));
+ *arg = parse_quote! {
+ #mutability #under_self: AsyncTrait
+ };
+ }
+ Context::Impl { receiver, .. } => {
+ *arg = parse_quote! {
+ #mutability #under_self: #receiver
+ };
+ }
+ }
+ }
+ Some(FnArg::Typed(arg)) => {
+ if let Pat::Ident(arg) = &mut *arg.pat {
+ if arg.ident == "self" {
+ arg.ident = Ident::new("_self", arg.ident.span());
+ }
+ }
+ }
+ _ => {}
+ }
+
+ if let Context::Trait { name, generics, .. } = context {
+ if has_self {
+ let (_, generics, _) = generics.split_for_impl();
+ let mut self_param: TypeParam = parse_quote!(AsyncTrait: ?Sized + #name #generics);
+ if !is_local {
+ self_param.bounds.extend(self_bound);
+ }
+ standalone
+ .generics
+ .params
+ .push(GenericParam::Type(self_param));
+ types.push(Ident::new("Self", Span::call_site()));
+ }
+ }
+
+ if let Some(where_clause) = &mut standalone.generics.where_clause {
+ // Work around an input bound like `where Self::Output: Send` expanding
+ // to `where <AsyncTrait>::Output: Send` which is illegal syntax because
+ // `where<T>` is reserved for future use... :(
+ where_clause.predicates.insert(0, parse_quote!((): Sized));
+ }
+
+ let mut replace = match context {
+ Context::Trait { .. } => ReplaceReceiver::with(parse_quote!(AsyncTrait)),
+ Context::Impl {
+ receiver, as_trait, ..
+ } => ReplaceReceiver::with_as_trait(receiver.clone(), as_trait.clone()),
+ };
+ replace.visit_signature_mut(&mut standalone);
+ replace.visit_block_mut(block);
+
+ let mut generics = types;
+ let consts = standalone
+ .generics
+ .const_params()
+ .map(|param| param.ident.clone());
+ generics.extend(consts);
+
+ let allow_non_snake_case = if sig.ident != sig.ident.to_string().to_lowercase() {
+ Some(quote!(non_snake_case,))
+ } else {
+ None
+ };
+
+ let brace = block.brace_token;
+ let box_pin = quote_spanned!(brace.span=> {
+ #[allow(
+ #allow_non_snake_case
+ clippy::missing_docs_in_private_items,
+ clippy::needless_lifetimes,
+ clippy::ptr_arg,
+ clippy::type_repetition_in_bounds,
+ clippy::used_underscore_binding,
+ )]
+ #standalone #block
+ Box::pin(#inner::<#(#generics),*>(#(#args),*))
+ });
+ *block = parse_quote!(#box_pin);
+ block.brace_token = brace;
+}
+
+fn positional_arg(i: usize) -> Ident {
+ format_ident!("__arg{}", i)
+}
+
+fn has_bound(supertraits: &Supertraits, marker: &Ident) -> bool {
+ for bound in supertraits {
+ if let TypeParamBound::Trait(bound) = bound {
+ if bound.path.is_ident(marker) {
+ return true;
+ }
+ }
+ }
+ false
+}
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..f0102d8
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,328 @@
+//! [![github]](https://github.com/dtolnay/async-trait) [![crates-io]](https://crates.io/crates/async-trait) [![docs-rs]](https://docs.rs/async-trait)
+//!
+//! [github]: https://img.shields.io/badge/github-8da0cb?style=for-the-badge&labelColor=555555&logo=github
+//! [crates-io]: https://img.shields.io/badge/crates.io-fc8d62?style=for-the-badge&labelColor=555555&logo=rust
+//! [docs-rs]: https://img.shields.io/badge/docs.rs-66c2a5?style=for-the-badge&labelColor=555555&logoColor=white&logo=
+//!
+//! <br>
+//!
+//! <h5>Type erasure for async trait methods</h5>
+//!
+//! The initial round of stabilizations for the async/await language feature in
+//! Rust 1.39 did not include support for async fn in traits. Trying to include
+//! an async fn in a trait produces the following error:
+//!
+//! ```compile_fail
+//! trait MyTrait {
+//! async fn f() {}
+//! }
+//! ```
+//!
+//! ```text
+//! error[E0706]: trait fns cannot be declared `async`
+//! --> src/main.rs:4:5
+//! |
+//! 4 | async fn f() {}
+//! | ^^^^^^^^^^^^^^^
+//! ```
+//!
+//! This crate provides an attribute macro to make async fn in traits work.
+//!
+//! Please refer to [*why async fn in traits are hard*][hard] for a deeper
+//! analysis of how this implementation differs from what the compiler and
+//! language hope to deliver in the future.
+//!
+//! [hard]: https://smallcultfollowing.com/babysteps/blog/2019/10/26/async-fn-in-traits-are-hard/
+//!
+//! <br>
+//!
+//! # Example
+//!
+//! This example implements the core of a highly effective advertising platform
+//! using async fn in a trait.
+//!
+//! The only thing to notice here is that we write an `#[async_trait]` macro on
+//! top of traits and trait impls that contain async fn, and then they work.
+//!
+//! ```
+//! use async_trait::async_trait;
+//!
+//! #[async_trait]
+//! trait Advertisement {
+//! async fn run(&self);
+//! }
+//!
+//! struct Modal;
+//!
+//! #[async_trait]
+//! impl Advertisement for Modal {
+//! async fn run(&self) {
+//! self.render_fullscreen().await;
+//! for _ in 0..4u16 {
+//! remind_user_to_join_mailing_list().await;
+//! }
+//! self.hide_for_now().await;
+//! }
+//! }
+//!
+//! struct AutoplayingVideo {
+//! media_url: String,
+//! }
+//!
+//! #[async_trait]
+//! impl Advertisement for AutoplayingVideo {
+//! async fn run(&self) {
+//! let stream = connect(&self.media_url).await;
+//! stream.play().await;
+//!
+//! // Video probably persuaded user to join our mailing list!
+//! Modal.run().await;
+//! }
+//! }
+//! #
+//! # impl Modal {
+//! # async fn render_fullscreen(&self) {}
+//! # async fn hide_for_now(&self) {}
+//! # }
+//! #
+//! # async fn remind_user_to_join_mailing_list() {}
+//! #
+//! # struct Stream;
+//! # async fn connect(_media_url: &str) -> Stream { Stream }
+//! # impl Stream {
+//! # async fn play(&self) {}
+//! # }
+//! ```
+//!
+//! <br><br>
+//!
+//! # Supported features
+//!
+//! It is the intention that all features of Rust traits should work nicely with
+//! #\[async_trait\], but the edge cases are numerous. Please file an issue if
+//! you see unexpected borrow checker errors, type errors, or warnings. There is
+//! no use of `unsafe` in the expanded code, so rest assured that if your code
+//! compiles it can't be that badly broken.
+//!
+//! > ☑ Self by value, by reference, by mut reference, or no self;<br>
+//! > ☑ Any number of arguments, any return value;<br>
+//! > ☑ Generic type parameters and lifetime parameters;<br>
+//! > ☑ Associated types;<br>
+//! > ☑ Having async and non-async functions in the same trait;<br>
+//! > ☑ Default implementations provided by the trait;<br>
+//! > ☑ Elided lifetimes;<br>
+//! > ☑ Dyn-capable traits.<br>
+//!
+//! <br>
+//!
+//! # Explanation
+//!
+//! Async fns get transformed into methods that return `Pin<Box<dyn Future +
+//! Send + 'async>>` and delegate to a private async freestanding function.
+//!
+//! For example the `impl Advertisement for AutoplayingVideo` above would be
+//! expanded as:
+//!
+//! ```
+//! # const IGNORE: &str = stringify! {
+//! impl Advertisement for AutoplayingVideo {
+//! fn run<'async>(
+//! &'async self,
+//! ) -> Pin<Box<dyn core::future::Future<Output = ()> + Send + 'async>>
+//! where
+//! Self: Sync + 'async,
+//! {
+//! async fn run(_self: &AutoplayingVideo) {
+//! /* the original method body */
+//! }
+//!
+//! Box::pin(run(self))
+//! }
+//! }
+//! # };
+//! ```
+//!
+//! <br><br>
+//!
+//! # Non-threadsafe futures
+//!
+//! Not all async traits need futures that are `dyn Future + Send`. To avoid
+//! having Send and Sync bounds placed on the async trait methods, invoke the
+//! async trait macro as `#[async_trait(?Send)]` on both the trait and the impl
+//! blocks.
+//!
+//! <br>
+//!
+//! # Elided lifetimes
+//!
+//! Be aware that async fn syntax does not allow lifetime elision outside of `&`
+//! and `&mut` references. (This is true even when not using #\[async_trait\].)
+//! Lifetimes must be named or marked by the placeholder `'_`.
+//!
+//! Fortunately the compiler is able to diagnose missing lifetimes with a good
+//! error message.
+//!
+//! ```compile_fail
+//! # use async_trait::async_trait;
+//! #
+//! type Elided<'a> = &'a usize;
+//!
+//! #[async_trait]
+//! trait Test {
+//! async fn test(not_okay: Elided, okay: &usize) {}
+//! }
+//! ```
+//!
+//! ```text
+//! error[E0726]: implicit elided lifetime not allowed here
+//! --> src/main.rs:9:29
+//! |
+//! 9 | async fn test(not_okay: Elided, okay: &usize) {}
+//! | ^^^^^^- help: indicate the anonymous lifetime: `<'_>`
+//! ```
+//!
+//! The fix is to name the lifetime or use `'_`.
+//!
+//! ```
+//! # use async_trait::async_trait;
+//! #
+//! # type Elided<'a> = &'a usize;
+//! #
+//! #[async_trait]
+//! trait Test {
+//! // either
+//! async fn test<'e>(elided: Elided<'e>) {}
+//! # }
+//! # #[async_trait]
+//! # trait Test2 {
+//! // or
+//! async fn test(elided: Elided<'_>) {}
+//! }
+//! ```
+//!
+//! <br><br>
+//!
+//! # Dyn traits
+//!
+//! Traits with async methods can be used as trait objects as long as they meet
+//! the usual requirements for dyn -- no methods with type parameters, no self
+//! by value, no associated types, etc.
+//!
+//! ```
+//! # use async_trait::async_trait;
+//! #
+//! #[async_trait]
+//! pub trait ObjectSafe {
+//! async fn f(&self);
+//! async fn g(&mut self);
+//! }
+//!
+//! # const IGNORE: &str = stringify! {
+//! impl ObjectSafe for MyType {...}
+//!
+//! let value: MyType = ...;
+//! # };
+//! #
+//! # struct MyType;
+//! #
+//! # #[async_trait]
+//! # impl ObjectSafe for MyType {
+//! # async fn f(&self) {}
+//! # async fn g(&mut self) {}
+//! # }
+//! #
+//! # let value: MyType = MyType;
+//! let object = &value as &dyn ObjectSafe; // make trait object
+//! ```
+//!
+//! The one wrinkle is in traits that provide default implementations of async
+//! methods. In order for the default implementation to produce a future that is
+//! Send, the async_trait macro must emit a bound of `Self: Sync` on trait
+//! methods that take `&self` and a bound `Self: Send` on trait methods that
+//! take `&mut self`. An example of the former is visible in the expanded code
+//! in the explanation section above.
+//!
+//! If you make a trait with async methods that have default implementations,
+//! everything will work except that the trait cannot be used as a trait object.
+//! Creating a value of type `&dyn Trait` will produce an error that looks like
+//! this:
+//!
+//! ```text
+//! error: the trait `Test` cannot be made into an object
+//! --> src/main.rs:8:5
+//! |
+//! 8 | async fn cannot_dyn(&self) {}
+//! | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+//! ```
+//!
+//! For traits that need to be object safe and need to have default
+//! implementations for some async methods, there are two resolutions. Either
+//! you can add Send and/or Sync as supertraits (Send if there are `&mut self`
+//! methods with default implementations, Sync if there are `&self` methods with
+//! default implementions) to constrain all implementors of the trait such that
+//! the default implementations are applicable to them:
+//!
+//! ```
+//! # use async_trait::async_trait;
+//! #
+//! #[async_trait]
+//! pub trait ObjectSafe: Sync { // added supertrait
+//! async fn can_dyn(&self) {}
+//! }
+//! #
+//! # struct MyType;
+//! #
+//! # #[async_trait]
+//! # impl ObjectSafe for MyType {}
+//! #
+//! # let value = MyType;
+//!
+//! let object = &value as &dyn ObjectSafe;
+//! ```
+//!
+//! or you can strike the problematic methods from your trait object by
+//! bounding them with `Self: Sized`:
+//!
+//! ```
+//! # use async_trait::async_trait;
+//! #
+//! #[async_trait]
+//! pub trait ObjectSafe {
+//! async fn cannot_dyn(&self) where Self: Sized {}
+//!
+//! // presumably other methods
+//! }
+//! #
+//! # struct MyType;
+//! #
+//! # #[async_trait]
+//! # impl ObjectSafe for MyType {}
+//! #
+//! # let value = MyType;
+//!
+//! let object = &value as &dyn ObjectSafe;
+//! ```
+
+extern crate proc_macro;
+
+mod args;
+mod expand;
+mod lifetime;
+mod parse;
+mod receiver;
+mod respan;
+
+use crate::args::Args;
+use crate::expand::expand;
+use crate::parse::Item;
+use proc_macro::TokenStream;
+use quote::quote;
+use syn::parse_macro_input;
+
+#[proc_macro_attribute]
+pub fn async_trait(args: TokenStream, input: TokenStream) -> TokenStream {
+ let args = parse_macro_input!(args as Args);
+ let mut item = parse_macro_input!(input as Item);
+ expand(&mut item, args.local);
+ TokenStream::from(quote!(#item))
+}
diff --git a/src/lifetime.rs b/src/lifetime.rs
new file mode 100644
index 0000000..9d2066b
--- /dev/null
+++ b/src/lifetime.rs
@@ -0,0 +1,80 @@
+use proc_macro2::Span;
+use syn::visit_mut::{self, VisitMut};
+use syn::{Block, GenericArgument, Item, Lifetime, Receiver, Signature, TypeReference};
+
+pub fn has_async_lifetime(sig: &mut Signature, block: &mut Block) -> bool {
+ let mut visitor = HasAsyncLifetime(false);
+ visitor.visit_signature_mut(sig);
+ visitor.visit_block_mut(block);
+ visitor.0
+}
+
+struct HasAsyncLifetime(bool);
+
+impl VisitMut for HasAsyncLifetime {
+ fn visit_lifetime_mut(&mut self, life: &mut Lifetime) {
+ self.0 |= life.to_string() == "'async_trait";
+ }
+
+ fn visit_item_mut(&mut self, _: &mut Item) {
+ // Do not recurse into nested items.
+ }
+}
+
+pub struct CollectLifetimes {
+ pub elided: Vec<Lifetime>,
+ pub explicit: Vec<Lifetime>,
+ pub name: &'static str,
+}
+
+impl CollectLifetimes {
+ pub fn new(name: &'static str) -> Self {
+ CollectLifetimes {
+ elided: Vec::new(),
+ explicit: Vec::new(),
+ name,
+ }
+ }
+
+ fn visit_opt_lifetime(&mut self, lifetime: &mut Option<Lifetime>) {
+ match lifetime {
+ None => *lifetime = Some(self.next_lifetime()),
+ Some(lifetime) => self.visit_lifetime(lifetime),
+ }
+ }
+
+ fn visit_lifetime(&mut self, lifetime: &mut Lifetime) {
+ if lifetime.ident == "_" {
+ *lifetime = self.next_lifetime();
+ } else {
+ self.explicit.push(lifetime.clone());
+ }
+ }
+
+ fn next_lifetime(&mut self) -> Lifetime {
+ let name = format!("{}{}", self.name, self.elided.len());
+ let life = Lifetime::new(&name, Span::call_site());
+ self.elided.push(life.clone());
+ life
+ }
+}
+
+impl VisitMut for CollectLifetimes {
+ fn visit_receiver_mut(&mut self, arg: &mut Receiver) {
+ if let Some((_, lifetime)) = &mut arg.reference {
+ self.visit_opt_lifetime(lifetime);
+ }
+ }
+
+ fn visit_type_reference_mut(&mut self, ty: &mut TypeReference) {
+ self.visit_opt_lifetime(&mut ty.lifetime);
+ visit_mut::visit_type_reference_mut(self, ty);
+ }
+
+ fn visit_generic_argument_mut(&mut self, gen: &mut GenericArgument) {
+ if let GenericArgument::Lifetime(lifetime) = gen {
+ self.visit_lifetime(lifetime);
+ }
+ visit_mut::visit_generic_argument_mut(self, gen);
+ }
+}
diff --git a/src/parse.rs b/src/parse.rs
new file mode 100644
index 0000000..ebd2535
--- /dev/null
+++ b/src/parse.rs
@@ -0,0 +1,34 @@
+use proc_macro2::Span;
+use syn::parse::{Error, Parse, ParseStream, Result};
+use syn::{Attribute, ItemImpl, ItemTrait, Token};
+
+pub enum Item {
+ Trait(ItemTrait),
+ Impl(ItemImpl),
+}
+
+impl Parse for Item {
+ fn parse(input: ParseStream) -> Result<Self> {
+ let attrs = input.call(Attribute::parse_outer)?;
+ let mut lookahead = input.lookahead1();
+ if lookahead.peek(Token![unsafe]) {
+ let ahead = input.fork();
+ ahead.parse::<Token![unsafe]>()?;
+ lookahead = ahead.lookahead1();
+ }
+ if lookahead.peek(Token![pub]) || lookahead.peek(Token![trait]) {
+ let mut item: ItemTrait = input.parse()?;
+ item.attrs = attrs;
+ Ok(Item::Trait(item))
+ } else if lookahead.peek(Token![impl]) {
+ let mut item: ItemImpl = input.parse()?;
+ if item.trait_.is_none() {
+ return Err(Error::new(Span::call_site(), "expected a trait impl"));
+ }
+ item.attrs = attrs;
+ Ok(Item::Impl(item))
+ } else {
+ Err(lookahead.error())
+ }
+ }
+}
diff --git a/src/receiver.rs b/src/receiver.rs
new file mode 100644
index 0000000..1e9e397
--- /dev/null
+++ b/src/receiver.rs
@@ -0,0 +1,307 @@
+use crate::respan::respan;
+use proc_macro2::{Group, Spacing, Span, TokenStream, TokenTree};
+use quote::{quote, quote_spanned};
+use std::iter::FromIterator;
+use std::mem;
+use syn::punctuated::Punctuated;
+use syn::visit_mut::{self, VisitMut};
+use syn::{
+ parse_quote, Block, Error, ExprPath, ExprStruct, Ident, Item, Macro, PatPath, PatStruct,
+ PatTupleStruct, Path, PathArguments, QSelf, Receiver, Signature, Token, Type, TypePath,
+ WherePredicate,
+};
+
+pub fn has_self_in_sig(sig: &mut Signature) -> bool {
+ let mut visitor = HasSelf(false);
+ visitor.visit_signature_mut(sig);
+ visitor.0
+}
+
+pub fn has_self_in_where_predicate(where_predicate: &mut WherePredicate) -> bool {
+ let mut visitor = HasSelf(false);
+ visitor.visit_where_predicate_mut(where_predicate);
+ visitor.0
+}
+
+pub fn has_self_in_block(block: &mut Block) -> bool {
+ let mut visitor = HasSelf(false);
+ visitor.visit_block_mut(block);
+ visitor.0
+}
+
+struct HasSelf(bool);
+
+impl VisitMut for HasSelf {
+ fn visit_expr_path_mut(&mut self, expr: &mut ExprPath) {
+ self.0 |= expr.path.segments[0].ident == "Self";
+ visit_mut::visit_expr_path_mut(self, expr);
+ }
+
+ fn visit_pat_path_mut(&mut self, pat: &mut PatPath) {
+ self.0 |= pat.path.segments[0].ident == "Self";
+ visit_mut::visit_pat_path_mut(self, pat);
+ }
+
+ fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
+ self.0 |= ty.path.segments[0].ident == "Self";
+ visit_mut::visit_type_path_mut(self, ty);
+ }
+
+ fn visit_receiver_mut(&mut self, _arg: &mut Receiver) {
+ self.0 = true;
+ }
+
+ fn visit_item_mut(&mut self, _: &mut Item) {
+ // Do not recurse into nested items.
+ }
+}
+
+pub struct ReplaceReceiver {
+ pub with: Type,
+ pub as_trait: Option<Path>,
+}
+
+impl ReplaceReceiver {
+ pub fn with(ty: Type) -> Self {
+ ReplaceReceiver {
+ with: ty,
+ as_trait: None,
+ }
+ }
+
+ pub fn with_as_trait(ty: Type, as_trait: Path) -> Self {
+ ReplaceReceiver {
+ with: ty,
+ as_trait: Some(as_trait),
+ }
+ }
+
+ fn self_ty(&self, span: Span) -> Type {
+ respan(&self.with, span)
+ }
+
+ fn self_to_qself_type(&self, qself: &mut Option<QSelf>, path: &mut Path) {
+ let include_as_trait = true;
+ self.self_to_qself(qself, path, include_as_trait);
+ }
+
+ fn self_to_qself_expr(&self, qself: &mut Option<QSelf>, path: &mut Path) {
+ let include_as_trait = false;
+ self.self_to_qself(qself, path, include_as_trait);
+ }
+
+ fn self_to_qself(&self, qself: &mut Option<QSelf>, path: &mut Path, include_as_trait: bool) {
+ if path.leading_colon.is_some() {
+ return;
+ }
+
+ let first = &path.segments[0];
+ if first.ident != "Self" || !first.arguments.is_empty() {
+ return;
+ }
+
+ if path.segments.len() == 1 {
+ self.self_to_expr_path(path);
+ return;
+ }
+
+ let span = first.ident.span();
+ *qself = Some(QSelf {
+ lt_token: Token,
+ ty: Box::new(self.self_ty(span)),
+ position: 0,
+ as_token: None,
+ gt_token: Token,
+ });
+
+ if include_as_trait && self.as_trait.is_some() {
+ let as_trait = self.as_trait.as_ref().unwrap().clone();
+ path.leading_colon = as_trait.leading_colon;
+ qself.as_mut().unwrap().position = as_trait.segments.len();
+
+ let segments = mem::replace(&mut path.segments, as_trait.segments);
+ path.segments.push_punct(Default::default());
+ path.segments.extend(segments.into_pairs().skip(1));
+ } else {
+ path.leading_colon = Some(**path.segments.pairs().next().unwrap().punct().unwrap());
+
+ let segments = mem::replace(&mut path.segments, Punctuated::new());
+ path.segments = segments.into_pairs().skip(1).collect();
+ }
+ }
+
+ fn self_to_expr_path(&self, path: &mut Path) {
+ if path.leading_colon.is_some() {
+ return;
+ }
+
+ let first = &path.segments[0];
+ if first.ident != "Self" || !first.arguments.is_empty() {
+ return;
+ }
+
+ if let Type::Path(self_ty) = self.self_ty(first.ident.span()) {
+ let variant = mem::replace(path, self_ty.path);
+ for segment in &mut path.segments {
+ if let PathArguments::AngleBracketed(bracketed) = &mut segment.arguments {
+ if bracketed.colon2_token.is_none() && !bracketed.args.is_empty() {
+ bracketed.colon2_token = Some(Default::default());
+ }
+ }
+ }
+ if variant.segments.len() > 1 {
+ path.segments.push_punct(Default::default());
+ path.segments.extend(variant.segments.into_pairs().skip(1));
+ }
+ } else {
+ let span = path.segments[0].ident.span();
+ let msg = "Self type of this impl is unsupported in expression position";
+ let error = Error::new(span, msg).to_compile_error();
+ *path = parse_quote!(::core::marker::PhantomData::<#error>);
+ }
+ }
+
+ fn visit_token_stream(&self, tokens: &mut TokenStream) -> bool {
+ let mut out = Vec::new();
+ let mut modified = false;
+ let mut iter = tokens.clone().into_iter().peekable();
+ while let Some(tt) = iter.next() {
+ match tt {
+ TokenTree::Ident(mut ident) => {
+ modified |= prepend_underscore_to_self(&mut ident);
+ if ident == "Self" {
+ modified = true;
+ if self.as_trait.is_none() {
+ let ident = Ident::new("AsyncTrait", ident.span());
+ out.push(TokenTree::Ident(ident));
+ } else {
+ let self_ty = self.self_ty(ident.span());
+ match iter.peek() {
+ Some(TokenTree::Punct(p))
+ if p.as_char() == ':' && p.spacing() == Spacing::Joint =>
+ {
+ let next = iter.next().unwrap();
+ match iter.peek() {
+ Some(TokenTree::Punct(p)) if p.as_char() == ':' => {
+ let span = ident.span();
+ out.extend(quote_spanned!(span=> <#self_ty>));
+ }
+ _ => out.extend(quote!(#self_ty)),
+ }
+ out.push(next);
+ }
+ _ => out.extend(quote!(#self_ty)),
+ }
+ }
+ } else {
+ out.push(TokenTree::Ident(ident));
+ }
+ }
+ TokenTree::Group(group) => {
+ let mut content = group.stream();
+ modified |= self.visit_token_stream(&mut content);
+ let mut new = Group::new(group.delimiter(), content);
+ new.set_span(group.span());
+ out.push(TokenTree::Group(new));
+ }
+ other => out.push(other),
+ }
+ }
+ if modified {
+ *tokens = TokenStream::from_iter(out);
+ }
+ modified
+ }
+}
+
+impl VisitMut for ReplaceReceiver {
+ // `Self` -> `Receiver`
+ fn visit_type_mut(&mut self, ty: &mut Type) {
+ if let Type::Path(node) = ty {
+ if node.qself.is_none() && node.path.is_ident("Self") {
+ *ty = self.self_ty(node.path.segments[0].ident.span());
+ } else {
+ self.visit_type_path_mut(node);
+ }
+ } else {
+ visit_mut::visit_type_mut(self, ty);
+ }
+ }
+
+ // `Self::Assoc` -> `<Receiver>::Assoc`
+ fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
+ if ty.qself.is_none() {
+ self.self_to_qself_type(&mut ty.qself, &mut ty.path);
+ }
+ visit_mut::visit_type_path_mut(self, ty);
+ }
+
+ // `Self::method` -> `<Receiver>::method`
+ fn visit_expr_path_mut(&mut self, expr: &mut ExprPath) {
+ if expr.qself.is_none() {
+ prepend_underscore_to_self(&mut expr.path.segments[0].ident);
+ self.self_to_qself_expr(&mut expr.qself, &mut expr.path);
+ }
+ visit_mut::visit_expr_path_mut(self, expr);
+ }
+
+ fn visit_expr_struct_mut(&mut self, expr: &mut ExprStruct) {
+ self.self_to_expr_path(&mut expr.path);
+ visit_mut::visit_expr_struct_mut(self, expr);
+ }
+
+ fn visit_pat_path_mut(&mut self, pat: &mut PatPath) {
+ if pat.qself.is_none() {
+ self.self_to_qself_expr(&mut pat.qself, &mut pat.path);
+ }
+ visit_mut::visit_pat_path_mut(self, pat);
+ }
+
+ fn visit_pat_struct_mut(&mut self, pat: &mut PatStruct) {
+ self.self_to_expr_path(&mut pat.path);
+ visit_mut::visit_pat_struct_mut(self, pat);
+ }
+
+ fn visit_pat_tuple_struct_mut(&mut self, pat: &mut PatTupleStruct) {
+ self.self_to_expr_path(&mut pat.path);
+ visit_mut::visit_pat_tuple_struct_mut(self, pat);
+ }
+
+ fn visit_item_mut(&mut self, i: &mut Item) {
+ match i {
+ // Visit `macro_rules!` because locally defined macros can refer to `self`.
+ Item::Macro(i) if i.mac.path.is_ident("macro_rules") => {
+ self.visit_macro_mut(&mut i.mac)
+ }
+ // Otherwise, do not recurse into nested items.
+ _ => {}
+ }
+ }
+
+ fn visit_macro_mut(&mut self, i: &mut Macro) {
+ // We can't tell in general whether `self` inside a macro invocation
+ // refers to the self in the argument list or a different self
+ // introduced within the macro. Heuristic: if the macro input contains
+ // `fn`, then `self` is more likely to refer to something other than the
+ // outer function's self argument.
+ if !contains_fn(i.tokens.clone()) {
+ self.visit_token_stream(&mut i.tokens);
+ }
+ }
+}
+
+fn contains_fn(tokens: TokenStream) -> bool {
+ tokens.into_iter().any(|tt| match tt {
+ TokenTree::Ident(ident) => ident == "fn",
+ TokenTree::Group(group) => contains_fn(group.stream()),
+ _ => false,
+ })
+}
+
+fn prepend_underscore_to_self(ident: &mut Ident) -> bool {
+ let modified = ident == "self";
+ if modified {
+ *ident = Ident::new("_self", ident.span());
+ }
+ modified
+}
diff --git a/src/respan.rs b/src/respan.rs
new file mode 100644
index 0000000..38f6612
--- /dev/null
+++ b/src/respan.rs
@@ -0,0 +1,22 @@
+use proc_macro2::{Span, TokenStream};
+use quote::ToTokens;
+use syn::parse::Parse;
+
+pub(crate) fn respan<T>(node: &T, span: Span) -> T
+where
+ T: ToTokens + Parse,
+{
+ let tokens = node.to_token_stream();
+ let respanned = respan_tokens(tokens, span);
+ syn::parse2(respanned).unwrap()
+}
+
+fn respan_tokens(tokens: TokenStream, span: Span) -> TokenStream {
+ tokens
+ .into_iter()
+ .map(|mut token| {
+ token.set_span(span);
+ token
+ })
+ .collect()
+}