blob: e767709b711ad0a8f87c83ba52a9f00d98dc3b93 [file] [log] [blame]
Chih-Hung Hsieh92ff6052020-06-10 20:18:39 -07001//! Oneof-related codegen functions.
2
3use code_writer::CodeWriter;
4use field::FieldElem;
5use field::FieldGen;
6use message::MessageGen;
7use protobuf::descriptor::FieldDescriptorProto_Type;
8use protobuf_name::ProtobufAbsolutePath;
9use rust_name::RustIdent;
10use rust_types_values::RustType;
11use scope::FieldWithContext;
12use scope::OneofVariantWithContext;
13use scope::OneofWithContext;
14use scope::RootScope;
15use scope::WithScope;
16use serde;
17use std::collections::HashSet;
18use Customize;
19
20// oneof one { ... }
21#[derive(Clone)]
22pub(crate) struct OneofField<'a> {
23 pub elem: FieldElem<'a>,
24 pub oneof_rust_field_name: RustIdent,
25 pub oneof_type_name: RustType,
26 pub boxed: bool,
27}
28
29impl<'a> OneofField<'a> {
30 // Detecting recursion: if oneof fields contains a self-reference
31 // or another message which has a reference to self,
32 // put oneof variant into a box.
33 fn need_boxed(field: &FieldWithContext, root_scope: &RootScope, owner_name: &str) -> bool {
34 let mut visited_messages = HashSet::new();
35 let mut fields = vec![field.clone()];
36 while let Some(field) = fields.pop() {
37 if field.field.get_field_type() == FieldDescriptorProto_Type::TYPE_MESSAGE {
38 let message_name = ProtobufAbsolutePath::from(field.field.get_type_name());
39 if !visited_messages.insert(message_name.clone()) {
40 continue;
41 }
42 if message_name.path == owner_name {
43 return true;
44 }
45 let message = root_scope.find_message(&message_name);
46 fields.extend(message.fields().into_iter().filter(|f| f.is_oneof()));
47 }
48 }
49 false
50 }
51
52 pub fn parse(
53 oneof: &OneofWithContext<'a>,
54 field: &FieldWithContext<'a>,
55 elem: FieldElem<'a>,
56 root_scope: &RootScope,
57 ) -> OneofField<'a> {
58 let boxed = OneofField::need_boxed(field, root_scope, &oneof.message.name_absolute().path);
59
60 OneofField {
61 elem,
62 boxed,
63 oneof_rust_field_name: oneof.field_name().into(),
64 oneof_type_name: RustType::Oneof(oneof.rust_name().to_string()),
65 }
66 }
67
68 pub fn rust_type(&self) -> RustType {
69 let t = self.elem.rust_storage_type();
70
71 if self.boxed {
72 RustType::Uniq(Box::new(t))
73 } else {
74 t
75 }
76 }
77}
78
79#[derive(Clone)]
80pub(crate) struct OneofVariantGen<'a> {
81 oneof: &'a OneofGen<'a>,
82 variant: OneofVariantWithContext<'a>,
83 oneof_field: OneofField<'a>,
84 pub field: FieldGen<'a>,
85 path: String,
86 customize: Customize,
87}
88
89impl<'a> OneofVariantGen<'a> {
90 fn parse(
91 oneof: &'a OneofGen<'a>,
92 variant: OneofVariantWithContext<'a>,
93 field: &'a FieldGen,
94 _root_scope: &RootScope,
95 customize: Customize,
96 ) -> OneofVariantGen<'a> {
97 OneofVariantGen {
98 oneof,
99 variant: variant.clone(),
100 field: field.clone(),
101 path: format!(
102 "{}::{}",
103 oneof.type_name.to_code(&field.customize),
104 field.rust_name
105 ),
106 oneof_field: OneofField::parse(
107 variant.oneof,
108 &field.proto_field,
109 field.oneof().elem.clone(),
110 oneof.message.root_scope,
111 ),
112 customize,
113 }
114 }
115
116 fn rust_type(&self) -> RustType {
117 self.oneof_field.rust_type()
118 }
119
120 pub fn path(&self) -> String {
121 self.path.clone()
122 }
123}
124
125#[derive(Clone)]
126pub(crate) struct OneofGen<'a> {
127 // Message containing this oneof
128 message: &'a MessageGen<'a>,
129 pub oneof: OneofWithContext<'a>,
130 type_name: RustType,
131 lite_runtime: bool,
132 customize: Customize,
133}
134
135impl<'a> OneofGen<'a> {
136 pub fn parse(
137 message: &'a MessageGen,
138 oneof: OneofWithContext<'a>,
139 customize: &Customize,
140 ) -> OneofGen<'a> {
141 let rust_name = oneof.rust_name();
142 OneofGen {
143 message,
144 oneof,
145 type_name: RustType::Oneof(rust_name.to_string()),
146 lite_runtime: message.lite_runtime,
147 customize: customize.clone(),
148 }
149 }
150
151 pub fn variants_except_group(&'a self) -> Vec<OneofVariantGen<'a>> {
152 self.oneof
153 .variants()
154 .into_iter()
155 .filter_map(|v| {
156 let field = self
157 .message
158 .fields
159 .iter()
160 .filter(|f| f.proto_field.name() == v.field.get_name())
161 .next()
162 .expect(&format!("field not found by name: {}", v.field.get_name()));
163 match field.proto_type {
164 FieldDescriptorProto_Type::TYPE_GROUP => None,
165 _ => Some(OneofVariantGen::parse(
166 self,
167 v,
168 field,
169 self.message.root_scope,
170 self.customize.clone(),
171 )),
172 }
173 })
174 .collect()
175 }
176
177 pub fn full_storage_type(&self) -> RustType {
178 RustType::Option(Box::new(self.type_name.clone()))
179 }
180
181 pub fn write_enum(&self, w: &mut CodeWriter) {
182 let derive = vec!["Clone", "PartialEq", "Debug"];
183 w.derive(&derive);
Elliott Hughesa013f1f2021-04-01 17:02:23 -0700184 serde::write_serde_attr(
185 w,
186 &self.customize,
187 "derive(::serde::Serialize, ::serde::Deserialize)",
188 );
Chih-Hung Hsieh92ff6052020-06-10 20:18:39 -0700189 w.pub_enum(&self.type_name.to_code(&self.customize), |w| {
190 for variant in self.variants_except_group() {
191 w.write_line(&format!(
192 "{}({}),",
193 variant.field.rust_name,
194 &variant.rust_type().to_code(&self.customize)
195 ));
196 }
197 });
198 }
199}