Parse unions in a DeriveInput
diff --git a/src/derive.rs b/src/derive.rs
index 0fffff1..3bf125c 100644
--- a/src/derive.rs
+++ b/src/derive.rs
@@ -81,19 +81,31 @@
use synom::Synom;
+ enum DeriveInputKind {
+ Struct(Token![struct]),
+ Enum(Token![enum]),
+ Union(Token![union]),
+ }
+
+ impl Synom for DeriveInputKind {
+ named!(parse -> Self, alt!(
+ keyword!(struct) => { DeriveInputKind::Struct }
+ |
+ keyword!(enum) => { DeriveInputKind::Enum }
+ |
+ keyword!(union) => { DeriveInputKind::Union }
+ ));
+ }
+
impl Synom for DeriveInput {
named!(parse -> Self, do_parse!(
attrs: many0!(Attribute::parse_outer) >>
vis: syn!(Visibility) >>
- which: alt!(
- keyword!(struct) => { Ok }
- |
- keyword!(enum) => { Err }
- ) >>
+ which: syn!(DeriveInputKind) >>
id: syn!(Ident) >>
generics: syn!(Generics) >>
item: switch!(value!(which),
- Ok(s) => map!(data_struct, move |(wh, fields, semi)| DeriveInput {
+ DeriveInputKind::Struct(s) => map!(data_struct, move |(wh, fields, semi)| DeriveInput {
ident: id,
vis: vis,
attrs: attrs,
@@ -108,7 +120,7 @@
}),
})
|
- Err(e) => map!(data_enum, move |(wh, brace, variants)| DeriveInput {
+ DeriveInputKind::Enum(e) => map!(data_enum, move |(wh, brace, variants)| DeriveInput {
ident: id,
vis: vis,
attrs: attrs,
@@ -122,6 +134,20 @@
enum_token: e,
}),
})
+ |
+ DeriveInputKind::Union(u) => map!(data_union, move |(wh, fields)| DeriveInput {
+ ident: id,
+ vis: vis,
+ attrs: attrs,
+ generics: Generics {
+ where_clause: wh,
+ ..generics
+ },
+ data: Data::Union(DataUnion {
+ union_token: u,
+ fields: fields,
+ }),
+ })
) >>
(item)
));
@@ -157,6 +183,11 @@
data: braces!(Punctuated::parse_terminated) >>
(wh, data.0, data.1)
));
+
+ named!(data_union -> (Option<WhereClause>, FieldsNamed), tuple!(
+ option!(syn!(WhereClause)),
+ syn!(FieldsNamed)
+ ));
}
#[cfg(feature = "printing")]
diff --git a/tests/test_derive_input.rs b/tests/test_derive_input.rs
index 7e08045..acc6ce0 100644
--- a/tests/test_derive_input.rs
+++ b/tests/test_derive_input.rs
@@ -14,6 +14,7 @@
use proc_macro2::Delimiter::{Brace, Parenthesis};
use proc_macro2::*;
use syn::*;
+use syn::punctuated::Punctuated;
use std::iter::FromIterator;
@@ -159,6 +160,69 @@
}
#[test]
+fn test_union() {
+ let raw = "
+ union MaybeUninit<T> {
+ uninit: (),
+ value: T
+ }
+ ";
+
+ let expected = DeriveInput {
+ ident: ident("MaybeUninit"),
+ vis: Visibility::Inherited,
+ attrs: Vec::new(),
+ generics: Generics {
+ lt_token: Some(Default::default()),
+ params: punctuated![
+ GenericParam::Type(TypeParam {
+ attrs: Vec::new(),
+ ident: ident("T"),
+ bounds: Default::default(),
+ default: None,
+ colon_token: None,
+ eq_token: None,
+ }),
+ ],
+ gt_token: Some(Default::default()),
+ where_clause: None,
+ },
+ data: Data::Union(DataUnion {
+ union_token: Default::default(),
+ fields: FieldsNamed {
+ brace_token: Default::default(),
+ named: punctuated![
+ Field {
+ ident: Some(ident("uninit")),
+ colon_token: Some(Default::default()),
+ vis: Visibility::Inherited,
+ attrs: Vec::new(),
+ ty: TypeTuple {
+ paren_token: Default::default(),
+ elems: Punctuated::new(),
+ }.into(),
+ },
+ Field {
+ ident: Some(ident("value")),
+ colon_token: Some(Default::default()),
+ vis: Visibility::Inherited,
+ attrs: Vec::new(),
+ ty: TypePath {
+ qself: None,
+ path: ident("T").into(),
+ }.into(),
+ },
+ ],
+ },
+ }),
+ };
+
+ let actual = syn::parse_str(raw).unwrap();
+
+ assert_eq!(expected, actual);
+}
+
+#[test]
#[cfg(feature = "full")]
fn test_enum() {
let raw = r#"