blob: a123d4e5205c90adb61b17bc40301ae11b479e5a [file] [log] [blame]
Andrew Walbrandc8aa4b2022-01-13 12:20:07 +00001#![recursion_limit = "256"]
2// Copyright (c) 2020 Google LLC All rights reserved.
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file.
5
6/// Implementation of the `FromArgs` and `argh(...)` derive attributes.
7///
8/// For more thorough documentation, see the `argh` crate itself.
9extern crate proc_macro;
10
11use {
12 crate::{
13 errors::Errors,
14 parse_attrs::{FieldAttrs, FieldKind, TypeAttrs},
15 },
16 proc_macro2::{Span, TokenStream},
17 quote::{quote, quote_spanned, ToTokens},
18 std::str::FromStr,
19 syn::{spanned::Spanned, LitStr},
20};
21
22mod errors;
23mod help;
24mod parse_attrs;
25
26/// Entrypoint for `#[derive(FromArgs)]`.
27#[proc_macro_derive(FromArgs, attributes(argh))]
28pub fn argh_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
29 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
30 let gen = impl_from_args(&ast);
31 gen.into()
32}
33
34/// Transform the input into a token stream containing any generated implementations,
35/// as well as all errors that occurred.
36fn impl_from_args(input: &syn::DeriveInput) -> TokenStream {
37 let errors = &Errors::default();
38 if input.generics.params.len() != 0 {
39 errors.err(
40 &input.generics,
41 "`#![derive(FromArgs)]` cannot be applied to types with generic parameters",
42 );
43 }
44 let type_attrs = &TypeAttrs::parse(errors, input);
45 let mut output_tokens = match &input.data {
46 syn::Data::Struct(ds) => impl_from_args_struct(errors, &input.ident, type_attrs, ds),
47 syn::Data::Enum(de) => impl_from_args_enum(errors, &input.ident, type_attrs, de),
48 syn::Data::Union(_) => {
49 errors.err(input, "`#[derive(FromArgs)]` cannot be applied to unions");
50 TokenStream::new()
51 }
52 };
53 errors.to_tokens(&mut output_tokens);
54 output_tokens
55}
56
57/// The kind of optionality a parameter has.
58enum Optionality {
59 None,
60 Defaulted(TokenStream),
61 Optional,
62 Repeating,
63}
64
65impl PartialEq<Optionality> for Optionality {
66 fn eq(&self, other: &Optionality) -> bool {
67 use Optionality::*;
68 match (self, other) {
69 (None, None) | (Optional, Optional) | (Repeating, Repeating) => true,
70 // NB: (Defaulted, Defaulted) can't contain the same token streams
71 _ => false,
72 }
73 }
74}
75
76impl Optionality {
77 /// Whether or not this is `Optionality::None`
78 fn is_required(&self) -> bool {
79 if let Optionality::None = self {
80 true
81 } else {
82 false
83 }
84 }
85}
86
87/// A field of a `#![derive(FromArgs)]` struct with attributes and some other
88/// notable metadata appended.
89struct StructField<'a> {
90 /// The original parsed field
91 field: &'a syn::Field,
92 /// The parsed attributes of the field
93 attrs: FieldAttrs,
94 /// The field name. This is contained optionally inside `field`,
95 /// but is duplicated non-optionally here to indicate that all field that
96 /// have reached this point must have a field name, and it no longer
97 /// needs to be unwrapped.
98 name: &'a syn::Ident,
99 /// Similar to `name` above, this is contained optionally inside `FieldAttrs`,
100 /// but here is fully present to indicate that we only have to consider fields
101 /// with a valid `kind` at this point.
102 kind: FieldKind,
103 // If `field.ty` is `Vec<T>` or `Option<T>`, this is `T`, otherwise it's `&field.ty`.
104 // This is used to enable consistent parsing code between optional and non-optional
105 // keyed and subcommand fields.
106 ty_without_wrapper: &'a syn::Type,
107 // Whether the field represents an optional value, such as an `Option` subcommand field
108 // or an `Option` or `Vec` keyed argument, or if it has a `default`.
109 optionality: Optionality,
110 // The `--`-prefixed name of the option, if one exists.
111 long_name: Option<String>,
112}
113
114impl<'a> StructField<'a> {
115 /// Attempts to parse a field of a `#[derive(FromArgs)]` struct, pulling out the
116 /// fields required for code generation.
117 fn new(errors: &Errors, field: &'a syn::Field, attrs: FieldAttrs) -> Option<Self> {
118 let name = field.ident.as_ref().expect("missing ident for named field");
119
120 // Ensure that one "kind" is present (switch, option, subcommand, positional)
121 let kind = if let Some(field_type) = &attrs.field_type {
122 field_type.kind
123 } else {
124 errors.err(
125 field,
126 concat!(
127 "Missing `argh` field kind attribute.\n",
128 "Expected one of: `switch`, `option`, `subcommand`, `positional`",
129 ),
130 );
131 return None;
132 };
133
134 // Parse out whether a field is optional (`Option` or `Vec`).
135 let optionality;
136 let ty_without_wrapper;
137 match kind {
138 FieldKind::Switch => {
139 if !ty_expect_switch(errors, &field.ty) {
140 return None;
141 }
142 optionality = Optionality::Optional;
143 ty_without_wrapper = &field.ty;
144 }
145 FieldKind::Option | FieldKind::Positional => {
146 if let Some(default) = &attrs.default {
147 let tokens = match TokenStream::from_str(&default.value()) {
148 Ok(tokens) => tokens,
149 Err(_) => {
150 errors.err(&default, "Invalid tokens: unable to lex `default` value");
151 return None;
152 }
153 };
154 // Set the span of the generated tokens to the string literal
155 let tokens: TokenStream = tokens
156 .into_iter()
157 .map(|mut tree| {
158 tree.set_span(default.span());
159 tree
160 })
161 .collect();
162 optionality = Optionality::Defaulted(tokens);
163 ty_without_wrapper = &field.ty;
164 } else {
165 let mut inner = None;
166 optionality = if let Some(x) = ty_inner(&["Option"], &field.ty) {
167 inner = Some(x);
168 Optionality::Optional
169 } else if let Some(x) = ty_inner(&["Vec"], &field.ty) {
170 inner = Some(x);
171 Optionality::Repeating
172 } else {
173 Optionality::None
174 };
175 ty_without_wrapper = inner.unwrap_or(&field.ty);
176 }
177 }
178 FieldKind::SubCommand => {
179 let inner = ty_inner(&["Option"], &field.ty);
180 optionality =
181 if inner.is_some() { Optionality::Optional } else { Optionality::None };
182 ty_without_wrapper = inner.unwrap_or(&field.ty);
183 }
184 }
185
186 // Determine the "long" name of options and switches.
187 // Defaults to the kebab-case'd field name if `#[argh(long = "...")]` is omitted.
188 let long_name = match kind {
189 FieldKind::Switch | FieldKind::Option => {
190 let long_name = attrs
191 .long
192 .as_ref()
193 .map(syn::LitStr::value)
194 .unwrap_or_else(|| heck::KebabCase::to_kebab_case(&*name.to_string()));
195 if long_name == "help" {
196 errors.err(field, "Custom `--help` flags are not supported.");
197 }
198 let long_name = format!("--{}", long_name);
199 Some(long_name)
200 }
201 FieldKind::SubCommand | FieldKind::Positional => None,
202 };
203
204 Some(StructField { field, attrs, kind, optionality, ty_without_wrapper, name, long_name })
205 }
206
207 pub(crate) fn arg_name(&self) -> String {
208 self.attrs.arg_name.as_ref().map(LitStr::value).unwrap_or_else(|| self.name.to_string())
209 }
210}
211
212/// Implements `FromArgs` and `TopLevelCommand` or `SubCommand` for a `#[derive(FromArgs)]` struct.
213fn impl_from_args_struct(
214 errors: &Errors,
215 name: &syn::Ident,
216 type_attrs: &TypeAttrs,
217 ds: &syn::DataStruct,
218) -> TokenStream {
219 let fields = match &ds.fields {
220 syn::Fields::Named(fields) => fields,
221 syn::Fields::Unnamed(_) => {
222 errors.err(
223 &ds.struct_token,
224 "`#![derive(FromArgs)]` is not currently supported on tuple structs",
225 );
226 return TokenStream::new();
227 }
228 syn::Fields::Unit => {
229 errors.err(&ds.struct_token, "#![derive(FromArgs)]` cannot be applied to unit structs");
230 return TokenStream::new();
231 }
232 };
233
234 let fields: Vec<_> = fields
235 .named
236 .iter()
237 .filter_map(|field| {
238 let attrs = FieldAttrs::parse(errors, field);
239 StructField::new(errors, field, attrs)
240 })
241 .collect();
242
243 ensure_only_last_positional_is_optional(errors, &fields);
244
245 let impl_span = Span::call_site();
246
247 let from_args_method = impl_from_args_struct_from_args(errors, type_attrs, &fields);
248
249 let redact_arg_values_method =
250 impl_from_args_struct_redact_arg_values(errors, type_attrs, &fields);
251
252 let top_or_sub_cmd_impl = top_or_sub_cmd_impl(errors, name, type_attrs);
253
254 let trait_impl = quote_spanned! { impl_span =>
255 impl argh::FromArgs for #name {
256 #from_args_method
257
258 #redact_arg_values_method
259 }
260
261 #top_or_sub_cmd_impl
262 };
263
264 trait_impl
265}
266
267fn impl_from_args_struct_from_args<'a>(
268 errors: &Errors,
269 type_attrs: &TypeAttrs,
270 fields: &'a [StructField<'a>],
271) -> TokenStream {
272 let init_fields = declare_local_storage_for_from_args_fields(&fields);
273 let unwrap_fields = unwrap_from_args_fields(&fields);
274 let positional_fields: Vec<&StructField<'_>> =
275 fields.iter().filter(|field| field.kind == FieldKind::Positional).collect();
276 let positional_field_idents = positional_fields.iter().map(|field| &field.field.ident);
277 let positional_field_names = positional_fields.iter().map(|field| field.name.to_string());
278 let last_positional_is_repeating = positional_fields
279 .last()
280 .map(|field| field.optionality == Optionality::Repeating)
281 .unwrap_or(false);
282
283 let flag_output_table = fields.iter().filter_map(|field| {
284 let field_name = &field.field.ident;
285 match field.kind {
286 FieldKind::Option => Some(quote! { argh::ParseStructOption::Value(&mut #field_name) }),
287 FieldKind::Switch => Some(quote! { argh::ParseStructOption::Flag(&mut #field_name) }),
288 FieldKind::SubCommand | FieldKind::Positional => None,
289 }
290 });
291
292 let flag_str_to_output_table_map = flag_str_to_output_table_map_entries(&fields);
293
294 let mut subcommands_iter =
295 fields.iter().filter(|field| field.kind == FieldKind::SubCommand).fuse();
296
297 let subcommand: Option<&StructField<'_>> = subcommands_iter.next();
298 while let Some(dup_subcommand) = subcommands_iter.next() {
299 errors.duplicate_attrs("subcommand", subcommand.unwrap().field, dup_subcommand.field);
300 }
301
302 let impl_span = Span::call_site();
303
304 let missing_requirements_ident = syn::Ident::new("__missing_requirements", impl_span);
305
306 let append_missing_requirements =
307 append_missing_requirements(&missing_requirements_ident, &fields);
308
309 let parse_subcommands = if let Some(subcommand) = subcommand {
310 let name = subcommand.name;
311 let ty = subcommand.ty_without_wrapper;
312 quote_spanned! { impl_span =>
313 Some(argh::ParseStructSubCommand {
314 subcommands: <#ty as argh::SubCommands>::COMMANDS,
315 parse_func: &mut |__command, __remaining_args| {
316 #name = Some(<#ty as argh::FromArgs>::from_args(__command, __remaining_args)?);
317 Ok(())
318 },
319 })
320 }
321 } else {
322 quote_spanned! { impl_span => None }
323 };
324
325 // Identifier referring to a value containing the name of the current command as an `&[&str]`.
326 let cmd_name_str_array_ident = syn::Ident::new("__cmd_name", impl_span);
327 let help = help::help(errors, cmd_name_str_array_ident, type_attrs, &fields, subcommand);
328
329 let method_impl = quote_spanned! { impl_span =>
330 fn from_args(__cmd_name: &[&str], __args: &[&str])
331 -> std::result::Result<Self, argh::EarlyExit>
332 {
333 #( #init_fields )*
334
335 argh::parse_struct_args(
336 __cmd_name,
337 __args,
338 argh::ParseStructOptions {
339 arg_to_slot: &[ #( #flag_str_to_output_table_map ,)* ],
340 slots: &mut [ #( #flag_output_table, )* ],
341 },
342 argh::ParseStructPositionals {
343 positionals: &mut [
344 #(
345 argh::ParseStructPositional {
346 name: #positional_field_names,
347 slot: &mut #positional_field_idents as &mut argh::ParseValueSlot,
348 },
349 )*
350 ],
351 last_is_repeating: #last_positional_is_repeating,
352 },
353 #parse_subcommands,
354 &|| #help,
355 )?;
356
357 let mut #missing_requirements_ident = argh::MissingRequirements::default();
358 #(
359 #append_missing_requirements
360 )*
361 #missing_requirements_ident.err_on_any()?;
362
363 Ok(Self {
364 #( #unwrap_fields, )*
365 })
366 }
367 };
368
369 method_impl
370}
371
372fn impl_from_args_struct_redact_arg_values<'a>(
373 errors: &Errors,
374 type_attrs: &TypeAttrs,
375 fields: &'a [StructField<'a>],
376) -> TokenStream {
377 let init_fields = declare_local_storage_for_redacted_fields(&fields);
378 let unwrap_fields = unwrap_redacted_fields(&fields);
379
380 let positional_fields: Vec<&StructField<'_>> =
381 fields.iter().filter(|field| field.kind == FieldKind::Positional).collect();
382 let positional_field_idents = positional_fields.iter().map(|field| &field.field.ident);
383 let positional_field_names = positional_fields.iter().map(|field| field.name.to_string());
384 let last_positional_is_repeating = positional_fields
385 .last()
386 .map(|field| field.optionality == Optionality::Repeating)
387 .unwrap_or(false);
388
389 let flag_output_table = fields.iter().filter_map(|field| {
390 let field_name = &field.field.ident;
391 match field.kind {
392 FieldKind::Option => Some(quote! { argh::ParseStructOption::Value(&mut #field_name) }),
393 FieldKind::Switch => Some(quote! { argh::ParseStructOption::Flag(&mut #field_name) }),
394 FieldKind::SubCommand | FieldKind::Positional => None,
395 }
396 });
397
398 let flag_str_to_output_table_map = flag_str_to_output_table_map_entries(&fields);
399
400 let mut subcommands_iter =
401 fields.iter().filter(|field| field.kind == FieldKind::SubCommand).fuse();
402
403 let subcommand: Option<&StructField<'_>> = subcommands_iter.next();
404 while let Some(dup_subcommand) = subcommands_iter.next() {
405 errors.duplicate_attrs("subcommand", subcommand.unwrap().field, dup_subcommand.field);
406 }
407
408 let impl_span = Span::call_site();
409
410 let missing_requirements_ident = syn::Ident::new("__missing_requirements", impl_span);
411
412 let append_missing_requirements =
413 append_missing_requirements(&missing_requirements_ident, &fields);
414
415 let redact_subcommands = if let Some(subcommand) = subcommand {
416 let name = subcommand.name;
417 let ty = subcommand.ty_without_wrapper;
418 quote_spanned! { impl_span =>
419 Some(argh::ParseStructSubCommand {
420 subcommands: <#ty as argh::SubCommands>::COMMANDS,
421 parse_func: &mut |__command, __remaining_args| {
422 #name = Some(<#ty as argh::FromArgs>::redact_arg_values(__command, __remaining_args)?);
423 Ok(())
424 },
425 })
426 }
427 } else {
428 quote_spanned! { impl_span => None }
429 };
430
431 let cmd_name = if type_attrs.is_subcommand.is_none() {
432 quote! { __cmd_name.last().expect("no command name").to_string() }
433 } else {
434 quote! { __cmd_name.last().expect("no subcommand name").to_string() }
435 };
436
437 // Identifier referring to a value containing the name of the current command as an `&[&str]`.
438 let cmd_name_str_array_ident = syn::Ident::new("__cmd_name", impl_span);
439 let help = help::help(errors, cmd_name_str_array_ident, type_attrs, &fields, subcommand);
440
441 let method_impl = quote_spanned! { impl_span =>
442 fn redact_arg_values(__cmd_name: &[&str], __args: &[&str]) -> std::result::Result<Vec<String>, argh::EarlyExit> {
443 #( #init_fields )*
444
445 argh::parse_struct_args(
446 __cmd_name,
447 __args,
448 argh::ParseStructOptions {
449 arg_to_slot: &[ #( #flag_str_to_output_table_map ,)* ],
450 slots: &mut [ #( #flag_output_table, )* ],
451 },
452 argh::ParseStructPositionals {
453 positionals: &mut [
454 #(
455 argh::ParseStructPositional {
456 name: #positional_field_names,
457 slot: &mut #positional_field_idents as &mut argh::ParseValueSlot,
458 },
459 )*
460 ],
461 last_is_repeating: #last_positional_is_repeating,
462 },
463 #redact_subcommands,
464 &|| #help,
465 )?;
466
467 let mut #missing_requirements_ident = argh::MissingRequirements::default();
468 #(
469 #append_missing_requirements
470 )*
471 #missing_requirements_ident.err_on_any()?;
472
473 let mut __redacted = vec![
474 #cmd_name,
475 ];
476
477 #( #unwrap_fields )*
478
479 Ok(__redacted)
480 }
481 };
482
483 method_impl
484}
485
486/// Ensures that only the last positional arg is non-required.
487fn ensure_only_last_positional_is_optional(errors: &Errors, fields: &[StructField<'_>]) {
488 let mut first_non_required_span = None;
489 for field in fields {
490 if field.kind == FieldKind::Positional {
491 if let Some(first) = first_non_required_span {
492 errors.err_span(
493 first,
494 "Only the last positional argument may be `Option`, `Vec`, or defaulted.",
495 );
496 errors.err(&field.field, "Later positional argument declared here.");
497 return;
498 }
499 if !field.optionality.is_required() {
500 first_non_required_span = Some(field.field.span());
501 }
502 }
503 }
504}
505
506/// Implement `argh::TopLevelCommand` or `argh::SubCommand` as appropriate.
507fn top_or_sub_cmd_impl(errors: &Errors, name: &syn::Ident, type_attrs: &TypeAttrs) -> TokenStream {
508 let description =
509 help::require_description(errors, name.span(), &type_attrs.description, "type");
510 if type_attrs.is_subcommand.is_none() {
511 // Not a subcommand
512 quote! {
513 impl argh::TopLevelCommand for #name {}
514 }
515 } else {
516 let empty_str = syn::LitStr::new("", Span::call_site());
517 let subcommand_name = type_attrs.name.as_ref().unwrap_or_else(|| {
518 errors.err(name, "`#[argh(name = \"...\")]` attribute is required for subcommands");
519 &empty_str
520 });
521 quote! {
522 impl argh::SubCommand for #name {
523 const COMMAND: &'static argh::CommandInfo = &argh::CommandInfo {
524 name: #subcommand_name,
525 description: #description,
526 };
527 }
528 }
529 }
530}
531
532/// Declare a local slots to store each field in during parsing.
533///
534/// Most fields are stored in `Option<FieldType>` locals.
535/// `argh(option)` fields are stored in a `ParseValueSlotTy` along with a
536/// function that knows how to decode the appropriate value.
537fn declare_local_storage_for_from_args_fields<'a>(
538 fields: &'a [StructField<'a>],
539) -> impl Iterator<Item = TokenStream> + 'a {
540 fields.iter().map(|field| {
541 let field_name = &field.field.ident;
542 let field_type = &field.ty_without_wrapper;
543
544 // Wrap field types in `Option` if they aren't already `Option` or `Vec`-wrapped.
545 let field_slot_type = match field.optionality {
546 Optionality::Optional | Optionality::Repeating => (&field.field.ty).into_token_stream(),
547 Optionality::None | Optionality::Defaulted(_) => {
548 quote! { std::option::Option<#field_type> }
549 }
550 };
551
552 match field.kind {
553 FieldKind::Option | FieldKind::Positional => {
554 let from_str_fn = match &field.attrs.from_str_fn {
555 Some(from_str_fn) => from_str_fn.into_token_stream(),
556 None => {
557 quote! {
558 <#field_type as argh::FromArgValue>::from_arg_value
559 }
560 }
561 };
562
563 quote! {
564 let mut #field_name: argh::ParseValueSlotTy<#field_slot_type, #field_type>
565 = argh::ParseValueSlotTy {
566 slot: std::default::Default::default(),
567 parse_func: |_, value| { #from_str_fn(value) },
568 };
569 }
570 }
571 FieldKind::SubCommand => {
572 quote! { let mut #field_name: #field_slot_type = None; }
573 }
574 FieldKind::Switch => {
575 quote! { let mut #field_name: #field_slot_type = argh::Flag::default(); }
576 }
577 }
578 })
579}
580
581/// Unwrap non-optional fields and take options out of their tuple slots.
582fn unwrap_from_args_fields<'a>(
583 fields: &'a [StructField<'a>],
584) -> impl Iterator<Item = TokenStream> + 'a {
585 fields.iter().map(|field| {
586 let field_name = field.name;
587 match field.kind {
588 FieldKind::Option | FieldKind::Positional => match &field.optionality {
589 Optionality::None => quote! { #field_name: #field_name.slot.unwrap() },
590 Optionality::Optional | Optionality::Repeating => {
591 quote! { #field_name: #field_name.slot }
592 }
593 Optionality::Defaulted(tokens) => {
594 quote! {
595 #field_name: #field_name.slot.unwrap_or_else(|| #tokens)
596 }
597 }
598 },
599 FieldKind::Switch => field_name.into_token_stream(),
600 FieldKind::SubCommand => match field.optionality {
601 Optionality::None => quote! { #field_name: #field_name.unwrap() },
602 Optionality::Optional | Optionality::Repeating => field_name.into_token_stream(),
603 Optionality::Defaulted(_) => unreachable!(),
604 },
605 }
606 })
607}
608
609/// Declare a local slots to store each field in during parsing.
610///
611/// Most fields are stored in `Option<FieldType>` locals.
612/// `argh(option)` fields are stored in a `ParseValueSlotTy` along with a
613/// function that knows how to decode the appropriate value.
614fn declare_local_storage_for_redacted_fields<'a>(
615 fields: &'a [StructField<'a>],
616) -> impl Iterator<Item = TokenStream> + 'a {
617 fields.iter().map(|field| {
618 let field_name = &field.field.ident;
619
620 match field.kind {
621 FieldKind::Switch => {
622 quote! {
623 let mut #field_name = argh::RedactFlag {
624 slot: None,
625 };
626 }
627 }
628 FieldKind::Option => {
629 let field_slot_type = match field.optionality {
630 Optionality::Repeating => {
631 quote! { std::vec::Vec<String> }
632 }
633 Optionality::None | Optionality::Optional | Optionality::Defaulted(_) => {
634 quote! { std::option::Option<String> }
635 }
636 };
637
638 quote! {
639 let mut #field_name: argh::ParseValueSlotTy::<#field_slot_type, String> =
640 argh::ParseValueSlotTy {
641 slot: std::default::Default::default(),
642 parse_func: |arg, _| { Ok(arg.to_string()) },
643 };
644 }
645 }
646 FieldKind::Positional => {
647 let field_slot_type = match field.optionality {
648 Optionality::Repeating => {
649 quote! { std::vec::Vec<String> }
650 }
651 Optionality::None | Optionality::Optional | Optionality::Defaulted(_) => {
652 quote! { std::option::Option<String> }
653 }
654 };
655
656 let arg_name = field.arg_name();
657 quote! {
658 let mut #field_name: argh::ParseValueSlotTy::<#field_slot_type, String> =
659 argh::ParseValueSlotTy {
660 slot: std::default::Default::default(),
661 parse_func: |_, _| { Ok(#arg_name.to_string()) },
662 };
663 }
664 }
665 FieldKind::SubCommand => {
666 quote! { let mut #field_name: std::option::Option<std::vec::Vec<String>> = None; }
667 }
668 }
669 })
670}
671
672/// Unwrap non-optional fields and take options out of their tuple slots.
673fn unwrap_redacted_fields<'a>(
674 fields: &'a [StructField<'a>],
675) -> impl Iterator<Item = TokenStream> + 'a {
676 fields.iter().map(|field| {
677 let field_name = field.name;
678
679 match field.kind {
680 FieldKind::Switch => {
681 quote! {
682 if let Some(__field_name) = #field_name.slot {
683 __redacted.push(__field_name);
684 }
685 }
686 }
687 FieldKind::Option => match field.optionality {
688 Optionality::Repeating => {
689 quote! {
690 __redacted.extend(#field_name.slot.into_iter());
691 }
692 }
693 Optionality::None | Optionality::Optional | Optionality::Defaulted(_) => {
694 quote! {
695 if let Some(__field_name) = #field_name.slot {
696 __redacted.push(__field_name);
697 }
698 }
699 }
700 },
701 FieldKind::Positional => {
702 quote! {
703 __redacted.extend(#field_name.slot.into_iter());
704 }
705 }
706 FieldKind::SubCommand => {
707 quote! {
708 if let Some(__subcommand_args) = #field_name {
709 __redacted.extend(__subcommand_args.into_iter());
710 }
711 }
712 }
713 }
714 })
715}
716
717/// Entries of tokens like `("--some-flag-key", 5)` that map from a flag key string
718/// to an index in the output table.
719fn flag_str_to_output_table_map_entries<'a>(fields: &'a [StructField<'a>]) -> Vec<TokenStream> {
720 let mut flag_str_to_output_table_map = vec![];
721 for (i, (field, long_name)) in fields
722 .iter()
723 .filter_map(|field| field.long_name.as_ref().map(|long_name| (field, long_name)))
724 .enumerate()
725 {
726 if let Some(short) = &field.attrs.short {
727 let short = format!("-{}", short.value());
728 flag_str_to_output_table_map.push(quote! { (#short, #i) });
729 }
730
731 flag_str_to_output_table_map.push(quote! { (#long_name, #i) });
732 }
733 flag_str_to_output_table_map
734}
735
736/// For each non-optional field, add an entry to the `argh::MissingRequirements`.
737fn append_missing_requirements<'a>(
738 // missing_requirements_ident
739 mri: &syn::Ident,
740 fields: &'a [StructField<'a>],
741) -> impl Iterator<Item = TokenStream> + 'a {
742 let mri = mri.clone();
743 fields.iter().filter(|f| f.optionality.is_required()).map(move |field| {
744 let field_name = field.name;
745 match field.kind {
746 FieldKind::Switch => unreachable!("switches are always optional"),
747 FieldKind::Positional => {
748 let name = field.arg_name();
749 quote! {
750 if #field_name.slot.is_none() {
751 #mri.missing_positional_arg(#name)
752 }
753 }
754 }
755 FieldKind::Option => {
756 let name = field.long_name.as_ref().expect("options always have a long name");
757 quote! {
758 if #field_name.slot.is_none() {
759 #mri.missing_option(#name)
760 }
761 }
762 }
763 FieldKind::SubCommand => {
764 let ty = field.ty_without_wrapper;
765 quote! {
766 if #field_name.is_none() {
767 #mri.missing_subcommands(
768 <#ty as argh::SubCommands>::COMMANDS,
769 )
770 }
771 }
772 }
773 }
774 })
775}
776
777/// Require that a type can be a `switch`.
778/// Throws an error for all types except booleans and integers
779fn ty_expect_switch(errors: &Errors, ty: &syn::Type) -> bool {
780 fn ty_can_be_switch(ty: &syn::Type) -> bool {
781 if let syn::Type::Path(path) = ty {
782 if path.qself.is_some() {
783 return false;
784 }
785 if path.path.segments.len() != 1 {
786 return false;
787 }
788 let ident = &path.path.segments[0].ident;
789 ["bool", "u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"]
790 .iter()
791 .any(|path| ident == path)
792 } else {
793 false
794 }
795 }
796
797 let res = ty_can_be_switch(ty);
798 if !res {
799 errors.err(ty, "switches must be of type `bool` or integer type");
800 }
801 res
802}
803
804/// Returns `Some(T)` if a type is `wrapper_name<T>` for any `wrapper_name` in `wrapper_names`.
805fn ty_inner<'a>(wrapper_names: &[&str], ty: &'a syn::Type) -> Option<&'a syn::Type> {
806 if let syn::Type::Path(path) = ty {
807 if path.qself.is_some() {
808 return None;
809 }
810 // Since we only check the last path segment, it isn't necessarily the case that
811 // we're referring to `std::vec::Vec` or `std::option::Option`, but there isn't
812 // a fool proof way to check these since name resolution happens after macro expansion,
813 // so this is likely "good enough" (so long as people don't have their own types called
814 // `Option` or `Vec` that take one generic parameter they're looking to parse).
815 let last_segment = path.path.segments.last()?;
816 if !wrapper_names.iter().any(|name| last_segment.ident == *name) {
817 return None;
818 }
819 if let syn::PathArguments::AngleBracketed(gen_args) = &last_segment.arguments {
820 let generic_arg = gen_args.args.first()?;
821 if let syn::GenericArgument::Type(ty) = &generic_arg {
822 return Some(ty);
823 }
824 }
825 }
826 None
827}
828
829/// Implements `FromArgs` and `SubCommands` for a `#![derive(FromArgs)]` enum.
830fn impl_from_args_enum(
831 errors: &Errors,
832 name: &syn::Ident,
833 type_attrs: &TypeAttrs,
834 de: &syn::DataEnum,
835) -> TokenStream {
836 parse_attrs::check_enum_type_attrs(errors, type_attrs, &de.enum_token.span);
837
838 // An enum variant like `<name>(<ty>)`
839 struct SubCommandVariant<'a> {
840 name: &'a syn::Ident,
841 ty: &'a syn::Type,
842 }
843
844 let variants: Vec<SubCommandVariant<'_>> = de
845 .variants
846 .iter()
847 .filter_map(|variant| {
848 parse_attrs::check_enum_variant_attrs(errors, variant);
849 let name = &variant.ident;
850 let ty = enum_only_single_field_unnamed_variants(errors, &variant.fields)?;
851 Some(SubCommandVariant { name, ty })
852 })
853 .collect();
854
855 let name_repeating = std::iter::repeat(name.clone());
856 let variant_ty = variants.iter().map(|x| x.ty).collect::<Vec<_>>();
857 let variant_names = variants.iter().map(|x| x.name).collect::<Vec<_>>();
858
859 quote! {
860 impl argh::FromArgs for #name {
861 fn from_args(command_name: &[&str], args: &[&str])
862 -> std::result::Result<Self, argh::EarlyExit>
863 {
864 let subcommand_name = *command_name.last().expect("no subcommand name");
865 #(
866 if subcommand_name == <#variant_ty as argh::SubCommand>::COMMAND.name {
867 return Ok(#name_repeating::#variant_names(
868 <#variant_ty as argh::FromArgs>::from_args(command_name, args)?
869 ));
870 }
871 )*
872 unreachable!("no subcommand matched")
873 }
874
875 fn redact_arg_values(command_name: &[&str], args: &[&str]) -> std::result::Result<Vec<String>, argh::EarlyExit> {
876 let subcommand_name = *command_name.last().expect("no subcommand name");
877 #(
878 if subcommand_name == <#variant_ty as argh::SubCommand>::COMMAND.name {
879 return <#variant_ty as argh::FromArgs>::redact_arg_values(command_name, args);
880 }
881 )*
882 unreachable!("no subcommand matched")
883 }
884 }
885
886 impl argh::SubCommands for #name {
887 const COMMANDS: &'static [&'static argh::CommandInfo] = &[#(
888 <#variant_ty as argh::SubCommand>::COMMAND,
889 )*];
890 }
891 }
892}
893
894/// Returns `Some(Bar)` if the field is a single-field unnamed variant like `Foo(Bar)`.
895/// Otherwise, generates an error.
896fn enum_only_single_field_unnamed_variants<'a>(
897 errors: &Errors,
898 variant_fields: &'a syn::Fields,
899) -> Option<&'a syn::Type> {
900 macro_rules! with_enum_suggestion {
901 ($help_text:literal) => {
902 concat!(
903 $help_text,
904 "\nInstead, use a variant with a single unnamed field for each subcommand:\n",
905 " enum MyCommandEnum {\n",
906 " SubCommandOne(SubCommandOne),\n",
907 " SubCommandTwo(SubCommandTwo),\n",
908 " }",
909 )
910 };
911 }
912
913 match variant_fields {
914 syn::Fields::Named(fields) => {
915 errors.err(
916 fields,
917 with_enum_suggestion!(
918 "`#![derive(FromArgs)]` `enum`s do not support variants with named fields."
919 ),
920 );
921 None
922 }
923 syn::Fields::Unit => {
924 errors.err(
925 variant_fields,
926 with_enum_suggestion!(
927 "`#![derive(FromArgs)]` does not support `enum`s with no variants."
928 ),
929 );
930 None
931 }
932 syn::Fields::Unnamed(fields) => {
933 if fields.unnamed.len() != 1 {
934 errors.err(
935 fields,
936 with_enum_suggestion!(
937 "`#![derive(FromArgs)]` `enum` variants must only contain one field."
938 ),
939 );
940 None
941 } else {
942 // `unwrap` is okay because of the length check above.
943 let first_field = fields.unnamed.first().unwrap();
944 Some(&first_field.ty)
945 }
946 }
947 }
948}