blob: aac8e46c3c6757b0721d6a5a51022dbdfa511d6e [file] [log] [blame]
Andrew Walbrand1b91c72020-08-11 17:12:08 +01001use crate::lifetime::{has_async_lifetime, CollectLifetimes};
2use crate::parse::Item;
3use crate::receiver::{
4 has_self_in_block, has_self_in_sig, has_self_in_where_predicate, ReplaceReceiver,
5};
6use proc_macro2::{Span, TokenStream};
7use quote::{format_ident, quote, quote_spanned, ToTokens};
8use std::mem;
9use syn::punctuated::Punctuated;
10use syn::visit_mut::VisitMut;
11use syn::{
12 parse_quote, Block, FnArg, GenericParam, Generics, Ident, ImplItem, Lifetime, Pat, PatIdent,
13 Path, Receiver, ReturnType, Signature, Stmt, Token, TraitItem, Type, TypeParam, TypeParamBound,
14 WhereClause,
15};
16
17impl ToTokens for Item {
18 fn to_tokens(&self, tokens: &mut TokenStream) {
19 match self {
20 Item::Trait(item) => item.to_tokens(tokens),
21 Item::Impl(item) => item.to_tokens(tokens),
22 }
23 }
24}
25
26#[derive(Clone, Copy)]
27enum Context<'a> {
28 Trait {
29 name: &'a Ident,
30 generics: &'a Generics,
31 supertraits: &'a Supertraits,
32 },
33 Impl {
34 impl_generics: &'a Generics,
35 receiver: &'a Type,
36 as_trait: &'a Path,
37 },
38}
39
40impl Context<'_> {
41 fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a GenericParam> {
42 let generics = match self {
43 Context::Trait { generics, .. } => generics,
44 Context::Impl { impl_generics, .. } => impl_generics,
45 };
46 generics.params.iter().filter(move |param| {
47 if let GenericParam::Lifetime(param) = param {
48 used.contains(&param.lifetime)
49 } else {
50 false
51 }
52 })
53 }
54}
55
56type Supertraits = Punctuated<TypeParamBound, Token![+]>;
57
58pub fn expand(input: &mut Item, is_local: bool) {
59 match input {
60 Item::Trait(input) => {
61 let context = Context::Trait {
62 name: &input.ident,
63 generics: &input.generics,
64 supertraits: &input.supertraits,
65 };
66 for inner in &mut input.items {
67 if let TraitItem::Method(method) = inner {
68 let sig = &mut method.sig;
69 if sig.asyncness.is_some() {
70 let block = &mut method.default;
71 let mut has_self = has_self_in_sig(sig);
72 if let Some(block) = block {
73 has_self |= has_self_in_block(block);
74 transform_block(context, sig, block, has_self, is_local);
Chih-Hung Hsiehf8b73ea2020-10-26 13:16:52 -070075 method.attrs.push(parse_quote!(#[allow(clippy::used_underscore_binding)]));
Andrew Walbrand1b91c72020-08-11 17:12:08 +010076 }
77 let has_default = method.default.is_some();
78 transform_sig(context, sig, has_self, has_default, is_local);
79 method.attrs.push(parse_quote!(#[must_use]));
80 }
81 }
82 }
83 }
84 Item::Impl(input) => {
85 let mut lifetimes = CollectLifetimes::new("'impl");
86 lifetimes.visit_type_mut(&mut *input.self_ty);
87 lifetimes.visit_path_mut(&mut input.trait_.as_mut().unwrap().1);
88 let params = &input.generics.params;
89 let elided = lifetimes.elided;
90 input.generics.params = parse_quote!(#(#elided,)* #params);
91
92 let context = Context::Impl {
93 impl_generics: &input.generics,
94 receiver: &input.self_ty,
95 as_trait: &input.trait_.as_ref().unwrap().1,
96 };
97 for inner in &mut input.items {
98 if let ImplItem::Method(method) = inner {
99 let sig = &mut method.sig;
100 if sig.asyncness.is_some() {
101 let block = &mut method.block;
102 let has_self = has_self_in_sig(sig) || has_self_in_block(block);
103 transform_block(context, sig, block, has_self, is_local);
104 transform_sig(context, sig, has_self, false, is_local);
Chih-Hung Hsiehf8b73ea2020-10-26 13:16:52 -0700105 method.attrs.push(parse_quote!(#[allow(clippy::used_underscore_binding)]));
Andrew Walbrand1b91c72020-08-11 17:12:08 +0100106 }
107 }
108 }
109 }
110 }
111}
112
113// Input:
114// async fn f<T>(&self, x: &T) -> Ret;
115//
116// Output:
117// fn f<'life0, 'life1, 'async_trait, T>(
118// &'life0 self,
119// x: &'life1 T,
120// ) -> Pin<Box<dyn Future<Output = Ret> + Send + 'async_trait>>
121// where
122// 'life0: 'async_trait,
123// 'life1: 'async_trait,
124// T: 'async_trait,
125// Self: Sync + 'async_trait;
126fn transform_sig(
127 context: Context,
128 sig: &mut Signature,
129 has_self: bool,
130 has_default: bool,
131 is_local: bool,
132) {
133 sig.fn_token.span = sig.asyncness.take().unwrap().span;
134
135 let ret = match &sig.output {
136 ReturnType::Default => quote!(()),
137 ReturnType::Type(_, ret) => quote!(#ret),
138 };
139
140 let mut lifetimes = CollectLifetimes::new("'life");
141 for arg in sig.inputs.iter_mut() {
142 match arg {
143 FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg),
144 FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty),
145 }
146 }
147
148 let where_clause = sig
149 .generics
150 .where_clause
151 .get_or_insert_with(|| WhereClause {
152 where_token: Default::default(),
153 predicates: Punctuated::new(),
154 });
155 for param in sig
156 .generics
157 .params
158 .iter()
159 .chain(context.lifetimes(&lifetimes.explicit))
160 {
161 match param {
162 GenericParam::Type(param) => {
163 let param = &param.ident;
164 where_clause
165 .predicates
166 .push(parse_quote!(#param: 'async_trait));
167 }
168 GenericParam::Lifetime(param) => {
169 let param = &param.lifetime;
170 where_clause
171 .predicates
172 .push(parse_quote!(#param: 'async_trait));
173 }
174 GenericParam::Const(_) => {}
175 }
176 }
177 for elided in lifetimes.elided {
178 sig.generics.params.push(parse_quote!(#elided));
179 where_clause
180 .predicates
181 .push(parse_quote!(#elided: 'async_trait));
182 }
183 sig.generics.params.push(parse_quote!('async_trait));
184 if has_self {
185 let bound: Ident = match sig.inputs.iter().next() {
186 Some(FnArg::Receiver(Receiver {
187 reference: Some(_),
188 mutability: None,
189 ..
190 })) => parse_quote!(Sync),
191 Some(FnArg::Typed(arg))
192 if match (arg.pat.as_ref(), arg.ty.as_ref()) {
193 (Pat::Ident(pat), Type::Reference(ty)) => {
194 pat.ident == "self" && ty.mutability.is_none()
195 }
196 _ => false,
197 } =>
198 {
199 parse_quote!(Sync)
200 }
201 _ => parse_quote!(Send),
202 };
203 let assume_bound = match context {
204 Context::Trait { supertraits, .. } => !has_default || has_bound(supertraits, &bound),
205 Context::Impl { .. } => true,
206 };
207 where_clause.predicates.push(if assume_bound || is_local {
208 parse_quote!(Self: 'async_trait)
209 } else {
210 parse_quote!(Self: ::core::marker::#bound + 'async_trait)
211 });
212 }
213
214 for (i, arg) in sig.inputs.iter_mut().enumerate() {
215 match arg {
216 FnArg::Receiver(Receiver {
217 reference: Some(_), ..
218 }) => {}
219 FnArg::Receiver(arg) => arg.mutability = None,
220 FnArg::Typed(arg) => {
221 if let Pat::Ident(ident) = &mut *arg.pat {
222 ident.by_ref = None;
223 ident.mutability = None;
224 } else {
225 let positional = positional_arg(i);
226 *arg.pat = parse_quote!(#positional);
227 }
228 }
229 }
230 }
231
232 let bounds = if is_local {
233 quote!('async_trait)
234 } else {
235 quote!(::core::marker::Send + 'async_trait)
236 };
237
238 sig.output = parse_quote! {
239 -> ::core::pin::Pin<Box<
240 dyn ::core::future::Future<Output = #ret> + #bounds
241 >>
242 };
243}
244
245// Input:
246// async fn f<T>(&self, x: &T) -> Ret {
247// self + x
248// }
249//
250// Output:
251// async fn f<T, AsyncTrait>(_self: &AsyncTrait, x: &T) -> Ret {
252// _self + x
253// }
254// Box::pin(async_trait_method::<T, Self>(self, x))
255fn transform_block(
256 context: Context,
257 sig: &mut Signature,
258 block: &mut Block,
259 has_self: bool,
260 is_local: bool,
261) {
262 if let Some(Stmt::Item(syn::Item::Verbatim(item))) = block.stmts.first() {
263 if block.stmts.len() == 1 && item.to_string() == ";" {
264 return;
265 }
266 }
267
268 let inner = format_ident!("__{}", sig.ident);
269 let args = sig.inputs.iter().enumerate().map(|(i, arg)| match arg {
270 FnArg::Receiver(Receiver { self_token, .. }) => quote!(#self_token),
271 FnArg::Typed(arg) => {
272 if let Pat::Ident(PatIdent { ident, .. }) = &*arg.pat {
273 quote!(#ident)
274 } else {
275 positional_arg(i).into_token_stream()
276 }
277 }
278 });
279
280 let mut standalone = sig.clone();
281 standalone.ident = inner.clone();
282
283 let generics = match context {
284 Context::Trait { generics, .. } => generics,
285 Context::Impl { impl_generics, .. } => impl_generics,
286 };
287
288 let mut outer_generics = generics.clone();
Haibo Huang62e9b292020-09-01 20:28:34 -0700289 for p in &mut outer_generics.params {
290 match p {
291 GenericParam::Type(t) => t.default = None,
292 GenericParam::Const(c) => c.default = None,
293 GenericParam::Lifetime(_) => {}
294 }
295 }
Andrew Walbrand1b91c72020-08-11 17:12:08 +0100296 if !has_self {
297 if let Some(mut where_clause) = outer_generics.where_clause {
298 where_clause.predicates = where_clause
299 .predicates
300 .into_iter()
301 .filter_map(|mut pred| {
302 if has_self_in_where_predicate(&mut pred) {
303 None
304 } else {
305 Some(pred)
306 }
307 })
308 .collect();
309 outer_generics.where_clause = Some(where_clause);
310 }
311 }
312
313 let fn_generics = mem::replace(&mut standalone.generics, outer_generics);
314 standalone.generics.params.extend(fn_generics.params);
315 if let Some(where_clause) = fn_generics.where_clause {
316 standalone
317 .generics
318 .make_where_clause()
319 .predicates
320 .extend(where_clause.predicates);
321 }
322
323 if has_async_lifetime(&mut standalone, block) {
324 standalone.generics.params.push(parse_quote!('async_trait));
325 }
326
327 let mut types = standalone
328 .generics
329 .type_params()
330 .map(|param| param.ident.clone())
331 .collect::<Vec<_>>();
332
333 let mut self_bound = None::<TypeParamBound>;
334 match standalone.inputs.iter_mut().next() {
335 Some(
336 arg @ FnArg::Receiver(Receiver {
337 reference: Some(_), ..
338 }),
339 ) => {
340 let (lifetime, mutability, self_token) = match arg {
341 FnArg::Receiver(Receiver {
342 reference: Some((_, lifetime)),
343 mutability,
344 self_token,
345 ..
346 }) => (lifetime, mutability, self_token),
347 _ => unreachable!(),
348 };
349 let under_self = Ident::new("_self", self_token.span);
350 match context {
351 Context::Trait { .. } => {
352 self_bound = Some(match mutability {
353 Some(_) => parse_quote!(::core::marker::Send),
354 None => parse_quote!(::core::marker::Sync),
355 });
356 *arg = parse_quote! {
357 #under_self: &#lifetime #mutability AsyncTrait
358 };
359 }
360 Context::Impl { receiver, .. } => {
361 let mut ty = quote!(#receiver);
362 if let Type::TraitObject(trait_object) = receiver {
363 if trait_object.dyn_token.is_none() {
364 ty = quote!(dyn #ty);
365 }
366 if trait_object.bounds.len() > 1 {
367 ty = quote!((#ty));
368 }
369 }
370 *arg = parse_quote! {
371 #under_self: &#lifetime #mutability #ty
372 };
373 }
374 }
375 }
376 Some(arg @ FnArg::Receiver(_)) => {
377 let (self_token, mutability) = match arg {
378 FnArg::Receiver(Receiver {
379 self_token,
380 mutability,
381 ..
382 }) => (self_token, mutability),
383 _ => unreachable!(),
384 };
385 let under_self = Ident::new("_self", self_token.span);
386 match context {
387 Context::Trait { .. } => {
388 self_bound = Some(parse_quote!(::core::marker::Send));
389 *arg = parse_quote! {
390 #mutability #under_self: AsyncTrait
391 };
392 }
393 Context::Impl { receiver, .. } => {
394 *arg = parse_quote! {
395 #mutability #under_self: #receiver
396 };
397 }
398 }
399 }
400 Some(FnArg::Typed(arg)) => {
401 if let Pat::Ident(arg) = &mut *arg.pat {
402 if arg.ident == "self" {
403 arg.ident = Ident::new("_self", arg.ident.span());
404 }
405 }
406 }
407 _ => {}
408 }
409
410 if let Context::Trait { name, generics, .. } = context {
411 if has_self {
412 let (_, generics, _) = generics.split_for_impl();
413 let mut self_param: TypeParam = parse_quote!(AsyncTrait: ?Sized + #name #generics);
414 if !is_local {
415 self_param.bounds.extend(self_bound);
416 }
417 standalone
418 .generics
419 .params
420 .push(GenericParam::Type(self_param));
421 types.push(Ident::new("Self", Span::call_site()));
422 }
423 }
424
425 if let Some(where_clause) = &mut standalone.generics.where_clause {
426 // Work around an input bound like `where Self::Output: Send` expanding
427 // to `where <AsyncTrait>::Output: Send` which is illegal syntax because
428 // `where<T>` is reserved for future use... :(
429 where_clause.predicates.insert(0, parse_quote!((): Sized));
430 }
431
432 let mut replace = match context {
433 Context::Trait { .. } => ReplaceReceiver::with(parse_quote!(AsyncTrait)),
434 Context::Impl {
435 receiver, as_trait, ..
436 } => ReplaceReceiver::with_as_trait(receiver.clone(), as_trait.clone()),
437 };
438 replace.visit_signature_mut(&mut standalone);
439 replace.visit_block_mut(block);
440
441 let mut generics = types;
442 let consts = standalone
443 .generics
444 .const_params()
445 .map(|param| param.ident.clone());
446 generics.extend(consts);
447
448 let allow_non_snake_case = if sig.ident != sig.ident.to_string().to_lowercase() {
449 Some(quote!(non_snake_case,))
450 } else {
451 None
452 };
453
454 let brace = block.brace_token;
455 let box_pin = quote_spanned!(brace.span=> {
456 #[allow(
457 #allow_non_snake_case
Haibo Huangd8abf3d2020-08-17 15:39:53 -0700458 unused_parens, // https://github.com/dtolnay/async-trait/issues/118
Andrew Walbrand1b91c72020-08-11 17:12:08 +0100459 clippy::missing_docs_in_private_items,
460 clippy::needless_lifetimes,
461 clippy::ptr_arg,
Haibo Huangd8abf3d2020-08-17 15:39:53 -0700462 clippy::trivially_copy_pass_by_ref,
Andrew Walbrand1b91c72020-08-11 17:12:08 +0100463 clippy::type_repetition_in_bounds,
464 clippy::used_underscore_binding,
465 )]
466 #standalone #block
467 Box::pin(#inner::<#(#generics),*>(#(#args),*))
468 });
469 *block = parse_quote!(#box_pin);
470 block.brace_token = brace;
471}
472
473fn positional_arg(i: usize) -> Ident {
474 format_ident!("__arg{}", i)
475}
476
477fn has_bound(supertraits: &Supertraits, marker: &Ident) -> bool {
478 for bound in supertraits {
479 if let TypeParamBound::Trait(bound) = bound {
480 if bound.path.is_ident(marker) {
481 return true;
482 }
483 }
484 }
485 false
486}