Chih-Hung Hsieh | 92ff605 | 2020-06-10 20:18:39 -0700 | [diff] [blame] | 1 | //! Oneof-related codegen functions. |
| 2 | |
| 3 | use code_writer::CodeWriter; |
| 4 | use field::FieldElem; |
| 5 | use field::FieldGen; |
| 6 | use message::MessageGen; |
| 7 | use protobuf::descriptor::FieldDescriptorProto_Type; |
| 8 | use protobuf_name::ProtobufAbsolutePath; |
| 9 | use rust_name::RustIdent; |
| 10 | use rust_types_values::RustType; |
| 11 | use scope::FieldWithContext; |
| 12 | use scope::OneofVariantWithContext; |
| 13 | use scope::OneofWithContext; |
| 14 | use scope::RootScope; |
| 15 | use scope::WithScope; |
| 16 | use serde; |
| 17 | use std::collections::HashSet; |
| 18 | use Customize; |
| 19 | |
| 20 | // oneof one { ... } |
| 21 | #[derive(Clone)] |
| 22 | pub(crate) struct OneofField<'a> { |
| 23 | pub elem: FieldElem<'a>, |
| 24 | pub oneof_rust_field_name: RustIdent, |
| 25 | pub oneof_type_name: RustType, |
| 26 | pub boxed: bool, |
| 27 | } |
| 28 | |
| 29 | impl<'a> OneofField<'a> { |
| 30 | // Detecting recursion: if oneof fields contains a self-reference |
| 31 | // or another message which has a reference to self, |
| 32 | // put oneof variant into a box. |
| 33 | fn need_boxed(field: &FieldWithContext, root_scope: &RootScope, owner_name: &str) -> bool { |
| 34 | let mut visited_messages = HashSet::new(); |
| 35 | let mut fields = vec![field.clone()]; |
| 36 | while let Some(field) = fields.pop() { |
| 37 | if field.field.get_field_type() == FieldDescriptorProto_Type::TYPE_MESSAGE { |
| 38 | let message_name = ProtobufAbsolutePath::from(field.field.get_type_name()); |
| 39 | if !visited_messages.insert(message_name.clone()) { |
| 40 | continue; |
| 41 | } |
| 42 | if message_name.path == owner_name { |
| 43 | return true; |
| 44 | } |
| 45 | let message = root_scope.find_message(&message_name); |
| 46 | fields.extend(message.fields().into_iter().filter(|f| f.is_oneof())); |
| 47 | } |
| 48 | } |
| 49 | false |
| 50 | } |
| 51 | |
| 52 | pub fn parse( |
| 53 | oneof: &OneofWithContext<'a>, |
| 54 | field: &FieldWithContext<'a>, |
| 55 | elem: FieldElem<'a>, |
| 56 | root_scope: &RootScope, |
| 57 | ) -> OneofField<'a> { |
| 58 | let boxed = OneofField::need_boxed(field, root_scope, &oneof.message.name_absolute().path); |
| 59 | |
| 60 | OneofField { |
| 61 | elem, |
| 62 | boxed, |
| 63 | oneof_rust_field_name: oneof.field_name().into(), |
| 64 | oneof_type_name: RustType::Oneof(oneof.rust_name().to_string()), |
| 65 | } |
| 66 | } |
| 67 | |
| 68 | pub fn rust_type(&self) -> RustType { |
| 69 | let t = self.elem.rust_storage_type(); |
| 70 | |
| 71 | if self.boxed { |
| 72 | RustType::Uniq(Box::new(t)) |
| 73 | } else { |
| 74 | t |
| 75 | } |
| 76 | } |
| 77 | } |
| 78 | |
| 79 | #[derive(Clone)] |
| 80 | pub(crate) struct OneofVariantGen<'a> { |
| 81 | oneof: &'a OneofGen<'a>, |
| 82 | variant: OneofVariantWithContext<'a>, |
| 83 | oneof_field: OneofField<'a>, |
| 84 | pub field: FieldGen<'a>, |
| 85 | path: String, |
| 86 | customize: Customize, |
| 87 | } |
| 88 | |
| 89 | impl<'a> OneofVariantGen<'a> { |
| 90 | fn parse( |
| 91 | oneof: &'a OneofGen<'a>, |
| 92 | variant: OneofVariantWithContext<'a>, |
| 93 | field: &'a FieldGen, |
| 94 | _root_scope: &RootScope, |
| 95 | customize: Customize, |
| 96 | ) -> OneofVariantGen<'a> { |
| 97 | OneofVariantGen { |
| 98 | oneof, |
| 99 | variant: variant.clone(), |
| 100 | field: field.clone(), |
| 101 | path: format!( |
| 102 | "{}::{}", |
| 103 | oneof.type_name.to_code(&field.customize), |
| 104 | field.rust_name |
| 105 | ), |
| 106 | oneof_field: OneofField::parse( |
| 107 | variant.oneof, |
| 108 | &field.proto_field, |
| 109 | field.oneof().elem.clone(), |
| 110 | oneof.message.root_scope, |
| 111 | ), |
| 112 | customize, |
| 113 | } |
| 114 | } |
| 115 | |
| 116 | fn rust_type(&self) -> RustType { |
| 117 | self.oneof_field.rust_type() |
| 118 | } |
| 119 | |
| 120 | pub fn path(&self) -> String { |
| 121 | self.path.clone() |
| 122 | } |
| 123 | } |
| 124 | |
| 125 | #[derive(Clone)] |
| 126 | pub(crate) struct OneofGen<'a> { |
| 127 | // Message containing this oneof |
| 128 | message: &'a MessageGen<'a>, |
| 129 | pub oneof: OneofWithContext<'a>, |
| 130 | type_name: RustType, |
| 131 | lite_runtime: bool, |
| 132 | customize: Customize, |
| 133 | } |
| 134 | |
| 135 | impl<'a> OneofGen<'a> { |
| 136 | pub fn parse( |
| 137 | message: &'a MessageGen, |
| 138 | oneof: OneofWithContext<'a>, |
| 139 | customize: &Customize, |
| 140 | ) -> OneofGen<'a> { |
| 141 | let rust_name = oneof.rust_name(); |
| 142 | OneofGen { |
| 143 | message, |
| 144 | oneof, |
| 145 | type_name: RustType::Oneof(rust_name.to_string()), |
| 146 | lite_runtime: message.lite_runtime, |
| 147 | customize: customize.clone(), |
| 148 | } |
| 149 | } |
| 150 | |
| 151 | pub fn variants_except_group(&'a self) -> Vec<OneofVariantGen<'a>> { |
| 152 | self.oneof |
| 153 | .variants() |
| 154 | .into_iter() |
| 155 | .filter_map(|v| { |
| 156 | let field = self |
| 157 | .message |
| 158 | .fields |
| 159 | .iter() |
| 160 | .filter(|f| f.proto_field.name() == v.field.get_name()) |
| 161 | .next() |
| 162 | .expect(&format!("field not found by name: {}", v.field.get_name())); |
| 163 | match field.proto_type { |
| 164 | FieldDescriptorProto_Type::TYPE_GROUP => None, |
| 165 | _ => Some(OneofVariantGen::parse( |
| 166 | self, |
| 167 | v, |
| 168 | field, |
| 169 | self.message.root_scope, |
| 170 | self.customize.clone(), |
| 171 | )), |
| 172 | } |
| 173 | }) |
| 174 | .collect() |
| 175 | } |
| 176 | |
| 177 | pub fn full_storage_type(&self) -> RustType { |
| 178 | RustType::Option(Box::new(self.type_name.clone())) |
| 179 | } |
| 180 | |
| 181 | pub fn write_enum(&self, w: &mut CodeWriter) { |
| 182 | let derive = vec!["Clone", "PartialEq", "Debug"]; |
| 183 | w.derive(&derive); |
Elliott Hughes | a013f1f | 2021-04-01 17:02:23 -0700 | [diff] [blame] | 184 | serde::write_serde_attr( |
| 185 | w, |
| 186 | &self.customize, |
| 187 | "derive(::serde::Serialize, ::serde::Deserialize)", |
| 188 | ); |
Chih-Hung Hsieh | 92ff605 | 2020-06-10 20:18:39 -0700 | [diff] [blame] | 189 | w.pub_enum(&self.type_name.to_code(&self.customize), |w| { |
| 190 | for variant in self.variants_except_group() { |
| 191 | w.write_line(&format!( |
| 192 | "{}({}),", |
| 193 | variant.field.rust_name, |
| 194 | &variant.rust_type().to_code(&self.customize) |
| 195 | )); |
| 196 | } |
| 197 | }); |
| 198 | } |
| 199 | } |