blob: fb83df12ec0f6f7c476db64c8223a89d4fed9c6d [file] [log] [blame]
Andrew Walbrand1b91c72020-08-11 17:12:08 +01001use crate::lifetime::{has_async_lifetime, CollectLifetimes};
2use crate::parse::Item;
3use crate::receiver::{
4 has_self_in_block, has_self_in_sig, has_self_in_where_predicate, ReplaceReceiver,
5};
6use proc_macro2::{Span, TokenStream};
7use quote::{format_ident, quote, quote_spanned, ToTokens};
8use std::mem;
9use syn::punctuated::Punctuated;
10use syn::visit_mut::VisitMut;
11use syn::{
12 parse_quote, Block, FnArg, GenericParam, Generics, Ident, ImplItem, Lifetime, Pat, PatIdent,
13 Path, Receiver, ReturnType, Signature, Stmt, Token, TraitItem, Type, TypeParam, TypeParamBound,
14 WhereClause,
15};
16
17impl ToTokens for Item {
18 fn to_tokens(&self, tokens: &mut TokenStream) {
19 match self {
20 Item::Trait(item) => item.to_tokens(tokens),
21 Item::Impl(item) => item.to_tokens(tokens),
22 }
23 }
24}
25
26#[derive(Clone, Copy)]
27enum Context<'a> {
28 Trait {
29 name: &'a Ident,
30 generics: &'a Generics,
31 supertraits: &'a Supertraits,
32 },
33 Impl {
34 impl_generics: &'a Generics,
35 receiver: &'a Type,
36 as_trait: &'a Path,
37 },
38}
39
40impl Context<'_> {
41 fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a GenericParam> {
42 let generics = match self {
43 Context::Trait { generics, .. } => generics,
44 Context::Impl { impl_generics, .. } => impl_generics,
45 };
46 generics.params.iter().filter(move |param| {
47 if let GenericParam::Lifetime(param) = param {
48 used.contains(&param.lifetime)
49 } else {
50 false
51 }
52 })
53 }
54}
55
56type Supertraits = Punctuated<TypeParamBound, Token![+]>;
57
58pub fn expand(input: &mut Item, is_local: bool) {
59 match input {
60 Item::Trait(input) => {
61 let context = Context::Trait {
62 name: &input.ident,
63 generics: &input.generics,
64 supertraits: &input.supertraits,
65 };
66 for inner in &mut input.items {
67 if let TraitItem::Method(method) = inner {
68 let sig = &mut method.sig;
69 if sig.asyncness.is_some() {
70 let block = &mut method.default;
71 let mut has_self = has_self_in_sig(sig);
72 if let Some(block) = block {
73 has_self |= has_self_in_block(block);
74 transform_block(context, sig, block, has_self, is_local);
Haibo Huang7ff84c32020-11-24 20:53:12 -080075 method
76 .attrs
77 .push(parse_quote!(#[allow(clippy::used_underscore_binding)]));
Andrew Walbrand1b91c72020-08-11 17:12:08 +010078 }
79 let has_default = method.default.is_some();
80 transform_sig(context, sig, has_self, has_default, is_local);
81 method.attrs.push(parse_quote!(#[must_use]));
82 }
83 }
84 }
85 }
86 Item::Impl(input) => {
87 let mut lifetimes = CollectLifetimes::new("'impl");
88 lifetimes.visit_type_mut(&mut *input.self_ty);
89 lifetimes.visit_path_mut(&mut input.trait_.as_mut().unwrap().1);
90 let params = &input.generics.params;
91 let elided = lifetimes.elided;
92 input.generics.params = parse_quote!(#(#elided,)* #params);
93
94 let context = Context::Impl {
95 impl_generics: &input.generics,
96 receiver: &input.self_ty,
97 as_trait: &input.trait_.as_ref().unwrap().1,
98 };
99 for inner in &mut input.items {
100 if let ImplItem::Method(method) = inner {
101 let sig = &mut method.sig;
102 if sig.asyncness.is_some() {
103 let block = &mut method.block;
104 let has_self = has_self_in_sig(sig) || has_self_in_block(block);
105 transform_block(context, sig, block, has_self, is_local);
106 transform_sig(context, sig, has_self, false, is_local);
Haibo Huang7ff84c32020-11-24 20:53:12 -0800107 method
108 .attrs
109 .push(parse_quote!(#[allow(clippy::used_underscore_binding)]));
Andrew Walbrand1b91c72020-08-11 17:12:08 +0100110 }
111 }
112 }
113 }
114 }
115}
116
117// Input:
118// async fn f<T>(&self, x: &T) -> Ret;
119//
120// Output:
121// fn f<'life0, 'life1, 'async_trait, T>(
122// &'life0 self,
123// x: &'life1 T,
124// ) -> Pin<Box<dyn Future<Output = Ret> + Send + 'async_trait>>
125// where
126// 'life0: 'async_trait,
127// 'life1: 'async_trait,
128// T: 'async_trait,
129// Self: Sync + 'async_trait;
130fn transform_sig(
131 context: Context,
132 sig: &mut Signature,
133 has_self: bool,
134 has_default: bool,
135 is_local: bool,
136) {
137 sig.fn_token.span = sig.asyncness.take().unwrap().span;
138
139 let ret = match &sig.output {
140 ReturnType::Default => quote!(()),
141 ReturnType::Type(_, ret) => quote!(#ret),
142 };
143
144 let mut lifetimes = CollectLifetimes::new("'life");
145 for arg in sig.inputs.iter_mut() {
146 match arg {
147 FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg),
148 FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty),
149 }
150 }
151
152 let where_clause = sig
153 .generics
154 .where_clause
155 .get_or_insert_with(|| WhereClause {
156 where_token: Default::default(),
157 predicates: Punctuated::new(),
158 });
159 for param in sig
160 .generics
161 .params
162 .iter()
163 .chain(context.lifetimes(&lifetimes.explicit))
164 {
165 match param {
166 GenericParam::Type(param) => {
167 let param = &param.ident;
168 where_clause
169 .predicates
170 .push(parse_quote!(#param: 'async_trait));
171 }
172 GenericParam::Lifetime(param) => {
173 let param = &param.lifetime;
174 where_clause
175 .predicates
176 .push(parse_quote!(#param: 'async_trait));
177 }
178 GenericParam::Const(_) => {}
179 }
180 }
181 for elided in lifetimes.elided {
182 sig.generics.params.push(parse_quote!(#elided));
183 where_clause
184 .predicates
185 .push(parse_quote!(#elided: 'async_trait));
186 }
187 sig.generics.params.push(parse_quote!('async_trait));
188 if has_self {
189 let bound: Ident = match sig.inputs.iter().next() {
190 Some(FnArg::Receiver(Receiver {
191 reference: Some(_),
192 mutability: None,
193 ..
194 })) => parse_quote!(Sync),
195 Some(FnArg::Typed(arg))
196 if match (arg.pat.as_ref(), arg.ty.as_ref()) {
197 (Pat::Ident(pat), Type::Reference(ty)) => {
198 pat.ident == "self" && ty.mutability.is_none()
199 }
200 _ => false,
201 } =>
202 {
203 parse_quote!(Sync)
204 }
205 _ => parse_quote!(Send),
206 };
207 let assume_bound = match context {
208 Context::Trait { supertraits, .. } => !has_default || has_bound(supertraits, &bound),
209 Context::Impl { .. } => true,
210 };
211 where_clause.predicates.push(if assume_bound || is_local {
212 parse_quote!(Self: 'async_trait)
213 } else {
214 parse_quote!(Self: ::core::marker::#bound + 'async_trait)
215 });
216 }
217
218 for (i, arg) in sig.inputs.iter_mut().enumerate() {
219 match arg {
220 FnArg::Receiver(Receiver {
221 reference: Some(_), ..
222 }) => {}
223 FnArg::Receiver(arg) => arg.mutability = None,
224 FnArg::Typed(arg) => {
225 if let Pat::Ident(ident) = &mut *arg.pat {
226 ident.by_ref = None;
227 ident.mutability = None;
228 } else {
229 let positional = positional_arg(i);
230 *arg.pat = parse_quote!(#positional);
231 }
232 }
233 }
234 }
235
236 let bounds = if is_local {
237 quote!('async_trait)
238 } else {
239 quote!(::core::marker::Send + 'async_trait)
240 };
241
242 sig.output = parse_quote! {
243 -> ::core::pin::Pin<Box<
244 dyn ::core::future::Future<Output = #ret> + #bounds
245 >>
246 };
247}
248
249// Input:
250// async fn f<T>(&self, x: &T) -> Ret {
251// self + x
252// }
253//
254// Output:
255// async fn f<T, AsyncTrait>(_self: &AsyncTrait, x: &T) -> Ret {
256// _self + x
257// }
258// Box::pin(async_trait_method::<T, Self>(self, x))
259fn transform_block(
260 context: Context,
261 sig: &mut Signature,
262 block: &mut Block,
263 has_self: bool,
264 is_local: bool,
265) {
266 if let Some(Stmt::Item(syn::Item::Verbatim(item))) = block.stmts.first() {
267 if block.stmts.len() == 1 && item.to_string() == ";" {
268 return;
269 }
270 }
271
272 let inner = format_ident!("__{}", sig.ident);
273 let args = sig.inputs.iter().enumerate().map(|(i, arg)| match arg {
274 FnArg::Receiver(Receiver { self_token, .. }) => quote!(#self_token),
275 FnArg::Typed(arg) => {
276 if let Pat::Ident(PatIdent { ident, .. }) = &*arg.pat {
277 quote!(#ident)
278 } else {
279 positional_arg(i).into_token_stream()
280 }
281 }
282 });
283
284 let mut standalone = sig.clone();
285 standalone.ident = inner.clone();
286
287 let generics = match context {
288 Context::Trait { generics, .. } => generics,
289 Context::Impl { impl_generics, .. } => impl_generics,
290 };
291
292 let mut outer_generics = generics.clone();
Haibo Huang62e9b292020-09-01 20:28:34 -0700293 for p in &mut outer_generics.params {
294 match p {
295 GenericParam::Type(t) => t.default = None,
296 GenericParam::Const(c) => c.default = None,
297 GenericParam::Lifetime(_) => {}
298 }
299 }
Andrew Walbrand1b91c72020-08-11 17:12:08 +0100300 if !has_self {
301 if let Some(mut where_clause) = outer_generics.where_clause {
302 where_clause.predicates = where_clause
303 .predicates
304 .into_iter()
305 .filter_map(|mut pred| {
306 if has_self_in_where_predicate(&mut pred) {
307 None
308 } else {
309 Some(pred)
310 }
311 })
312 .collect();
313 outer_generics.where_clause = Some(where_clause);
314 }
315 }
316
317 let fn_generics = mem::replace(&mut standalone.generics, outer_generics);
318 standalone.generics.params.extend(fn_generics.params);
319 if let Some(where_clause) = fn_generics.where_clause {
320 standalone
321 .generics
322 .make_where_clause()
323 .predicates
324 .extend(where_clause.predicates);
325 }
326
327 if has_async_lifetime(&mut standalone, block) {
328 standalone.generics.params.push(parse_quote!('async_trait));
329 }
330
331 let mut types = standalone
332 .generics
333 .type_params()
334 .map(|param| param.ident.clone())
335 .collect::<Vec<_>>();
336
337 let mut self_bound = None::<TypeParamBound>;
338 match standalone.inputs.iter_mut().next() {
339 Some(
340 arg @ FnArg::Receiver(Receiver {
341 reference: Some(_), ..
342 }),
343 ) => {
344 let (lifetime, mutability, self_token) = match arg {
345 FnArg::Receiver(Receiver {
346 reference: Some((_, lifetime)),
347 mutability,
348 self_token,
349 ..
350 }) => (lifetime, mutability, self_token),
351 _ => unreachable!(),
352 };
353 let under_self = Ident::new("_self", self_token.span);
354 match context {
355 Context::Trait { .. } => {
356 self_bound = Some(match mutability {
357 Some(_) => parse_quote!(::core::marker::Send),
358 None => parse_quote!(::core::marker::Sync),
359 });
360 *arg = parse_quote! {
361 #under_self: &#lifetime #mutability AsyncTrait
362 };
363 }
364 Context::Impl { receiver, .. } => {
365 let mut ty = quote!(#receiver);
366 if let Type::TraitObject(trait_object) = receiver {
367 if trait_object.dyn_token.is_none() {
368 ty = quote!(dyn #ty);
369 }
370 if trait_object.bounds.len() > 1 {
371 ty = quote!((#ty));
372 }
373 }
374 *arg = parse_quote! {
375 #under_self: &#lifetime #mutability #ty
376 };
377 }
378 }
379 }
380 Some(arg @ FnArg::Receiver(_)) => {
381 let (self_token, mutability) = match arg {
382 FnArg::Receiver(Receiver {
383 self_token,
384 mutability,
385 ..
386 }) => (self_token, mutability),
387 _ => unreachable!(),
388 };
389 let under_self = Ident::new("_self", self_token.span);
390 match context {
391 Context::Trait { .. } => {
392 self_bound = Some(parse_quote!(::core::marker::Send));
393 *arg = parse_quote! {
394 #mutability #under_self: AsyncTrait
395 };
396 }
397 Context::Impl { receiver, .. } => {
398 *arg = parse_quote! {
399 #mutability #under_self: #receiver
400 };
401 }
402 }
403 }
404 Some(FnArg::Typed(arg)) => {
405 if let Pat::Ident(arg) = &mut *arg.pat {
406 if arg.ident == "self" {
407 arg.ident = Ident::new("_self", arg.ident.span());
408 }
409 }
410 }
411 _ => {}
412 }
413
414 if let Context::Trait { name, generics, .. } = context {
415 if has_self {
416 let (_, generics, _) = generics.split_for_impl();
417 let mut self_param: TypeParam = parse_quote!(AsyncTrait: ?Sized + #name #generics);
418 if !is_local {
419 self_param.bounds.extend(self_bound);
420 }
Haibo Huang7ff84c32020-11-24 20:53:12 -0800421 let count = standalone
422 .generics
423 .params
424 .iter()
425 .take_while(|param| {
426 if let GenericParam::Const(_) = param {
427 false
428 } else {
429 true
430 }
431 })
432 .count();
Andrew Walbrand1b91c72020-08-11 17:12:08 +0100433 standalone
434 .generics
435 .params
Haibo Huang7ff84c32020-11-24 20:53:12 -0800436 .insert(count, GenericParam::Type(self_param));
Andrew Walbrand1b91c72020-08-11 17:12:08 +0100437 types.push(Ident::new("Self", Span::call_site()));
438 }
439 }
440
441 if let Some(where_clause) = &mut standalone.generics.where_clause {
442 // Work around an input bound like `where Self::Output: Send` expanding
443 // to `where <AsyncTrait>::Output: Send` which is illegal syntax because
444 // `where<T>` is reserved for future use... :(
445 where_clause.predicates.insert(0, parse_quote!((): Sized));
446 }
447
448 let mut replace = match context {
449 Context::Trait { .. } => ReplaceReceiver::with(parse_quote!(AsyncTrait)),
450 Context::Impl {
451 receiver, as_trait, ..
452 } => ReplaceReceiver::with_as_trait(receiver.clone(), as_trait.clone()),
453 };
454 replace.visit_signature_mut(&mut standalone);
455 replace.visit_block_mut(block);
456
457 let mut generics = types;
458 let consts = standalone
459 .generics
460 .const_params()
461 .map(|param| param.ident.clone());
462 generics.extend(consts);
463
464 let allow_non_snake_case = if sig.ident != sig.ident.to_string().to_lowercase() {
465 Some(quote!(non_snake_case,))
466 } else {
467 None
468 };
469
470 let brace = block.brace_token;
471 let box_pin = quote_spanned!(brace.span=> {
472 #[allow(
473 #allow_non_snake_case
Haibo Huangd8abf3d2020-08-17 15:39:53 -0700474 unused_parens, // https://github.com/dtolnay/async-trait/issues/118
Andrew Walbrand1b91c72020-08-11 17:12:08 +0100475 clippy::missing_docs_in_private_items,
476 clippy::needless_lifetimes,
477 clippy::ptr_arg,
Haibo Huangd8abf3d2020-08-17 15:39:53 -0700478 clippy::trivially_copy_pass_by_ref,
Andrew Walbrand1b91c72020-08-11 17:12:08 +0100479 clippy::type_repetition_in_bounds,
480 clippy::used_underscore_binding,
481 )]
482 #standalone #block
483 Box::pin(#inner::<#(#generics),*>(#(#args),*))
484 });
485 *block = parse_quote!(#box_pin);
486 block.brace_token = brace;
487}
488
489fn positional_arg(i: usize) -> Ident {
490 format_ident!("__arg{}", i)
491}
492
493fn has_bound(supertraits: &Supertraits, marker: &Ident) -> bool {
494 for bound in supertraits {
495 if let TypeParamBound::Trait(bound) = bound {
496 if bound.path.is_ident(marker) {
497 return true;
498 }
499 }
500 }
501 false
502}