blob: 75202c4e30d33ad2f1473dd6a410a2c3d073b877 [file] [log] [blame]
temporal40ee5512008-07-10 02:12:20 +00001# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.
3# http://code.google.com/p/protobuf/
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17# This code is meant to work on Python 2.4 and above only.
18#
19# TODO(robinson): Helpers for verbose, common checks like seeing if a
20# descriptor's cpp_type is CPPTYPE_MESSAGE.
21
22"""Contains a metaclass and helper functions used to create
23protocol message classes from Descriptor objects at runtime.
24
25Recall that a metaclass is the "type" of a class.
26(A class is to a metaclass what an instance is to a class.)
27
28In this case, we use the GeneratedProtocolMessageType metaclass
29to inject all the useful functionality into the classes
30output by the protocol compiler at compile-time.
31
32The upshot of all this is that the real implementation
33details for ALL pure-Python protocol buffers are *here in
34this file*.
35"""
36
37__author__ = 'robinson@google.com (Will Robinson)'
38
39import heapq
40import threading
41import weakref
42# We use "as" to avoid name collisions with variables.
43from google.protobuf.internal import decoder
44from google.protobuf.internal import encoder
45from google.protobuf.internal import message_listener as message_listener_mod
46from google.protobuf.internal import wire_format
47from google.protobuf import descriptor as descriptor_mod
48from google.protobuf import message as message_mod
49
50_FieldDescriptor = descriptor_mod.FieldDescriptor
51
52
53class GeneratedProtocolMessageType(type):
54
55 """Metaclass for protocol message classes created at runtime from Descriptors.
56
57 We add implementations for all methods described in the Message class. We
58 also create properties to allow getting/setting all fields in the protocol
59 message. Finally, we create slots to prevent users from accidentally
60 "setting" nonexistent fields in the protocol message, which then wouldn't get
61 serialized / deserialized properly.
62
63 The protocol compiler currently uses this metaclass to create protocol
64 message classes at runtime. Clients can also manually create their own
65 classes at runtime, as in this example:
66
67 mydescriptor = Descriptor(.....)
68 class MyProtoClass(Message):
69 __metaclass__ = GeneratedProtocolMessageType
70 DESCRIPTOR = mydescriptor
71 myproto_instance = MyProtoClass()
72 myproto.foo_field = 23
73 ...
74 """
75
76 # Must be consistent with the protocol-compiler code in
77 # proto2/compiler/internal/generator.*.
78 _DESCRIPTOR_KEY = 'DESCRIPTOR'
79
80 def __new__(cls, name, bases, dictionary):
81 """Custom allocation for runtime-generated class types.
82
83 We override __new__ because this is apparently the only place
84 where we can meaningfully set __slots__ on the class we're creating(?).
85 (The interplay between metaclasses and slots is not very well-documented).
86
87 Args:
88 name: Name of the class (ignored, but required by the
89 metaclass protocol).
90 bases: Base classes of the class we're constructing.
91 (Should be message.Message). We ignore this field, but
92 it's required by the metaclass protocol
93 dictionary: The class dictionary of the class we're
94 constructing. dictionary[_DESCRIPTOR_KEY] must contain
95 a Descriptor object describing this protocol message
96 type.
97
98 Returns:
99 Newly-allocated class.
100 """
101 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
102 _AddSlots(descriptor, dictionary)
103 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
104 superclass = super(GeneratedProtocolMessageType, cls)
105 return superclass.__new__(cls, name, bases, dictionary)
106
107 def __init__(cls, name, bases, dictionary):
108 """Here we perform the majority of our work on the class.
109 We add enum getters, an __init__ method, implementations
110 of all Message methods, and properties for all fields
111 in the protocol type.
112
113 Args:
114 name: Name of the class (ignored, but required by the
115 metaclass protocol).
116 bases: Base classes of the class we're constructing.
117 (Should be message.Message). We ignore this field, but
118 it's required by the metaclass protocol
119 dictionary: The class dictionary of the class we're
120 constructing. dictionary[_DESCRIPTOR_KEY] must contain
121 a Descriptor object describing this protocol message
122 type.
123 """
124 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
125 # We act as a "friend" class of the descriptor, setting
126 # its _concrete_class attribute the first time we use a
127 # given descriptor to initialize a concrete protocol message
128 # class.
129 concrete_class_attr_name = '_concrete_class'
130 if not hasattr(descriptor, concrete_class_attr_name):
131 setattr(descriptor, concrete_class_attr_name, cls)
132 cls._known_extensions = []
133 _AddEnumValues(descriptor, cls)
134 _AddInitMethod(descriptor, cls)
135 _AddPropertiesForFields(descriptor, cls)
136 _AddStaticMethods(cls)
137 _AddMessageMethods(descriptor, cls)
138 _AddPrivateHelperMethods(cls)
139 superclass = super(GeneratedProtocolMessageType, cls)
140 superclass.__init__(cls, name, bases, dictionary)
141
142
143# Stateless helpers for GeneratedProtocolMessageType below.
144# Outside clients should not access these directly.
145#
146# I opted not to make any of these methods on the metaclass, to make it more
147# clear that I'm not really using any state there and to keep clients from
148# thinking that they have direct access to these construction helpers.
149
150
151def _PropertyName(proto_field_name):
152 """Returns the name of the public property attribute which
153 clients can use to get and (in some cases) set the value
154 of a protocol message field.
155
156 Args:
157 proto_field_name: The protocol message field name, exactly
158 as it appears (or would appear) in a .proto file.
159 """
160 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
161 # nnorwitz makes my day by writing:
162 # """
163 # FYI. See the keyword module in the stdlib. This could be as simple as:
164 #
165 # if keyword.iskeyword(proto_field_name):
166 # return proto_field_name + "_"
167 # return proto_field_name
168 # """
169 return proto_field_name
170
171
172def _ValueFieldName(proto_field_name):
173 """Returns the name of the (internal) instance attribute which objects
174 should use to store the current value for a given protocol message field.
175
176 Args:
177 proto_field_name: The protocol message field name, exactly
178 as it appears (or would appear) in a .proto file.
179 """
180 return '_value_' + proto_field_name
181
182
183def _HasFieldName(proto_field_name):
184 """Returns the name of the (internal) instance attribute which
185 objects should use to store a boolean telling whether this field
186 is explicitly set or not.
187
188 Args:
189 proto_field_name: The protocol message field name, exactly
190 as it appears (or would appear) in a .proto file.
191 """
192 return '_has_' + proto_field_name
193
194
195def _AddSlots(message_descriptor, dictionary):
196 """Adds a __slots__ entry to dictionary, containing the names of all valid
197 attributes for this message type.
198
199 Args:
200 message_descriptor: A Descriptor instance describing this message type.
201 dictionary: Class dictionary to which we'll add a '__slots__' entry.
202 """
203 field_names = [_ValueFieldName(f.name) for f in message_descriptor.fields]
204 field_names.extend(_HasFieldName(f.name) for f in message_descriptor.fields
205 if f.label != _FieldDescriptor.LABEL_REPEATED)
206 field_names.extend(('Extensions',
207 '_cached_byte_size',
208 '_cached_byte_size_dirty',
209 '_called_transition_to_nonempty',
210 '_listener',
211 '_lock', '__weakref__'))
212 dictionary['__slots__'] = field_names
213
214
215def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
216 extension_dict = descriptor.extensions_by_name
217 for extension_name, extension_field in extension_dict.iteritems():
218 assert extension_name not in dictionary
219 dictionary[extension_name] = extension_field
220
221
222def _AddEnumValues(descriptor, cls):
223 """Sets class-level attributes for all enum fields defined in this message.
224
225 Args:
226 descriptor: Descriptor object for this message type.
227 cls: Class we're constructing for this message type.
228 """
229 for enum_type in descriptor.enum_types:
230 for enum_value in enum_type.values:
231 setattr(cls, enum_value.name, enum_value.number)
232
233
234def _DefaultValueForField(message, field):
235 """Returns a default value for a field.
236
237 Args:
238 message: Message instance containing this field, or a weakref proxy
239 of same.
240 field: FieldDescriptor object for this field.
241
242 Returns: A default value for this field. May refer back to |message|
243 via a weak reference.
244 """
245 # TODO(robinson): Only the repeated fields need a reference to 'message' (so
246 # that they can set the 'has' bit on the containing Message when someone
247 # append()s a value). We could special-case this, and avoid an extra
248 # function call on __init__() and Clear() for non-repeated fields.
249
250 # TODO(robinson): Find a better place for the default value assertion in this
251 # function. No need to repeat them every time the client calls Clear('foo').
252 # (We should probably just assert these things once and as early as possible,
253 # by tightening checking in the descriptor classes.)
254 if field.label == _FieldDescriptor.LABEL_REPEATED:
255 if field.default_value != []:
256 raise ValueError('Repeated field default value not empty list: %s' % (
257 field.default_value))
258 listener = _Listener(message, None)
259 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
260 # We can't look at _concrete_class yet since it might not have
261 # been set. (Depends on order in which we initialize the classes).
262 return _RepeatedCompositeFieldContainer(listener, field.message_type)
263 else:
264 return _RepeatedScalarFieldContainer(listener,
265 _VALUE_CHECKERS[field.cpp_type])
266
267 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
268 assert field.default_value is None
269
270 return field.default_value
271
272
273def _AddInitMethod(message_descriptor, cls):
274 """Adds an __init__ method to cls."""
275 fields = message_descriptor.fields
276 def init(self):
277 self._cached_byte_size = 0
278 self._cached_byte_size_dirty = False
279 self._listener = message_listener_mod.NullMessageListener()
280 self._called_transition_to_nonempty = False
281 # TODO(robinson): We should only create a lock if we really need one
282 # in this class.
283 self._lock = threading.Lock()
284 for field in fields:
285 default_value = _DefaultValueForField(self, field)
286 python_field_name = _ValueFieldName(field.name)
287 setattr(self, python_field_name, default_value)
288 if field.label != _FieldDescriptor.LABEL_REPEATED:
289 setattr(self, _HasFieldName(field.name), False)
290 self.Extensions = _ExtensionDict(self, cls._known_extensions)
291
292 init.__module__ = None
293 init.__doc__ = None
294 cls.__init__ = init
295
296
297def _AddPropertiesForFields(descriptor, cls):
298 """Adds properties for all fields in this protocol message type."""
299 for field in descriptor.fields:
300 _AddPropertiesForField(field, cls)
301
302
303def _AddPropertiesForField(field, cls):
304 """Adds a public property for a protocol message field.
305 Clients can use this property to get and (in the case
306 of non-repeated scalar fields) directly set the value
307 of a protocol message field.
308
309 Args:
310 field: A FieldDescriptor for this field.
311 cls: The class we're constructing.
312 """
313 # Catch it if we add other types that we should
314 # handle specially here.
315 assert _FieldDescriptor.MAX_CPPTYPE == 10
316
317 if field.label == _FieldDescriptor.LABEL_REPEATED:
318 _AddPropertiesForRepeatedField(field, cls)
319 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
320 _AddPropertiesForNonRepeatedCompositeField(field, cls)
321 else:
322 _AddPropertiesForNonRepeatedScalarField(field, cls)
323
324
325def _AddPropertiesForRepeatedField(field, cls):
326 """Adds a public property for a "repeated" protocol message field. Clients
327 can use this property to get the value of the field, which will be either a
328 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
329 below).
330
331 Note that when clients add values to these containers, we perform
332 type-checking in the case of repeated scalar fields, and we also set any
333 necessary "has" bits as a side-effect.
334
335 Args:
336 field: A FieldDescriptor for this field.
337 cls: The class we're constructing.
338 """
339 proto_field_name = field.name
340 python_field_name = _ValueFieldName(proto_field_name)
341 property_name = _PropertyName(proto_field_name)
342
343 def getter(self):
344 return getattr(self, python_field_name)
345 getter.__module__ = None
346 getter.__doc__ = 'Getter for %s.' % proto_field_name
347
348 # We define a setter just so we can throw an exception with a more
349 # helpful error message.
350 def setter(self, new_value):
351 raise AttributeError('Assignment not allowed to repeated field '
352 '"%s" in protocol message object.' % proto_field_name)
353
354 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
355 setattr(cls, property_name, property(getter, setter, doc=doc))
356
357
358def _AddPropertiesForNonRepeatedScalarField(field, cls):
359 """Adds a public property for a nonrepeated, scalar protocol message field.
360 Clients can use this property to get and directly set the value of the field.
361 Note that when the client sets the value of a field by using this property,
362 all necessary "has" bits are set as a side-effect, and we also perform
363 type-checking.
364
365 Args:
366 field: A FieldDescriptor for this field.
367 cls: The class we're constructing.
368 """
369 proto_field_name = field.name
370 python_field_name = _ValueFieldName(proto_field_name)
371 has_field_name = _HasFieldName(proto_field_name)
372 property_name = _PropertyName(proto_field_name)
373 type_checker = _VALUE_CHECKERS[field.cpp_type]
374
375 def getter(self):
376 return getattr(self, python_field_name)
377 getter.__module__ = None
378 getter.__doc__ = 'Getter for %s.' % proto_field_name
379 def setter(self, new_value):
380 type_checker.CheckValue(new_value)
381 setattr(self, has_field_name, True)
382 self._MarkByteSizeDirty()
383 self._MaybeCallTransitionToNonemptyCallback()
384 setattr(self, python_field_name, new_value)
385 setter.__module__ = None
386 setter.__doc__ = 'Setter for %s.' % proto_field_name
387
388 # Add a property to encapsulate the getter/setter.
389 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
390 setattr(cls, property_name, property(getter, setter, doc=doc))
391
392
393def _AddPropertiesForNonRepeatedCompositeField(field, cls):
394 """Adds a public property for a nonrepeated, composite protocol message field.
395 A composite field is a "group" or "message" field.
396
397 Clients can use this property to get the value of the field, but cannot
398 assign to the property directly.
399
400 Args:
401 field: A FieldDescriptor for this field.
402 cls: The class we're constructing.
403 """
404 # TODO(robinson): Remove duplication with similar method
405 # for non-repeated scalars.
406 proto_field_name = field.name
407 python_field_name = _ValueFieldName(proto_field_name)
408 has_field_name = _HasFieldName(proto_field_name)
409 property_name = _PropertyName(proto_field_name)
410 message_type = field.message_type
411
412 def getter(self):
413 # TODO(robinson): Appropriately scary note about double-checked locking.
414 field_value = getattr(self, python_field_name)
415 if field_value is None:
416 self._lock.acquire()
417 try:
418 field_value = getattr(self, python_field_name)
419 if field_value is None:
420 field_class = message_type._concrete_class
421 field_value = field_class()
422 field_value._SetListener(_Listener(self, has_field_name))
423 setattr(self, python_field_name, field_value)
424 finally:
425 self._lock.release()
426 return field_value
427 getter.__module__ = None
428 getter.__doc__ = 'Getter for %s.' % proto_field_name
429
430 # We define a setter just so we can throw an exception with a more
431 # helpful error message.
432 def setter(self, new_value):
433 raise AttributeError('Assignment not allowed to composite field '
434 '"%s" in protocol message object.' % proto_field_name)
435
436 # Add a property to encapsulate the getter.
437 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
438 setattr(cls, property_name, property(getter, setter, doc=doc))
439
440
441def _AddStaticMethods(cls):
442 # TODO(robinson): This probably needs to be thread-safe(?)
443 def RegisterExtension(extension_handle):
444 extension_handle.containing_type = cls.DESCRIPTOR
445 cls._known_extensions.append(extension_handle)
446 cls.RegisterExtension = staticmethod(RegisterExtension)
447
448
449def _AddListFieldsMethod(message_descriptor, cls):
450 """Helper for _AddMessageMethods()."""
451
452 # Ensure that we always list in ascending field-number order.
453 # For non-extension fields, we can do the sort once, here, at import-time.
454 # For extensions, we sort on each ListFields() call, though
455 # we could do better if we have to.
456 fields = sorted(message_descriptor.fields, key=lambda f: f.number)
457 has_field_names = (_HasFieldName(f.name) for f in fields)
458 value_field_names = (_ValueFieldName(f.name) for f in fields)
459 triplets = zip(has_field_names, value_field_names, fields)
460
461 def ListFields(self):
462 # We need to list all extension and non-extension fields
463 # together, in sorted order by field number.
464
465 # Step 0: Get an iterator over all "set" non-extension fields,
466 # sorted by field number.
467 # This iterator yields (field_number, field_descriptor, value) tuples.
468 def SortedSetFieldsIter():
469 # Note that triplets is already sorted by field number.
470 for has_field_name, value_field_name, field_descriptor in triplets:
471 if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED:
472 value = getattr(self, _ValueFieldName(field_descriptor.name))
473 if len(value) > 0:
474 yield (field_descriptor.number, field_descriptor, value)
475 elif getattr(self, _HasFieldName(field_descriptor.name)):
476 value = getattr(self, _ValueFieldName(field_descriptor.name))
477 yield (field_descriptor.number, field_descriptor, value)
478 sorted_fields = SortedSetFieldsIter()
479
480 # Step 1: Get an iterator over all "set" extension fields,
481 # sorted by field number.
482 # This iterator ALSO yields (field_number, field_descriptor, value) tuples.
483 # TODO(robinson): It's not necessary to repeat this with each
484 # serialization call. We can do better.
485 sorted_extension_fields = sorted(
486 [(f.number, f, v) for f, v in self.Extensions._ListSetExtensions()])
487
488 # Step 2: Create a composite iterator that merges the extension-
489 # and non-extension fields, and that still yields fields in
490 # sorted order.
491 all_set_fields = _ImergeSorted(sorted_fields, sorted_extension_fields)
492
493 # Step 3: Strip off the field numbers and return.
494 return [field[1:] for field in all_set_fields]
495
496 cls.ListFields = ListFields
497
498def _AddHasFieldMethod(cls):
499 """Helper for _AddMessageMethods()."""
500 def HasField(self, field_name):
501 try:
502 return getattr(self, _HasFieldName(field_name))
503 except AttributeError:
504 raise ValueError('Protocol message has no "%s" field.' % field_name)
505 cls.HasField = HasField
506
507
508def _AddClearFieldMethod(cls):
509 """Helper for _AddMessageMethods()."""
510 def ClearField(self, field_name):
511 try:
512 field = self.DESCRIPTOR.fields_by_name[field_name]
513 except KeyError:
514 raise ValueError('Protocol message has no "%s" field.' % field_name)
515 proto_field_name = field.name
516 python_field_name = _ValueFieldName(proto_field_name)
517 has_field_name = _HasFieldName(proto_field_name)
518 default_value = _DefaultValueForField(self, field)
519 if field.label == _FieldDescriptor.LABEL_REPEATED:
520 self._MarkByteSizeDirty()
521 else:
522 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
523 old_field_value = getattr(self, python_field_name)
524 if old_field_value is not None:
525 # Snip the old object out of the object tree.
526 old_field_value._SetListener(None)
527 if getattr(self, has_field_name):
528 setattr(self, has_field_name, False)
529 # Set dirty bit on ourself and parents only if
530 # we're actually changing state.
531 self._MarkByteSizeDirty()
532 setattr(self, python_field_name, default_value)
533 cls.ClearField = ClearField
534
535
536def _AddClearExtensionMethod(cls):
537 """Helper for _AddMessageMethods()."""
538 def ClearExtension(self, extension_handle):
539 self.Extensions._ClearExtension(extension_handle)
540 cls.ClearExtension = ClearExtension
541
542
543def _AddClearMethod(cls):
544 """Helper for _AddMessageMethods()."""
545 def Clear(self):
546 # Clear fields.
547 fields = self.DESCRIPTOR.fields
548 for field in fields:
549 self.ClearField(field.name)
550 # Clear extensions.
551 extensions = self.Extensions._ListSetExtensions()
552 for extension in extensions:
553 self.ClearExtension(extension[0])
554 cls.Clear = Clear
555
556
557def _AddHasExtensionMethod(cls):
558 """Helper for _AddMessageMethods()."""
559 def HasExtension(self, extension_handle):
560 return self.Extensions._HasExtension(extension_handle)
561 cls.HasExtension = HasExtension
562
563
564def _AddEqualsMethod(message_descriptor, cls):
565 """Helper for _AddMessageMethods()."""
566 def __eq__(self, other):
567 if self is other:
568 return True
569
570 # Compare all fields contained directly in this message.
571 for field_descriptor in message_descriptor.fields:
572 label = field_descriptor.label
573 property_name = _PropertyName(field_descriptor.name)
574 # Non-repeated field equality requires matching "has" bits as well
575 # as having an equal value.
576 if label != _FieldDescriptor.LABEL_REPEATED:
577 self_has = self.HasField(property_name)
578 other_has = other.HasField(property_name)
579 if self_has != other_has:
580 return False
581 if not self_has:
582 # If the "has" bit for this field is False, we must stop here.
583 # Otherwise we will recurse forever on recursively-defined protos.
584 continue
585 if getattr(self, property_name) != getattr(other, property_name):
586 return False
587
588 # Compare the extensions present in both messages.
589 return self.Extensions == other.Extensions
590 cls.__eq__ = __eq__
591
592
593def _AddSetListenerMethod(cls):
594 """Helper for _AddMessageMethods()."""
595 def SetListener(self, listener):
596 if listener is None:
597 self._listener = message_listener_mod.NullMessageListener()
598 else:
599 self._listener = listener
600 cls._SetListener = SetListener
601
602
603def _BytesForNonRepeatedElement(value, field_number, field_type):
604 """Returns the number of bytes needed to serialize a non-repeated element.
605 The returned byte count includes space for tag information and any
606 other additional space associated with serializing value.
607
608 Args:
609 value: Value we're serializing.
610 field_number: Field number of this value. (Since the field number
611 is stored as part of a varint-encoded tag, this has an impact
612 on the total bytes required to serialize the value).
613 field_type: The type of the field. One of the TYPE_* constants
614 within FieldDescriptor.
615 """
616 try:
617 fn = _TYPE_TO_BYTE_SIZE_FN[field_type]
618 return fn(field_number, value)
619 except KeyError:
620 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
621
622
623def _AddByteSizeMethod(message_descriptor, cls):
624 """Helper for _AddMessageMethods()."""
625
626 def BytesForField(message, field, value):
627 """Returns the number of bytes required to serialize a single field
628 in message. The field may be repeated or not, composite or not.
629
630 Args:
631 message: The Message instance containing a field of the given type.
632 field: A FieldDescriptor describing the field of interest.
633 value: The value whose byte size we're interested in.
634
635 Returns: The number of bytes required to serialize the current value
636 of "field" in "message", including space for tags and any other
637 necessary information.
638 """
639
640 if _MessageSetField(field):
641 return wire_format.MessageSetItemByteSize(field.number, value)
642
643 field_number, field_type = field.number, field.type
644
645 # Repeated fields.
646 if field.label == _FieldDescriptor.LABEL_REPEATED:
647 elements = value
648 else:
649 elements = [value]
650
651 size = sum(_BytesForNonRepeatedElement(element, field_number, field_type)
652 for element in elements)
653 return size
654
655 fields = message_descriptor.fields
656 has_field_names = (_HasFieldName(f.name) for f in fields)
657 zipped = zip(has_field_names, fields)
658
659 def ByteSize(self):
660 if not self._cached_byte_size_dirty:
661 return self._cached_byte_size
662
663 size = 0
664 # Hardcoded fields first.
665 for has_field_name, field in zipped:
666 if (field.label == _FieldDescriptor.LABEL_REPEATED
667 or getattr(self, has_field_name)):
668 value = getattr(self, _ValueFieldName(field.name))
669 size += BytesForField(self, field, value)
670 # Extensions next.
671 for field, value in self.Extensions._ListSetExtensions():
672 size += BytesForField(self, field, value)
673
674 self._cached_byte_size = size
675 self._cached_byte_size_dirty = False
676 return size
677 cls.ByteSize = ByteSize
678
679
680def _MessageSetField(field_descriptor):
681 """Checks if a field should be serialized using the message set wire format.
682
683 Args:
684 field_descriptor: Descriptor of the field.
685
686 Returns:
687 True if the field should be serialized using the message set wire format,
688 false otherwise.
689 """
690 return (field_descriptor.is_extension and
691 field_descriptor.label != _FieldDescriptor.LABEL_REPEATED and
692 field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
693 field_descriptor.containing_type.GetOptions().message_set_wire_format)
694
695
696def _SerializeValueToEncoder(value, field_number, field_descriptor, encoder):
697 """Appends the serialization of a single value to encoder.
698
699 Args:
700 value: Value to serialize.
701 field_number: Field number of this value.
702 field_descriptor: Descriptor of the field to serialize.
703 encoder: encoder.Encoder object to which we should serialize this value.
704 """
705 if _MessageSetField(field_descriptor):
706 encoder.AppendMessageSetItem(field_number, value)
707 return
708
709 try:
710 method = _TYPE_TO_SERIALIZE_METHOD[field_descriptor.type]
711 method(encoder, field_number, value)
712 except KeyError:
713 raise message_mod.EncodeError('Unrecognized field type: %d' %
714 field_descriptor.type)
715
716
717def _ImergeSorted(*streams):
718 """Merges N sorted iterators into a single sorted iterator.
719 Each element in streams must be an iterable that yields
720 its elements in sorted order, and the elements contained
721 in each stream must all be comparable.
722
723 There may be repeated elements in the component streams or
724 across the streams; the repeated elements will all be repeated
725 in the merged iterator as well.
726
727 I believe that the heapq module at HEAD in the Python
728 sources has a method like this, but for now we roll our own.
729 """
730 iters = [iter(stream) for stream in streams]
731 heap = []
732 for index, it in enumerate(iters):
733 try:
734 heap.append((it.next(), index))
735 except StopIteration:
736 pass
737 heapq.heapify(heap)
738
739 while heap:
740 smallest_value, idx = heap[0]
741 yield smallest_value
742 try:
743 next_element = iters[idx].next()
744 heapq.heapreplace(heap, (next_element, idx))
745 except StopIteration:
746 heapq.heappop(heap)
747
748
749def _AddSerializeToStringMethod(message_descriptor, cls):
750 """Helper for _AddMessageMethods()."""
751 Encoder = encoder.Encoder
752
753 def SerializeToString(self):
754 encoder = Encoder()
755 # We need to serialize all extension and non-extension fields
756 # together, in sorted order by field number.
757
758 # Step 3: Iterate over all extension and non-extension fields, sorted
759 # in order of tag number, and serialize each one to the wire.
760 for field_descriptor, field_value in self.ListFields():
761 if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED:
762 repeated_value = field_value
763 else:
764 repeated_value = [field_value]
765 for element in repeated_value:
766 _SerializeValueToEncoder(element, field_descriptor.number,
767 field_descriptor, encoder)
768 return encoder.ToString()
769 cls.SerializeToString = SerializeToString
770
771
772def _WireTypeForFieldType(field_type):
773 """Given a field type, returns the expected wire type."""
774 try:
775 return _FIELD_TYPE_TO_WIRE_TYPE[field_type]
776 except KeyError:
777 raise message_mod.DecodeError('Unknown field type: %d' % field_type)
778
779
780def _RecursivelyMerge(field_number, field_type, decoder, message):
781 """Decodes a message from decoder into message.
782 message is either a group or a nested message within some containing
783 protocol message. If it's a group, we use the group protocol to
784 deserialize, and if it's a nested message, we use the nested-message
785 protocol.
786
787 Args:
788 field_number: The field number of message in its enclosing protocol buffer.
789 field_type: The field type of message. Must be either TYPE_MESSAGE
790 or TYPE_GROUP.
791 decoder: Decoder to read from.
792 message: Message to deserialize into.
793 """
794 if field_type == _FieldDescriptor.TYPE_MESSAGE:
795 decoder.ReadMessageInto(message)
796 elif field_type == _FieldDescriptor.TYPE_GROUP:
797 decoder.ReadGroupInto(field_number, message)
798 else:
799 raise message_mod.DecodeError('Unexpected field type: %d' % field_type)
800
801
802def _DeserializeScalarFromDecoder(field_type, decoder):
803 """Deserializes a scalar of the requested type from decoder. field_type must
804 be a scalar (non-group, non-message) FieldDescriptor.FIELD_* constant.
805 """
806 try:
807 method = _TYPE_TO_DESERIALIZE_METHOD[field_type]
808 return method(decoder)
809 except KeyError:
810 raise message_mod.DecodeError('Unrecognized field type: %d' % field_type)
811
812
813def _SkipField(field_number, wire_type, decoder):
814 """Skips a field with the specified wire type.
815
816 Args:
817 field_number: Tag number of the field to skip.
818 wire_type: Wire type of the field to skip.
819 decoder: Decoder used to deserialize the messsage. It must be positioned
820 just after reading the the tag and wire type of the field.
821 """
822 if wire_type == wire_format.WIRETYPE_VARINT:
823 decoder.ReadInt32()
824 elif wire_type == wire_format.WIRETYPE_FIXED64:
825 decoder.ReadFixed64()
826 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
827 decoder.SkipBytes(decoder.ReadInt32())
828 elif wire_type == wire_format.WIRETYPE_START_GROUP:
829 _SkipGroup(field_number, decoder)
830 elif wire_type == wire_format.WIRETYPE_END_GROUP:
831 pass
832 elif wire_type == wire_format.WIRETYPE_FIXED32:
833 decoder.ReadFixed32()
834 else:
835 raise message_mod.DecodeError('Unexpected wire type: %d' % wire_type)
836
837
838def _SkipGroup(group_number, decoder):
839 """Skips a nested group from the decoder.
840
841 Args:
842 group_number: Tag number of the group to skip.
843 decoder: Decoder used to deserialize the message. It must be positioned
844 exactly at the beginning of the message that should be skipped.
845 """
846 while True:
847 field_number, wire_type = decoder.ReadFieldNumberAndWireType()
848 if (wire_type == wire_format.WIRETYPE_END_GROUP and
849 field_number == group_number):
850 return
851 _SkipField(field_number, wire_type, decoder)
852
853
854def _DeserializeMessageSetItem(message, decoder):
855 """Deserializes a message using the message set wire format.
856
857 Args:
858 message: Message to be parsed to.
859 decoder: The decoder to be used to deserialize encoded data. Note that the
860 decoder should be positioned just after reading the START_GROUP tag that
861 began the messageset item.
862 """
863 field_number, wire_type = decoder.ReadFieldNumberAndWireType()
864 if wire_type != wire_format.WIRETYPE_VARINT or field_number != 2:
865 raise message_mod.DecodeError(
866 'Incorrect message set wire format. '
867 'wire_type: %d, field_number: %d' % (wire_type, field_number))
868
869 type_id = decoder.ReadInt32()
870 field_number, wire_type = decoder.ReadFieldNumberAndWireType()
871 if wire_type != wire_format.WIRETYPE_LENGTH_DELIMITED or field_number != 3:
872 raise message_mod.DecodeError(
873 'Incorrect message set wire format. '
874 'wire_type: %d, field_number: %d' % (wire_type, field_number))
875
876 extension_dict = message.Extensions
877 extensions_by_number = extension_dict._AllExtensionsByNumber()
878 if type_id not in extensions_by_number:
879 _SkipField(field_number, wire_type, decoder)
880 return
881
882 field_descriptor = extensions_by_number[type_id]
883 value = extension_dict[field_descriptor]
884 decoder.ReadMessageInto(value)
885 # Read the END_GROUP tag.
886 field_number, wire_type = decoder.ReadFieldNumberAndWireType()
887 if wire_type != wire_format.WIRETYPE_END_GROUP or field_number != 1:
888 raise message_mod.DecodeError(
889 'Incorrect message set wire format. '
890 'wire_type: %d, field_number: %d' % (wire_type, field_number))
891
892
893def _DeserializeOneEntity(message_descriptor, message, decoder):
894 """Deserializes the next wire entity from decoder into message.
895 The next wire entity is either a scalar or a nested message,
896 and may also be an element in a repeated field (the wire encoding
897 is the same).
898
899 Args:
900 message_descriptor: A Descriptor instance describing all fields
901 in message.
902 message: The Message instance into which we're decoding our fields.
903 decoder: The Decoder we're using to deserialize encoded data.
904
905 Returns: The number of bytes read from decoder during this method.
906 """
907 initial_position = decoder.Position()
908 field_number, wire_type = decoder.ReadFieldNumberAndWireType()
909 extension_dict = message.Extensions
910 extensions_by_number = extension_dict._AllExtensionsByNumber()
911 if field_number in message_descriptor.fields_by_number:
912 # Non-extension field.
913 field_descriptor = message_descriptor.fields_by_number[field_number]
914 value = getattr(message, _PropertyName(field_descriptor.name))
915 def nonextension_setter_fn(scalar):
916 setattr(message, _PropertyName(field_descriptor.name), scalar)
917 scalar_setter_fn = nonextension_setter_fn
918 elif field_number in extensions_by_number:
919 # Extension field.
920 field_descriptor = extensions_by_number[field_number]
921 value = extension_dict[field_descriptor]
922 def extension_setter_fn(scalar):
923 extension_dict[field_descriptor] = scalar
924 scalar_setter_fn = extension_setter_fn
925 elif wire_type == wire_format.WIRETYPE_END_GROUP:
926 # We assume we're being parsed as the group that's ended.
927 return 0
928 elif (wire_type == wire_format.WIRETYPE_START_GROUP and
929 field_number == 1 and
930 message_descriptor.GetOptions().message_set_wire_format):
931 # A Message Set item.
932 _DeserializeMessageSetItem(message, decoder)
933 return decoder.Position() - initial_position
934 else:
935 _SkipField(field_number, wire_type, decoder)
936 return decoder.Position() - initial_position
937
938 # If we reach this point, we've identified the field as either
939 # hardcoded or extension, and set |field_descriptor|, |scalar_setter_fn|,
940 # and |value| appropriately. Now actually deserialize the thing.
941 #
942 # field_descriptor: Describes the field we're deserializing.
943 # value: The value currently stored in the field to deserialize.
944 # Used only if the field is composite and/or repeated.
945 # scalar_setter_fn: A function F such that F(scalar) will
946 # set a nonrepeated scalar value for this field. Used only
947 # if this field is a nonrepeated scalar.
948
949 field_number = field_descriptor.number
950 field_type = field_descriptor.type
951 expected_wire_type = _WireTypeForFieldType(field_type)
952 if wire_type != expected_wire_type:
953 # Need to fill in uninterpreted_bytes. Work for the next CL.
954 raise RuntimeError('TODO(robinson): Wiretype mismatches not handled.')
955
956 property_name = _PropertyName(field_descriptor.name)
957 label = field_descriptor.label
958 cpp_type = field_descriptor.cpp_type
959
960 # Nonrepeated scalar. Just set the field directly.
961 if (label != _FieldDescriptor.LABEL_REPEATED
962 and cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE):
963 scalar_setter_fn(_DeserializeScalarFromDecoder(field_type, decoder))
964 return decoder.Position() - initial_position
965
966 # Nonrepeated composite. Recursively deserialize.
967 if label != _FieldDescriptor.LABEL_REPEATED:
968 composite = value
969 _RecursivelyMerge(field_number, field_type, decoder, composite)
970 return decoder.Position() - initial_position
971
972 # Now we know we're dealing with a repeated field of some kind.
973 element_list = value
974
975 if cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE:
976 # Repeated scalar.
977 element_list.append(_DeserializeScalarFromDecoder(field_type, decoder))
978 return decoder.Position() - initial_position
979 else:
980 # Repeated composite.
981 composite = element_list.add()
982 _RecursivelyMerge(field_number, field_type, decoder, composite)
983 return decoder.Position() - initial_position
984
985
986def _FieldOrExtensionValues(message, field_or_extension):
987 """Retrieves the list of values for the specified field or extension.
988
989 The target field or extension can be optional, required or repeated, but it
990 must have value(s) set. The assumption is that the target field or extension
991 is set (e.g. _HasFieldOrExtension holds true).
992
993 Args:
994 message: Message which contains the target field or extension.
995 field_or_extension: Field or extension for which the list of values is
996 required. Must be an instance of FieldDescriptor.
997
998 Returns:
999 A list of values for the specified field or extension. This list will only
1000 contain a single element if the field is non-repeated.
1001 """
1002 if field_or_extension.is_extension:
1003 value = message.Extensions[field_or_extension]
1004 else:
1005 value = getattr(message, _ValueFieldName(field_or_extension.name))
1006 if field_or_extension.label != _FieldDescriptor.LABEL_REPEATED:
1007 return [value]
1008 else:
1009 # In this case value is a list or repeated values.
1010 return value
1011
1012
1013def _HasFieldOrExtension(message, field_or_extension):
1014 """Checks if a message has the specified field or extension set.
1015
1016 The field or extension specified can be optional, required or repeated. If
1017 it is repeated, this function returns True. Otherwise it checks the has bit
1018 of the field or extension.
1019
1020 Args:
1021 message: Message which contains the target field or extension.
1022 field_or_extension: Field or extension to check. This must be a
1023 FieldDescriptor instance.
1024
1025 Returns:
1026 True if the message has a value set for the specified field or extension,
1027 or if the field or extension is repeated.
1028 """
1029 if field_or_extension.label == _FieldDescriptor.LABEL_REPEATED:
1030 return True
1031 if field_or_extension.is_extension:
1032 return message.HasExtension(field_or_extension)
1033 else:
1034 return message.HasField(field_or_extension.name)
1035
1036
1037def _IsFieldOrExtensionInitialized(message, field):
1038 """Checks if a message field or extension is initialized.
1039
1040 Args:
1041 message: The message which contains the field or extension.
1042 field: Field or extension to check. This must be a FieldDescriptor instance.
1043
1044 Returns:
1045 True if the field/extension can be considered initialized.
1046 """
1047 # If the field is required and is not set, it isn't initialized.
1048 if field.label == _FieldDescriptor.LABEL_REQUIRED:
1049 if not _HasFieldOrExtension(message, field):
1050 return False
1051
1052 # If the field is optional and is not set, or if it
1053 # isn't a submessage then the field is initialized.
1054 if field.label == _FieldDescriptor.LABEL_OPTIONAL:
1055 if not _HasFieldOrExtension(message, field):
1056 return True
1057 if field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE:
1058 return True
1059
1060 # The field is set and is either a single or a repeated submessage.
1061 messages = _FieldOrExtensionValues(message, field)
1062 # If all submessages in this field are initialized, the field is
1063 # considered initialized.
1064 for message in messages:
1065 if not message.IsInitialized():
1066 return False
1067 return True
1068
1069
1070def _AddMergeFromStringMethod(message_descriptor, cls):
1071 """Helper for _AddMessageMethods()."""
1072 Decoder = decoder.Decoder
1073 def MergeFromString(self, serialized):
1074 decoder = Decoder(serialized)
1075 byte_count = 0
1076 while not decoder.EndOfStream():
1077 bytes_read = _DeserializeOneEntity(message_descriptor, self, decoder)
1078 if not bytes_read:
1079 break
1080 byte_count += bytes_read
1081 return byte_count
1082 cls.MergeFromString = MergeFromString
1083
1084
1085def _AddIsInitializedMethod(message_descriptor, cls):
1086 """Adds the IsInitialized method to the protocol message class."""
1087 def IsInitialized(self):
1088 fields_and_extensions = []
1089 fields_and_extensions.extend(message_descriptor.fields)
1090 fields_and_extensions.extend(
1091 self.Extensions._AllExtensionsByNumber().values())
1092 for field_or_extension in fields_and_extensions:
1093 if not _IsFieldOrExtensionInitialized(self, field_or_extension):
1094 return False
1095 return True
1096 cls.IsInitialized = IsInitialized
1097
1098
1099def _AddMessageMethods(message_descriptor, cls):
1100 """Adds implementations of all Message methods to cls."""
1101
1102 # TODO(robinson): Add support for remaining Message methods.
1103
1104 _AddListFieldsMethod(message_descriptor, cls)
1105 _AddHasFieldMethod(cls)
1106 _AddClearFieldMethod(cls)
1107 _AddClearExtensionMethod(cls)
1108 _AddClearMethod(cls)
1109 _AddHasExtensionMethod(cls)
1110 _AddEqualsMethod(message_descriptor, cls)
1111 _AddSetListenerMethod(cls)
1112 _AddByteSizeMethod(message_descriptor, cls)
1113 _AddSerializeToStringMethod(message_descriptor, cls)
1114 _AddMergeFromStringMethod(message_descriptor, cls)
1115 _AddIsInitializedMethod(message_descriptor, cls)
1116
1117
1118def _AddPrivateHelperMethods(cls):
1119 """Adds implementation of private helper methods to cls."""
1120
1121 def MaybeCallTransitionToNonemptyCallback(self):
1122 """Calls self._listener.TransitionToNonempty() the first time this
1123 method is called. On all subsequent calls, this is a no-op.
1124 """
1125 if not self._called_transition_to_nonempty:
1126 self._listener.TransitionToNonempty()
1127 self._called_transition_to_nonempty = True
1128 cls._MaybeCallTransitionToNonemptyCallback = (
1129 MaybeCallTransitionToNonemptyCallback)
1130
1131 def MarkByteSizeDirty(self):
1132 """Sets the _cached_byte_size_dirty bit to true,
1133 and propagates this to our listener iff this was a state change.
1134 """
1135 if not self._cached_byte_size_dirty:
1136 self._cached_byte_size_dirty = True
1137 self._listener.ByteSizeDirty()
1138 cls._MarkByteSizeDirty = MarkByteSizeDirty
1139
1140
1141class _Listener(object):
1142
1143 """MessageListener implementation that a parent message registers with its
1144 child message.
1145
1146 In order to support semantics like:
1147
1148 foo.bar.baz = 23
1149 assert foo.HasField('bar')
1150
1151 ...child objects must have back references to their parents.
1152 This helper class is at the heart of this support.
1153 """
1154
1155 def __init__(self, parent_message, has_field_name):
1156 """Args:
1157 parent_message: The message whose _MaybeCallTransitionToNonemptyCallback()
1158 and _MarkByteSizeDirty() methods we should call when we receive
1159 TransitionToNonempty() and ByteSizeDirty() messages.
1160 has_field_name: The name of the "has" field that we should set in
1161 the parent message when we receive a TransitionToNonempty message,
1162 or None if there's no "has" field to set. (This will be the case
1163 for child objects in "repeated" fields).
1164 """
1165 # This listener establishes a back reference from a child (contained) object
1166 # to its parent (containing) object. We make this a weak reference to avoid
1167 # creating cyclic garbage when the client finishes with the 'parent' object
1168 # in the tree.
1169 if isinstance(parent_message, weakref.ProxyType):
1170 self._parent_message_weakref = parent_message
1171 else:
1172 self._parent_message_weakref = weakref.proxy(parent_message)
1173 self._has_field_name = has_field_name
1174
1175 def TransitionToNonempty(self):
1176 try:
1177 if self._has_field_name is not None:
1178 setattr(self._parent_message_weakref, self._has_field_name, True)
1179 # Propagate the signal to our parents iff this is the first field set.
1180 self._parent_message_weakref._MaybeCallTransitionToNonemptyCallback()
1181 except ReferenceError:
1182 # We can get here if a client has kept a reference to a child object,
1183 # and is now setting a field on it, but the child's parent has been
1184 # garbage-collected. This is not an error.
1185 pass
1186
1187 def ByteSizeDirty(self):
1188 try:
1189 self._parent_message_weakref._MarkByteSizeDirty()
1190 except ReferenceError:
1191 # Same as above.
1192 pass
1193
1194
1195# TODO(robinson): Move elsewhere?
1196# TODO(robinson): Provide a clear() method here in addition to ClearField()?
1197class _RepeatedScalarFieldContainer(object):
1198
1199 """Simple, type-checked, list-like container for holding repeated scalars.
1200 """
1201
1202 def __init__(self, message_listener, type_checker):
1203 """
1204 Args:
1205 message_listener: A MessageListener implementation.
1206 The _RepeatedScalarFieldContaininer will call this object's
1207 TransitionToNonempty() method when it transitions from being empty to
1208 being nonempty.
1209 type_checker: A _ValueChecker instance to run on elements inserted
1210 into this container.
1211 """
1212 self._message_listener = message_listener
1213 self._type_checker = type_checker
1214 self._values = []
1215
1216 def append(self, elem):
1217 self._type_checker.CheckValue(elem)
1218 self._values.append(elem)
1219 self._message_listener.ByteSizeDirty()
1220 if len(self._values) == 1:
1221 self._message_listener.TransitionToNonempty()
1222
1223 # List-like __getitem__() support also makes us iterable (via "iter(foo)"
1224 # or implicitly via "for i in mylist:") for free.
1225 def __getitem__(self, key):
1226 return self._values[key]
1227
1228 def __setitem__(self, key, value):
1229 # No need to call TransitionToNonempty(), since if we're able to
1230 # set the element at this index, we were already nonempty before
1231 # this method was called.
1232 self._message_listener.ByteSizeDirty()
1233 self._type_checker.CheckValue(value)
1234 self._values[key] = value
1235
1236 def __len__(self):
1237 return len(self._values)
1238
1239 def __eq__(self, other):
1240 if self is other:
1241 return True
1242 # Special case for the same type which should be common and fast.
1243 if isinstance(other, self.__class__):
1244 return other._values == self._values
1245 # We are presumably comparing against some other sequence type.
1246 return other == self._values
1247
1248 def __ne__(self, other):
1249 # Can't use != here since it would infinitely recurse.
1250 return not self == other
1251
1252
1253# TODO(robinson): Move elsewhere?
1254# TODO(robinson): Provide a clear() method here in addition to ClearField()?
1255# TODO(robinson): Unify common functionality with
1256# _RepeatedScalarFieldContaininer?
1257class _RepeatedCompositeFieldContainer(object):
1258
1259 """Simple, list-like container for holding repeated composite fields.
1260 """
1261
1262 def __init__(self, message_listener, message_descriptor):
1263 """Note that we pass in a descriptor instead of the generated directly,
1264 since at the time we construct a _RepeatedCompositeFieldContainer we
1265 haven't yet necessarily initialized the type that will be contained in the
1266 container.
1267
1268 Args:
1269 message_listener: A MessageListener implementation.
1270 The _RepeatedCompositeFieldContainer will call this object's
1271 TransitionToNonempty() method when it transitions from being empty to
1272 being nonempty.
1273 message_descriptor: A Descriptor instance describing the protocol type
1274 that should be present in this container. We'll use the
1275 _concrete_class field of this descriptor when the client calls add().
1276 """
1277 self._message_listener = message_listener
1278 self._message_descriptor = message_descriptor
1279 self._values = []
1280
1281 def add(self):
1282 new_element = self._message_descriptor._concrete_class()
1283 new_element._SetListener(self._message_listener)
1284 self._values.append(new_element)
1285 self._message_listener.ByteSizeDirty()
1286 self._message_listener.TransitionToNonempty()
1287 return new_element
1288
1289 # List-like __getitem__() support also makes us iterable (via "iter(foo)"
1290 # or implicitly via "for i in mylist:") for free.
1291 def __getitem__(self, key):
1292 return self._values[key]
1293
1294 def __len__(self):
1295 return len(self._values)
1296
1297 def __eq__(self, other):
1298 if self is other:
1299 return True
1300 if not isinstance(other, self.__class__):
1301 raise TypeError('Can only compare repeated composite fields against '
1302 'other repeated composite fields.')
1303 return self._values == other._values
1304
1305 def __ne__(self, other):
1306 # Can't use != here since it would infinitely recurse.
1307 return not self == other
1308
1309 # TODO(robinson): Implement, document, and test slicing support.
1310
1311
1312# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous...
1313# TODO(robinson): Unify error handling of "unknown extension" crap.
1314# TODO(robinson): There's so much similarity between the way that
1315# extensions behave and the way that normal fields behave that it would
1316# be really nice to unify more code. It's not immediately obvious
1317# how to do this, though, and I'd rather get the full functionality
1318# implemented (and, crucially, get all the tests and specs fleshed out
1319# and passing), and then come back to this thorny unification problem.
1320# TODO(robinson): Support iteritems()-style iteration over all
1321# extensions with the "has" bits turned on?
1322class _ExtensionDict(object):
1323
1324 """Dict-like container for supporting an indexable "Extensions"
1325 field on proto instances.
1326
1327 Note that in all cases we expect extension handles to be
1328 FieldDescriptors.
1329 """
1330
1331 class _ExtensionListener(object):
1332
1333 """Adapts an _ExtensionDict to behave as a MessageListener."""
1334
1335 def __init__(self, extension_dict, handle_id):
1336 self._extension_dict = extension_dict
1337 self._handle_id = handle_id
1338
1339 def TransitionToNonempty(self):
1340 self._extension_dict._SubmessageTransitionedToNonempty(self._handle_id)
1341
1342 def ByteSizeDirty(self):
1343 self._extension_dict._SubmessageByteSizeBecameDirty()
1344
1345 # TODO(robinson): Somewhere, we need to blow up if people
1346 # try to register two extensions with the same field number.
1347 # (And we need a test for this of course).
1348
1349 def __init__(self, extended_message, known_extensions):
1350 """extended_message: Message instance for which we are the Extensions dict.
1351 known_extensions: Iterable of known extension handles.
1352 These must be FieldDescriptors.
1353 """
1354 # We keep a weak reference to extended_message, since
1355 # it has a reference to this instance in turn.
1356 self._extended_message = weakref.proxy(extended_message)
1357 # We make a deep copy of known_extensions to avoid any
1358 # thread-safety concerns, since the argument passed in
1359 # is the global (class-level) dict of known extensions for
1360 # this type of message, which could be modified at any time
1361 # via a RegisterExtension() call.
1362 #
1363 # This dict maps from handle id to handle (a FieldDescriptor).
1364 #
1365 # XXX
1366 # TODO(robinson): This isn't good enough. The client could
1367 # instantiate an object in module A, then afterward import
1368 # module B and pass the instance to B.Foo(). If B imports
1369 # an extender of this proto and then tries to use it, B
1370 # will get a KeyError, even though the extension *is* registered
1371 # at the time of use.
1372 # XXX
1373 self._known_extensions = dict((id(e), e) for e in known_extensions)
1374 # Read lock around self._values, which may be modified by multiple
1375 # concurrent readers in the conceptually "const" __getitem__ method.
1376 # So, we grab this lock in every "read-only" method to ensure
1377 # that concurrent read access is safe without external locking.
1378 self._lock = threading.Lock()
1379 # Maps from extension handle ID to current value of that extension.
1380 self._values = {}
1381 # Maps from extension handle ID to a boolean "has" bit, but only
1382 # for non-repeated extension fields.
1383 keys = (id for id, extension in self._known_extensions.iteritems()
1384 if extension.label != _FieldDescriptor.LABEL_REPEATED)
1385 self._has_bits = dict.fromkeys(keys, False)
1386
1387 def __getitem__(self, extension_handle):
1388 """Returns the current value of the given extension handle."""
1389 # We don't care as much about keeping critical sections short in the
1390 # extension support, since it's presumably much less of a common case.
1391 self._lock.acquire()
1392 try:
1393 handle_id = id(extension_handle)
1394 if handle_id not in self._known_extensions:
1395 raise KeyError('Extension not known to this class')
1396 if handle_id not in self._values:
1397 self._AddMissingHandle(extension_handle, handle_id)
1398 return self._values[handle_id]
1399 finally:
1400 self._lock.release()
1401
1402 def __eq__(self, other):
1403 # We have to grab read locks since we're accessing _values
1404 # in a "const" method. See the comment in the constructor.
1405 if self is other:
1406 return True
1407 self._lock.acquire()
1408 try:
1409 other._lock.acquire()
1410 try:
1411 if self._has_bits != other._has_bits:
1412 return False
1413 # If there's a "has" bit, then only compare values where it is true.
1414 for k, v in self._values.iteritems():
1415 if self._has_bits.get(k, False) and v != other._values[k]:
1416 return False
1417 return True
1418 finally:
1419 other._lock.release()
1420 finally:
1421 self._lock.release()
1422
1423 def __ne__(self, other):
1424 return not self == other
1425
1426 # Note that this is only meaningful for non-repeated, scalar extension
1427 # fields. Note also that we may have to call
1428 # MaybeCallTransitionToNonemptyCallback() when we do successfully set a field
1429 # this way, to set any necssary "has" bits in the ancestors of the extended
1430 # message.
1431 def __setitem__(self, extension_handle, value):
1432 """If extension_handle specifies a non-repeated, scalar extension
1433 field, sets the value of that field.
1434 """
1435 handle_id = id(extension_handle)
1436 if handle_id not in self._known_extensions:
1437 raise KeyError('Extension not known to this class')
1438 field = extension_handle # Just shorten the name.
1439 if (field.label == _FieldDescriptor.LABEL_OPTIONAL
1440 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE):
1441 # It's slightly wasteful to lookup the type checker each time,
1442 # but we expect this to be a vanishingly uncommon case anyway.
1443 type_checker = _VALUE_CHECKERS[field.cpp_type]
1444 type_checker.CheckValue(value)
1445 self._values[handle_id] = value
1446 self._has_bits[handle_id] = True
1447 self._extended_message._MarkByteSizeDirty()
1448 self._extended_message._MaybeCallTransitionToNonemptyCallback()
1449 else:
1450 raise TypeError('Extension is repeated and/or a composite type.')
1451
1452 def _AddMissingHandle(self, extension_handle, handle_id):
1453 """Helper internal to ExtensionDict."""
1454 # Special handling for non-repeated message extensions, which (like
1455 # normal fields of this kind) are initialized lazily.
1456 # REQUIRES: _lock already held.
1457 cpp_type = extension_handle.cpp_type
1458 label = extension_handle.label
1459 if (cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
1460 and label != _FieldDescriptor.LABEL_REPEATED):
1461 self._AddMissingNonRepeatedCompositeHandle(extension_handle, handle_id)
1462 else:
1463 self._values[handle_id] = _DefaultValueForField(
1464 self._extended_message, extension_handle)
1465
1466 def _AddMissingNonRepeatedCompositeHandle(self, extension_handle, handle_id):
1467 """Helper internal to ExtensionDict."""
1468 # REQUIRES: _lock already held.
1469 value = extension_handle.message_type._concrete_class()
1470 value._SetListener(_ExtensionDict._ExtensionListener(self, handle_id))
1471 self._values[handle_id] = value
1472
1473 def _SubmessageTransitionedToNonempty(self, handle_id):
1474 """Called when a submessage with a given handle id first transitions to
1475 being nonempty. Called by _ExtensionListener.
1476 """
1477 assert handle_id in self._has_bits
1478 self._has_bits[handle_id] = True
1479 self._extended_message._MaybeCallTransitionToNonemptyCallback()
1480
1481 def _SubmessageByteSizeBecameDirty(self):
1482 """Called whenever a submessage's cached byte size becomes invalid
1483 (goes from being "clean" to being "dirty"). Called by _ExtensionListener.
1484 """
1485 self._extended_message._MarkByteSizeDirty()
1486
1487 # We may wish to widen the public interface of Message.Extensions
1488 # to expose some of this private functionality in the future.
1489 # For now, we make all this functionality module-private and just
1490 # implement what we need for serialization/deserialization,
1491 # HasField()/ClearField(), etc.
1492
1493 def _HasExtension(self, extension_handle):
1494 """Method for internal use by this module.
1495 Returns true iff we "have" this extension in the sense of the
1496 "has" bit being set.
1497 """
1498 handle_id = id(extension_handle)
1499 # Note that this is different from the other checks.
1500 if handle_id not in self._has_bits:
1501 raise KeyError('Extension not known to this class, or is repeated field.')
1502 return self._has_bits[handle_id]
1503
1504 # Intentionally pretty similar to ClearField() above.
1505 def _ClearExtension(self, extension_handle):
1506 """Method for internal use by this module.
1507 Clears the specified extension, unsetting its "has" bit.
1508 """
1509 handle_id = id(extension_handle)
1510 if handle_id not in self._known_extensions:
1511 raise KeyError('Extension not known to this class')
1512 default_value = _DefaultValueForField(self._extended_message,
1513 extension_handle)
1514 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
1515 self._extended_message._MarkByteSizeDirty()
1516 else:
1517 cpp_type = extension_handle.cpp_type
1518 if cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1519 if handle_id in self._values:
1520 # Future modifications to this object shouldn't set any
1521 # "has" bits here.
1522 self._values[handle_id]._SetListener(None)
1523 if self._has_bits[handle_id]:
1524 self._has_bits[handle_id] = False
1525 self._extended_message._MarkByteSizeDirty()
1526 if handle_id in self._values:
1527 del self._values[handle_id]
1528
1529 def _ListSetExtensions(self):
1530 """Method for internal use by this module.
1531
1532 Returns an sequence of all extensions that are currently "set"
1533 in this extension dict. A "set" extension is a repeated extension,
1534 or a non-repeated extension with its "has" bit set.
1535
1536 The returned sequence contains (field_descriptor, value) pairs,
1537 where value is the current value of the extension with the given
1538 field descriptor.
1539
1540 The sequence values are in arbitrary order.
1541 """
1542 self._lock.acquire() # Read-only methods must lock around self._values.
1543 try:
1544 set_extensions = []
1545 for handle_id, value in self._values.iteritems():
1546 handle = self._known_extensions[handle_id]
1547 if (handle.label == _FieldDescriptor.LABEL_REPEATED
1548 or self._has_bits[handle_id]):
1549 set_extensions.append((handle, value))
1550 return set_extensions
1551 finally:
1552 self._lock.release()
1553
1554 def _AllExtensionsByNumber(self):
1555 """Method for internal use by this module.
1556
1557 Returns: A dict mapping field_number to (handle, field_descriptor),
1558 for *all* registered extensions for this dict.
1559 """
1560 # TODO(robinson): Precompute and store this away. Note that we'll have to
1561 # be careful when we move away from having _known_extensions as a
1562 # deep-copied member of this object.
1563 return dict((f.number, f) for f in self._known_extensions.itervalues())
1564
1565
1566# None of the typecheckers below make any attempt to guard against people
1567# subclassing builtin types and doing weird things. We're not trying to
1568# protect against malicious clients here, just people accidentally shooting
1569# themselves in the foot in obvious ways.
1570
1571class _TypeChecker(object):
1572
1573 """Type checker used to catch type errors as early as possible
1574 when the client is setting scalar fields in protocol messages.
1575 """
1576
1577 def __init__(self, *acceptable_types):
1578 self._acceptable_types = acceptable_types
1579
1580 def CheckValue(self, proposed_value):
1581 if not isinstance(proposed_value, self._acceptable_types):
1582 message = ('%.1024r has type %s, but expected one of: %s' %
1583 (proposed_value, type(proposed_value), self._acceptable_types))
1584 raise TypeError(message)
1585
1586
1587# _IntValueChecker and its subclasses perform integer type-checks
1588# and bounds-checks.
1589class _IntValueChecker(object):
1590
1591 """Checker used for integer fields. Performs type-check and range check."""
1592
1593 def CheckValue(self, proposed_value):
1594 if not isinstance(proposed_value, (int, long)):
1595 message = ('%.1024r has type %s, but expected one of: %s' %
1596 (proposed_value, type(proposed_value), (int, long)))
1597 raise TypeError(message)
1598 if not self._MIN <= proposed_value <= self._MAX:
1599 raise ValueError('Value out of range: %d' % proposed_value)
1600
1601class _Int32ValueChecker(_IntValueChecker):
1602 # We're sure to use ints instead of longs here since comparison may be more
1603 # efficient.
1604 _MIN = -2147483648
1605 _MAX = 2147483647
1606
1607class _Uint32ValueChecker(_IntValueChecker):
1608 _MIN = 0
1609 _MAX = (1 << 32) - 1
1610
1611class _Int64ValueChecker(_IntValueChecker):
1612 _MIN = -(1 << 63)
1613 _MAX = (1 << 63) - 1
1614
1615class _Uint64ValueChecker(_IntValueChecker):
1616 _MIN = 0
1617 _MAX = (1 << 64) - 1
1618
1619
1620# Type-checkers for all scalar CPPTYPEs.
1621_VALUE_CHECKERS = {
1622 _FieldDescriptor.CPPTYPE_INT32: _Int32ValueChecker(),
1623 _FieldDescriptor.CPPTYPE_INT64: _Int64ValueChecker(),
1624 _FieldDescriptor.CPPTYPE_UINT32: _Uint32ValueChecker(),
1625 _FieldDescriptor.CPPTYPE_UINT64: _Uint64ValueChecker(),
1626 _FieldDescriptor.CPPTYPE_DOUBLE: _TypeChecker(
1627 float, int, long),
1628 _FieldDescriptor.CPPTYPE_FLOAT: _TypeChecker(
1629 float, int, long),
1630 _FieldDescriptor.CPPTYPE_BOOL: _TypeChecker(bool, int),
1631 _FieldDescriptor.CPPTYPE_ENUM: _Int32ValueChecker(),
1632 _FieldDescriptor.CPPTYPE_STRING: _TypeChecker(str),
1633 }
1634
1635
1636# Map from field type to a function F, such that F(field_num, value)
1637# gives the total byte size for a value of the given type. This
1638# byte size includes tag information and any other additional space
1639# associated with serializing "value".
1640_TYPE_TO_BYTE_SIZE_FN = {
1641 _FieldDescriptor.TYPE_DOUBLE: wire_format.DoubleByteSize,
1642 _FieldDescriptor.TYPE_FLOAT: wire_format.FloatByteSize,
1643 _FieldDescriptor.TYPE_INT64: wire_format.Int64ByteSize,
1644 _FieldDescriptor.TYPE_UINT64: wire_format.UInt64ByteSize,
1645 _FieldDescriptor.TYPE_INT32: wire_format.Int32ByteSize,
1646 _FieldDescriptor.TYPE_FIXED64: wire_format.Fixed64ByteSize,
1647 _FieldDescriptor.TYPE_FIXED32: wire_format.Fixed32ByteSize,
1648 _FieldDescriptor.TYPE_BOOL: wire_format.BoolByteSize,
1649 _FieldDescriptor.TYPE_STRING: wire_format.StringByteSize,
1650 _FieldDescriptor.TYPE_GROUP: wire_format.GroupByteSize,
1651 _FieldDescriptor.TYPE_MESSAGE: wire_format.MessageByteSize,
1652 _FieldDescriptor.TYPE_BYTES: wire_format.BytesByteSize,
1653 _FieldDescriptor.TYPE_UINT32: wire_format.UInt32ByteSize,
1654 _FieldDescriptor.TYPE_ENUM: wire_format.EnumByteSize,
1655 _FieldDescriptor.TYPE_SFIXED32: wire_format.SFixed32ByteSize,
1656 _FieldDescriptor.TYPE_SFIXED64: wire_format.SFixed64ByteSize,
1657 _FieldDescriptor.TYPE_SINT32: wire_format.SInt32ByteSize,
1658 _FieldDescriptor.TYPE_SINT64: wire_format.SInt64ByteSize
1659 }
1660
1661# Maps from field type to an unbound Encoder method F, such that
1662# F(encoder, field_number, value) will append the serialization
1663# of a value of this type to the encoder.
1664_Encoder = encoder.Encoder
1665_TYPE_TO_SERIALIZE_METHOD = {
1666 _FieldDescriptor.TYPE_DOUBLE: _Encoder.AppendDouble,
1667 _FieldDescriptor.TYPE_FLOAT: _Encoder.AppendFloat,
1668 _FieldDescriptor.TYPE_INT64: _Encoder.AppendInt64,
1669 _FieldDescriptor.TYPE_UINT64: _Encoder.AppendUInt64,
1670 _FieldDescriptor.TYPE_INT32: _Encoder.AppendInt32,
1671 _FieldDescriptor.TYPE_FIXED64: _Encoder.AppendFixed64,
1672 _FieldDescriptor.TYPE_FIXED32: _Encoder.AppendFixed32,
1673 _FieldDescriptor.TYPE_BOOL: _Encoder.AppendBool,
1674 _FieldDescriptor.TYPE_STRING: _Encoder.AppendString,
1675 _FieldDescriptor.TYPE_GROUP: _Encoder.AppendGroup,
1676 _FieldDescriptor.TYPE_MESSAGE: _Encoder.AppendMessage,
1677 _FieldDescriptor.TYPE_BYTES: _Encoder.AppendBytes,
1678 _FieldDescriptor.TYPE_UINT32: _Encoder.AppendUInt32,
1679 _FieldDescriptor.TYPE_ENUM: _Encoder.AppendEnum,
1680 _FieldDescriptor.TYPE_SFIXED32: _Encoder.AppendSFixed32,
1681 _FieldDescriptor.TYPE_SFIXED64: _Encoder.AppendSFixed64,
1682 _FieldDescriptor.TYPE_SINT32: _Encoder.AppendSInt32,
1683 _FieldDescriptor.TYPE_SINT64: _Encoder.AppendSInt64,
1684 }
1685
1686# Maps from field type to expected wiretype.
1687_FIELD_TYPE_TO_WIRE_TYPE = {
1688 _FieldDescriptor.TYPE_DOUBLE: wire_format.WIRETYPE_FIXED64,
1689 _FieldDescriptor.TYPE_FLOAT: wire_format.WIRETYPE_FIXED32,
1690 _FieldDescriptor.TYPE_INT64: wire_format.WIRETYPE_VARINT,
1691 _FieldDescriptor.TYPE_UINT64: wire_format.WIRETYPE_VARINT,
1692 _FieldDescriptor.TYPE_INT32: wire_format.WIRETYPE_VARINT,
1693 _FieldDescriptor.TYPE_FIXED64: wire_format.WIRETYPE_FIXED64,
1694 _FieldDescriptor.TYPE_FIXED32: wire_format.WIRETYPE_FIXED32,
1695 _FieldDescriptor.TYPE_BOOL: wire_format.WIRETYPE_VARINT,
1696 _FieldDescriptor.TYPE_STRING:
1697 wire_format.WIRETYPE_LENGTH_DELIMITED,
1698 _FieldDescriptor.TYPE_GROUP: wire_format.WIRETYPE_START_GROUP,
1699 _FieldDescriptor.TYPE_MESSAGE:
1700 wire_format.WIRETYPE_LENGTH_DELIMITED,
1701 _FieldDescriptor.TYPE_BYTES:
1702 wire_format.WIRETYPE_LENGTH_DELIMITED,
1703 _FieldDescriptor.TYPE_UINT32: wire_format.WIRETYPE_VARINT,
1704 _FieldDescriptor.TYPE_ENUM: wire_format.WIRETYPE_VARINT,
1705 _FieldDescriptor.TYPE_SFIXED32: wire_format.WIRETYPE_FIXED32,
1706 _FieldDescriptor.TYPE_SFIXED64: wire_format.WIRETYPE_FIXED64,
1707 _FieldDescriptor.TYPE_SINT32: wire_format.WIRETYPE_VARINT,
1708 _FieldDescriptor.TYPE_SINT64: wire_format.WIRETYPE_VARINT,
1709 }
1710
1711# Maps from field type to an unbound Decoder method F,
1712# such that F(decoder) will read a field of the requested type.
1713#
1714# Note that Message and Group are intentionally missing here.
1715# They're handled by _RecursivelyMerge().
1716_Decoder = decoder.Decoder
1717_TYPE_TO_DESERIALIZE_METHOD = {
1718 _FieldDescriptor.TYPE_DOUBLE: _Decoder.ReadDouble,
1719 _FieldDescriptor.TYPE_FLOAT: _Decoder.ReadFloat,
1720 _FieldDescriptor.TYPE_INT64: _Decoder.ReadInt64,
1721 _FieldDescriptor.TYPE_UINT64: _Decoder.ReadUInt64,
1722 _FieldDescriptor.TYPE_INT32: _Decoder.ReadInt32,
1723 _FieldDescriptor.TYPE_FIXED64: _Decoder.ReadFixed64,
1724 _FieldDescriptor.TYPE_FIXED32: _Decoder.ReadFixed32,
1725 _FieldDescriptor.TYPE_BOOL: _Decoder.ReadBool,
1726 _FieldDescriptor.TYPE_STRING: _Decoder.ReadString,
1727 _FieldDescriptor.TYPE_BYTES: _Decoder.ReadBytes,
1728 _FieldDescriptor.TYPE_UINT32: _Decoder.ReadUInt32,
1729 _FieldDescriptor.TYPE_ENUM: _Decoder.ReadEnum,
1730 _FieldDescriptor.TYPE_SFIXED32: _Decoder.ReadSFixed32,
1731 _FieldDescriptor.TYPE_SFIXED64: _Decoder.ReadSFixed64,
1732 _FieldDescriptor.TYPE_SINT32: _Decoder.ReadSInt32,
1733 _FieldDescriptor.TYPE_SINT64: _Decoder.ReadSInt64,
1734 }