blob: f868c7dfa51a5044b84a5c76819c0310e1879999 [file] [log] [blame]
Andrew Walbrand1b91c72020-08-11 17:12:08 +01001use crate::lifetime::{has_async_lifetime, CollectLifetimes};
2use crate::parse::Item;
3use crate::receiver::{
4 has_self_in_block, has_self_in_sig, has_self_in_where_predicate, ReplaceReceiver,
5};
6use proc_macro2::{Span, TokenStream};
7use quote::{format_ident, quote, quote_spanned, ToTokens};
8use std::mem;
9use syn::punctuated::Punctuated;
10use syn::visit_mut::VisitMut;
11use syn::{
12 parse_quote, Block, FnArg, GenericParam, Generics, Ident, ImplItem, Lifetime, Pat, PatIdent,
13 Path, Receiver, ReturnType, Signature, Stmt, Token, TraitItem, Type, TypeParam, TypeParamBound,
14 WhereClause,
15};
16
17impl ToTokens for Item {
18 fn to_tokens(&self, tokens: &mut TokenStream) {
19 match self {
20 Item::Trait(item) => item.to_tokens(tokens),
21 Item::Impl(item) => item.to_tokens(tokens),
22 }
23 }
24}
25
26#[derive(Clone, Copy)]
27enum Context<'a> {
28 Trait {
29 name: &'a Ident,
30 generics: &'a Generics,
31 supertraits: &'a Supertraits,
32 },
33 Impl {
34 impl_generics: &'a Generics,
35 receiver: &'a Type,
36 as_trait: &'a Path,
37 },
38}
39
40impl Context<'_> {
41 fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a GenericParam> {
42 let generics = match self {
43 Context::Trait { generics, .. } => generics,
44 Context::Impl { impl_generics, .. } => impl_generics,
45 };
46 generics.params.iter().filter(move |param| {
47 if let GenericParam::Lifetime(param) = param {
48 used.contains(&param.lifetime)
49 } else {
50 false
51 }
52 })
53 }
54}
55
56type Supertraits = Punctuated<TypeParamBound, Token![+]>;
57
58pub fn expand(input: &mut Item, is_local: bool) {
59 match input {
60 Item::Trait(input) => {
61 let context = Context::Trait {
62 name: &input.ident,
63 generics: &input.generics,
64 supertraits: &input.supertraits,
65 };
66 for inner in &mut input.items {
67 if let TraitItem::Method(method) = inner {
68 let sig = &mut method.sig;
69 if sig.asyncness.is_some() {
70 let block = &mut method.default;
71 let mut has_self = has_self_in_sig(sig);
72 if let Some(block) = block {
73 has_self |= has_self_in_block(block);
74 transform_block(context, sig, block, has_self, is_local);
75 }
76 let has_default = method.default.is_some();
77 transform_sig(context, sig, has_self, has_default, is_local);
78 method.attrs.push(parse_quote!(#[must_use]));
79 }
80 }
81 }
82 }
83 Item::Impl(input) => {
84 let mut lifetimes = CollectLifetimes::new("'impl");
85 lifetimes.visit_type_mut(&mut *input.self_ty);
86 lifetimes.visit_path_mut(&mut input.trait_.as_mut().unwrap().1);
87 let params = &input.generics.params;
88 let elided = lifetimes.elided;
89 input.generics.params = parse_quote!(#(#elided,)* #params);
90
91 let context = Context::Impl {
92 impl_generics: &input.generics,
93 receiver: &input.self_ty,
94 as_trait: &input.trait_.as_ref().unwrap().1,
95 };
96 for inner in &mut input.items {
97 if let ImplItem::Method(method) = inner {
98 let sig = &mut method.sig;
99 if sig.asyncness.is_some() {
100 let block = &mut method.block;
101 let has_self = has_self_in_sig(sig) || has_self_in_block(block);
102 transform_block(context, sig, block, has_self, is_local);
103 transform_sig(context, sig, has_self, false, is_local);
104 }
105 }
106 }
107 }
108 }
109}
110
111// Input:
112// async fn f<T>(&self, x: &T) -> Ret;
113//
114// Output:
115// fn f<'life0, 'life1, 'async_trait, T>(
116// &'life0 self,
117// x: &'life1 T,
118// ) -> Pin<Box<dyn Future<Output = Ret> + Send + 'async_trait>>
119// where
120// 'life0: 'async_trait,
121// 'life1: 'async_trait,
122// T: 'async_trait,
123// Self: Sync + 'async_trait;
124fn transform_sig(
125 context: Context,
126 sig: &mut Signature,
127 has_self: bool,
128 has_default: bool,
129 is_local: bool,
130) {
131 sig.fn_token.span = sig.asyncness.take().unwrap().span;
132
133 let ret = match &sig.output {
134 ReturnType::Default => quote!(()),
135 ReturnType::Type(_, ret) => quote!(#ret),
136 };
137
138 let mut lifetimes = CollectLifetimes::new("'life");
139 for arg in sig.inputs.iter_mut() {
140 match arg {
141 FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg),
142 FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty),
143 }
144 }
145
146 let where_clause = sig
147 .generics
148 .where_clause
149 .get_or_insert_with(|| WhereClause {
150 where_token: Default::default(),
151 predicates: Punctuated::new(),
152 });
153 for param in sig
154 .generics
155 .params
156 .iter()
157 .chain(context.lifetimes(&lifetimes.explicit))
158 {
159 match param {
160 GenericParam::Type(param) => {
161 let param = &param.ident;
162 where_clause
163 .predicates
164 .push(parse_quote!(#param: 'async_trait));
165 }
166 GenericParam::Lifetime(param) => {
167 let param = &param.lifetime;
168 where_clause
169 .predicates
170 .push(parse_quote!(#param: 'async_trait));
171 }
172 GenericParam::Const(_) => {}
173 }
174 }
175 for elided in lifetimes.elided {
176 sig.generics.params.push(parse_quote!(#elided));
177 where_clause
178 .predicates
179 .push(parse_quote!(#elided: 'async_trait));
180 }
181 sig.generics.params.push(parse_quote!('async_trait));
182 if has_self {
183 let bound: Ident = match sig.inputs.iter().next() {
184 Some(FnArg::Receiver(Receiver {
185 reference: Some(_),
186 mutability: None,
187 ..
188 })) => parse_quote!(Sync),
189 Some(FnArg::Typed(arg))
190 if match (arg.pat.as_ref(), arg.ty.as_ref()) {
191 (Pat::Ident(pat), Type::Reference(ty)) => {
192 pat.ident == "self" && ty.mutability.is_none()
193 }
194 _ => false,
195 } =>
196 {
197 parse_quote!(Sync)
198 }
199 _ => parse_quote!(Send),
200 };
201 let assume_bound = match context {
202 Context::Trait { supertraits, .. } => !has_default || has_bound(supertraits, &bound),
203 Context::Impl { .. } => true,
204 };
205 where_clause.predicates.push(if assume_bound || is_local {
206 parse_quote!(Self: 'async_trait)
207 } else {
208 parse_quote!(Self: ::core::marker::#bound + 'async_trait)
209 });
210 }
211
212 for (i, arg) in sig.inputs.iter_mut().enumerate() {
213 match arg {
214 FnArg::Receiver(Receiver {
215 reference: Some(_), ..
216 }) => {}
217 FnArg::Receiver(arg) => arg.mutability = None,
218 FnArg::Typed(arg) => {
219 if let Pat::Ident(ident) = &mut *arg.pat {
220 ident.by_ref = None;
221 ident.mutability = None;
222 } else {
223 let positional = positional_arg(i);
224 *arg.pat = parse_quote!(#positional);
225 }
226 }
227 }
228 }
229
230 let bounds = if is_local {
231 quote!('async_trait)
232 } else {
233 quote!(::core::marker::Send + 'async_trait)
234 };
235
236 sig.output = parse_quote! {
237 -> ::core::pin::Pin<Box<
238 dyn ::core::future::Future<Output = #ret> + #bounds
239 >>
240 };
241}
242
243// Input:
244// async fn f<T>(&self, x: &T) -> Ret {
245// self + x
246// }
247//
248// Output:
249// async fn f<T, AsyncTrait>(_self: &AsyncTrait, x: &T) -> Ret {
250// _self + x
251// }
252// Box::pin(async_trait_method::<T, Self>(self, x))
253fn transform_block(
254 context: Context,
255 sig: &mut Signature,
256 block: &mut Block,
257 has_self: bool,
258 is_local: bool,
259) {
260 if let Some(Stmt::Item(syn::Item::Verbatim(item))) = block.stmts.first() {
261 if block.stmts.len() == 1 && item.to_string() == ";" {
262 return;
263 }
264 }
265
266 let inner = format_ident!("__{}", sig.ident);
267 let args = sig.inputs.iter().enumerate().map(|(i, arg)| match arg {
268 FnArg::Receiver(Receiver { self_token, .. }) => quote!(#self_token),
269 FnArg::Typed(arg) => {
270 if let Pat::Ident(PatIdent { ident, .. }) = &*arg.pat {
271 quote!(#ident)
272 } else {
273 positional_arg(i).into_token_stream()
274 }
275 }
276 });
277
278 let mut standalone = sig.clone();
279 standalone.ident = inner.clone();
280
281 let generics = match context {
282 Context::Trait { generics, .. } => generics,
283 Context::Impl { impl_generics, .. } => impl_generics,
284 };
285
286 let mut outer_generics = generics.clone();
Haibo Huang62e9b292020-09-01 20:28:34 -0700287 for p in &mut outer_generics.params {
288 match p {
289 GenericParam::Type(t) => t.default = None,
290 GenericParam::Const(c) => c.default = None,
291 GenericParam::Lifetime(_) => {}
292 }
293 }
Andrew Walbrand1b91c72020-08-11 17:12:08 +0100294 if !has_self {
295 if let Some(mut where_clause) = outer_generics.where_clause {
296 where_clause.predicates = where_clause
297 .predicates
298 .into_iter()
299 .filter_map(|mut pred| {
300 if has_self_in_where_predicate(&mut pred) {
301 None
302 } else {
303 Some(pred)
304 }
305 })
306 .collect();
307 outer_generics.where_clause = Some(where_clause);
308 }
309 }
310
311 let fn_generics = mem::replace(&mut standalone.generics, outer_generics);
312 standalone.generics.params.extend(fn_generics.params);
313 if let Some(where_clause) = fn_generics.where_clause {
314 standalone
315 .generics
316 .make_where_clause()
317 .predicates
318 .extend(where_clause.predicates);
319 }
320
321 if has_async_lifetime(&mut standalone, block) {
322 standalone.generics.params.push(parse_quote!('async_trait));
323 }
324
325 let mut types = standalone
326 .generics
327 .type_params()
328 .map(|param| param.ident.clone())
329 .collect::<Vec<_>>();
330
331 let mut self_bound = None::<TypeParamBound>;
332 match standalone.inputs.iter_mut().next() {
333 Some(
334 arg @ FnArg::Receiver(Receiver {
335 reference: Some(_), ..
336 }),
337 ) => {
338 let (lifetime, mutability, self_token) = match arg {
339 FnArg::Receiver(Receiver {
340 reference: Some((_, lifetime)),
341 mutability,
342 self_token,
343 ..
344 }) => (lifetime, mutability, self_token),
345 _ => unreachable!(),
346 };
347 let under_self = Ident::new("_self", self_token.span);
348 match context {
349 Context::Trait { .. } => {
350 self_bound = Some(match mutability {
351 Some(_) => parse_quote!(::core::marker::Send),
352 None => parse_quote!(::core::marker::Sync),
353 });
354 *arg = parse_quote! {
355 #under_self: &#lifetime #mutability AsyncTrait
356 };
357 }
358 Context::Impl { receiver, .. } => {
359 let mut ty = quote!(#receiver);
360 if let Type::TraitObject(trait_object) = receiver {
361 if trait_object.dyn_token.is_none() {
362 ty = quote!(dyn #ty);
363 }
364 if trait_object.bounds.len() > 1 {
365 ty = quote!((#ty));
366 }
367 }
368 *arg = parse_quote! {
369 #under_self: &#lifetime #mutability #ty
370 };
371 }
372 }
373 }
374 Some(arg @ FnArg::Receiver(_)) => {
375 let (self_token, mutability) = match arg {
376 FnArg::Receiver(Receiver {
377 self_token,
378 mutability,
379 ..
380 }) => (self_token, mutability),
381 _ => unreachable!(),
382 };
383 let under_self = Ident::new("_self", self_token.span);
384 match context {
385 Context::Trait { .. } => {
386 self_bound = Some(parse_quote!(::core::marker::Send));
387 *arg = parse_quote! {
388 #mutability #under_self: AsyncTrait
389 };
390 }
391 Context::Impl { receiver, .. } => {
392 *arg = parse_quote! {
393 #mutability #under_self: #receiver
394 };
395 }
396 }
397 }
398 Some(FnArg::Typed(arg)) => {
399 if let Pat::Ident(arg) = &mut *arg.pat {
400 if arg.ident == "self" {
401 arg.ident = Ident::new("_self", arg.ident.span());
402 }
403 }
404 }
405 _ => {}
406 }
407
408 if let Context::Trait { name, generics, .. } = context {
409 if has_self {
410 let (_, generics, _) = generics.split_for_impl();
411 let mut self_param: TypeParam = parse_quote!(AsyncTrait: ?Sized + #name #generics);
412 if !is_local {
413 self_param.bounds.extend(self_bound);
414 }
415 standalone
416 .generics
417 .params
418 .push(GenericParam::Type(self_param));
419 types.push(Ident::new("Self", Span::call_site()));
420 }
421 }
422
423 if let Some(where_clause) = &mut standalone.generics.where_clause {
424 // Work around an input bound like `where Self::Output: Send` expanding
425 // to `where <AsyncTrait>::Output: Send` which is illegal syntax because
426 // `where<T>` is reserved for future use... :(
427 where_clause.predicates.insert(0, parse_quote!((): Sized));
428 }
429
430 let mut replace = match context {
431 Context::Trait { .. } => ReplaceReceiver::with(parse_quote!(AsyncTrait)),
432 Context::Impl {
433 receiver, as_trait, ..
434 } => ReplaceReceiver::with_as_trait(receiver.clone(), as_trait.clone()),
435 };
436 replace.visit_signature_mut(&mut standalone);
437 replace.visit_block_mut(block);
438
439 let mut generics = types;
440 let consts = standalone
441 .generics
442 .const_params()
443 .map(|param| param.ident.clone());
444 generics.extend(consts);
445
446 let allow_non_snake_case = if sig.ident != sig.ident.to_string().to_lowercase() {
447 Some(quote!(non_snake_case,))
448 } else {
449 None
450 };
451
452 let brace = block.brace_token;
453 let box_pin = quote_spanned!(brace.span=> {
454 #[allow(
455 #allow_non_snake_case
Haibo Huangd8abf3d2020-08-17 15:39:53 -0700456 unused_parens, // https://github.com/dtolnay/async-trait/issues/118
Andrew Walbrand1b91c72020-08-11 17:12:08 +0100457 clippy::missing_docs_in_private_items,
458 clippy::needless_lifetimes,
459 clippy::ptr_arg,
Haibo Huangd8abf3d2020-08-17 15:39:53 -0700460 clippy::trivially_copy_pass_by_ref,
Andrew Walbrand1b91c72020-08-11 17:12:08 +0100461 clippy::type_repetition_in_bounds,
462 clippy::used_underscore_binding,
463 )]
464 #standalone #block
465 Box::pin(#inner::<#(#generics),*>(#(#args),*))
466 });
467 *block = parse_quote!(#box_pin);
468 block.brace_token = brace;
469}
470
471fn positional_arg(i: usize) -> Ident {
472 format_ident!("__arg{}", i)
473}
474
475fn has_bound(supertraits: &Supertraits, marker: &Ident) -> bool {
476 for bound in supertraits {
477 if let TypeParamBound::Trait(bound) = bound {
478 if bound.path.is_ident(marker) {
479 return true;
480 }
481 }
482 }
483 false
484}