blob: 2f4697a40a7ebf16d337cc7f310870599b696964 [file] [log] [blame]
use crate::lifetime::CollectLifetimes;
use crate::parse::Item;
use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf};
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use std::collections::BTreeSet as Set;
use syn::punctuated::Punctuated;
use syn::visit_mut::{self, VisitMut};
use syn::{
parse_quote, Attribute, Block, FnArg, GenericParam, Generics, Ident, ImplItem, Lifetime, Pat,
PatIdent, Receiver, ReturnType, Signature, Stmt, Token, TraitItem, Type, TypeParamBound,
TypePath, WhereClause,
};
macro_rules! parse_quote_spanned {
($span:expr=> $($t:tt)*) => {
syn::parse2(quote_spanned!($span=> $($t)*)).unwrap()
};
}
impl ToTokens for Item {
fn to_tokens(&self, tokens: &mut TokenStream) {
match self {
Item::Trait(item) => item.to_tokens(tokens),
Item::Impl(item) => item.to_tokens(tokens),
}
}
}
#[derive(Clone, Copy)]
enum Context<'a> {
Trait {
generics: &'a Generics,
supertraits: &'a Supertraits,
},
Impl {
impl_generics: &'a Generics,
associated_type_impl_traits: &'a Set<Ident>,
},
}
impl Context<'_> {
fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a GenericParam> {
let generics = match self {
Context::Trait { generics, .. } => generics,
Context::Impl { impl_generics, .. } => impl_generics,
};
generics.params.iter().filter(move |param| {
if let GenericParam::Lifetime(param) = param {
used.contains(&param.lifetime)
} else {
false
}
})
}
}
type Supertraits = Punctuated<TypeParamBound, Token![+]>;
pub fn expand(input: &mut Item, is_local: bool) {
match input {
Item::Trait(input) => {
let context = Context::Trait {
generics: &input.generics,
supertraits: &input.supertraits,
};
for inner in &mut input.items {
if let TraitItem::Method(method) = inner {
let sig = &mut method.sig;
if sig.asyncness.is_some() {
let block = &mut method.default;
let mut has_self = has_self_in_sig(sig);
method.attrs.push(parse_quote!(#[must_use]));
if let Some(block) = block {
has_self |= has_self_in_block(block);
transform_block(context, sig, block);
method.attrs.push(lint_suppress_with_body());
} else {
method.attrs.push(lint_suppress_without_body());
}
let has_default = method.default.is_some();
transform_sig(context, sig, has_self, has_default, is_local);
}
}
}
}
Item::Impl(input) => {
let mut lifetimes = CollectLifetimes::new("'impl", input.impl_token.span);
lifetimes.visit_type_mut(&mut *input.self_ty);
lifetimes.visit_path_mut(&mut input.trait_.as_mut().unwrap().1);
let params = &input.generics.params;
let elided = lifetimes.elided;
input.generics.params = parse_quote!(#(#elided,)* #params);
let mut associated_type_impl_traits = Set::new();
for inner in &input.items {
if let ImplItem::Type(assoc) = inner {
if let Type::ImplTrait(_) = assoc.ty {
associated_type_impl_traits.insert(assoc.ident.clone());
}
}
}
let context = Context::Impl {
impl_generics: &input.generics,
associated_type_impl_traits: &associated_type_impl_traits,
};
for inner in &mut input.items {
if let ImplItem::Method(method) = inner {
let sig = &mut method.sig;
if sig.asyncness.is_some() {
let block = &mut method.block;
let has_self = has_self_in_sig(sig) || has_self_in_block(block);
transform_block(context, sig, block);
transform_sig(context, sig, has_self, false, is_local);
method.attrs.push(lint_suppress_with_body());
}
}
}
}
}
}
fn lint_suppress_with_body() -> Attribute {
parse_quote! {
#[allow(
clippy::let_unit_value,
clippy::no_effect_underscore_binding,
clippy::shadow_same,
clippy::type_complexity,
clippy::type_repetition_in_bounds,
clippy::used_underscore_binding
)]
}
}
fn lint_suppress_without_body() -> Attribute {
parse_quote! {
#[allow(
clippy::type_complexity,
clippy::type_repetition_in_bounds
)]
}
}
// Input:
// async fn f<T>(&self, x: &T) -> Ret;
//
// Output:
// fn f<'life0, 'life1, 'async_trait, T>(
// &'life0 self,
// x: &'life1 T,
// ) -> Pin<Box<dyn Future<Output = Ret> + Send + 'async_trait>>
// where
// 'life0: 'async_trait,
// 'life1: 'async_trait,
// T: 'async_trait,
// Self: Sync + 'async_trait;
fn transform_sig(
context: Context,
sig: &mut Signature,
has_self: bool,
has_default: bool,
is_local: bool,
) {
sig.fn_token.span = sig.asyncness.take().unwrap().span;
let ret = match &sig.output {
ReturnType::Default => quote!(()),
ReturnType::Type(_, ret) => quote!(#ret),
};
let default_span = sig
.ident
.span()
.join(sig.paren_token.span)
.unwrap_or_else(|| sig.ident.span());
let mut lifetimes = CollectLifetimes::new("'life", default_span);
for arg in sig.inputs.iter_mut() {
match arg {
FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg),
FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty),
}
}
for param in sig
.generics
.params
.iter()
.chain(context.lifetimes(&lifetimes.explicit))
{
match param {
GenericParam::Type(param) => {
let param = &param.ident;
let span = param.span();
where_clause_or_default(&mut sig.generics.where_clause)
.predicates
.push(parse_quote_spanned!(span=> #param: 'async_trait));
}
GenericParam::Lifetime(param) => {
let param = &param.lifetime;
let span = param.span();
where_clause_or_default(&mut sig.generics.where_clause)
.predicates
.push(parse_quote_spanned!(span=> #param: 'async_trait));
}
GenericParam::Const(_) => {}
}
}
if sig.generics.lt_token.is_none() {
sig.generics.lt_token = Some(Token![<](sig.ident.span()));
}
if sig.generics.gt_token.is_none() {
sig.generics.gt_token = Some(Token![>](sig.paren_token.span));
}
for elided in lifetimes.elided {
sig.generics.params.push(parse_quote!(#elided));
where_clause_or_default(&mut sig.generics.where_clause)
.predicates
.push(parse_quote_spanned!(elided.span()=> #elided: 'async_trait));
}
sig.generics
.params
.push(parse_quote_spanned!(default_span=> 'async_trait));
if has_self {
let bound_span = sig.ident.span();
let bound = match sig.inputs.iter().next() {
Some(FnArg::Receiver(Receiver {
reference: Some(_),
mutability: None,
..
})) => Ident::new("Sync", bound_span),
Some(FnArg::Typed(arg))
if match (arg.pat.as_ref(), arg.ty.as_ref()) {
(Pat::Ident(pat), Type::Reference(ty)) => {
pat.ident == "self" && ty.mutability.is_none()
}
_ => false,
} =>
{
Ident::new("Sync", bound_span)
}
_ => Ident::new("Send", bound_span),
};
let assume_bound = match context {
Context::Trait { supertraits, .. } => !has_default || has_bound(supertraits, &bound),
Context::Impl { .. } => true,
};
let where_clause = where_clause_or_default(&mut sig.generics.where_clause);
where_clause.predicates.push(if assume_bound || is_local {
parse_quote_spanned!(bound_span=> Self: 'async_trait)
} else {
parse_quote_spanned!(bound_span=> Self: ::core::marker::#bound + 'async_trait)
});
}
for (i, arg) in sig.inputs.iter_mut().enumerate() {
match arg {
FnArg::Receiver(Receiver {
reference: Some(_), ..
}) => {}
FnArg::Receiver(arg) => arg.mutability = None,
FnArg::Typed(arg) => {
if let Pat::Ident(ident) = &mut *arg.pat {
ident.by_ref = None;
ident.mutability = None;
} else {
let positional = positional_arg(i, &arg.pat);
let m = mut_pat(&mut arg.pat);
arg.pat = parse_quote!(#m #positional);
}
}
}
}
let ret_span = sig.ident.span();
let bounds = if is_local {
quote_spanned!(ret_span=> 'async_trait)
} else {
quote_spanned!(ret_span=> ::core::marker::Send + 'async_trait)
};
sig.output = parse_quote_spanned! {ret_span=>
-> ::core::pin::Pin<Box<
dyn ::core::future::Future<Output = #ret> + #bounds
>>
};
}
// Input:
// async fn f<T>(&self, x: &T, (a, b): (A, B)) -> Ret {
// self + x + a + b
// }
//
// Output:
// Box::pin(async move {
// let ___ret: Ret = {
// let __self = self;
// let x = x;
// let (a, b) = __arg1;
//
// __self + x + a + b
// };
//
// ___ret
// })
fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) {
if let Some(Stmt::Item(syn::Item::Verbatim(item))) = block.stmts.first() {
if block.stmts.len() == 1 && item.to_string() == ";" {
return;
}
}
let mut self_span = None;
let decls = sig
.inputs
.iter()
.enumerate()
.map(|(i, arg)| match arg {
FnArg::Receiver(Receiver {
self_token,
mutability,
..
}) => {
let ident = Ident::new("__self", self_token.span);
self_span = Some(self_token.span);
quote!(let #mutability #ident = #self_token;)
}
FnArg::Typed(arg) => {
if let Pat::Ident(PatIdent {
ident, mutability, ..
}) = &*arg.pat
{
if ident == "self" {
self_span = Some(ident.span());
let prefixed = Ident::new("__self", ident.span());
quote!(let #mutability #prefixed = #ident;)
} else {
quote!(let #mutability #ident = #ident;)
}
} else {
let pat = &arg.pat;
let ident = positional_arg(i, pat);
quote!(let #pat = #ident;)
}
}
})
.collect::<Vec<_>>();
if let Some(span) = self_span {
let mut replace_self = ReplaceSelf(span);
replace_self.visit_block_mut(block);
}
let stmts = &block.stmts;
let let_ret = match &mut sig.output {
ReturnType::Default => quote_spanned! {block.brace_token.span=>
#(#decls)*
let _: () = { #(#stmts)* };
},
ReturnType::Type(_, ret) => {
if contains_associated_type_impl_trait(context, ret) {
if decls.is_empty() {
quote!(#(#stmts)*)
} else {
quote!(#(#decls)* { #(#stmts)* })
}
} else {
quote_spanned! {block.brace_token.span=>
if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> {
return __ret;
}
#(#decls)*
let __ret: #ret = { #(#stmts)* };
#[allow(unreachable_code)]
__ret
}
}
}
};
let box_pin = quote_spanned!(block.brace_token.span=>
Box::pin(async move { #let_ret })
);
block.stmts = parse_quote!(#box_pin);
}
fn positional_arg(i: usize, pat: &Pat) -> Ident {
use syn::spanned::Spanned;
format_ident!("__arg{}", i, span = pat.span())
}
fn has_bound(supertraits: &Supertraits, marker: &Ident) -> bool {
for bound in supertraits {
if let TypeParamBound::Trait(bound) = bound {
if bound.path.is_ident(marker)
|| bound.path.segments.len() == 3
&& (bound.path.segments[0].ident == "std"
|| bound.path.segments[0].ident == "core")
&& bound.path.segments[1].ident == "marker"
&& bound.path.segments[2].ident == *marker
{
return true;
}
}
}
false
}
fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool {
struct AssociatedTypeImplTraits<'a> {
set: &'a Set<Ident>,
contains: bool,
}
impl<'a> VisitMut for AssociatedTypeImplTraits<'a> {
fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
if ty.qself.is_none()
&& ty.path.segments.len() == 2
&& ty.path.segments[0].ident == "Self"
&& self.set.contains(&ty.path.segments[1].ident)
{
self.contains = true;
}
visit_mut::visit_type_path_mut(self, ty);
}
}
match context {
Context::Trait { .. } => false,
Context::Impl {
associated_type_impl_traits,
..
} => {
let mut visit = AssociatedTypeImplTraits {
set: associated_type_impl_traits,
contains: false,
};
visit.visit_type_mut(ret);
visit.contains
}
}
}
fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause {
clause.get_or_insert_with(|| WhereClause {
where_token: Default::default(),
predicates: Punctuated::new(),
})
}