blob: c35742ce3fdf0d1c585e69eaae24259eff131b82 [file] [log] [blame]
temporal40ee5512008-07-10 02:12:20 +00001# Protocol Buffers - Google's data interchange format
kenton@google.com24bf56f2008-09-24 20:31:01 +00002# Copyright 2008 Google Inc. All rights reserved.
temporal40ee5512008-07-10 02:12:20 +00003# http://code.google.com/p/protobuf/
4#
kenton@google.com24bf56f2008-09-24 20:31:01 +00005# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
temporal40ee5512008-07-10 02:12:20 +00008#
kenton@google.com24bf56f2008-09-24 20:31:01 +00009# * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11# * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15# * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
temporal40ee5512008-07-10 02:12:20 +000018#
kenton@google.com24bf56f2008-09-24 20:31:01 +000019# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
temporal40ee5512008-07-10 02:12:20 +000030
31# This code is meant to work on Python 2.4 and above only.
32#
33# TODO(robinson): Helpers for verbose, common checks like seeing if a
34# descriptor's cpp_type is CPPTYPE_MESSAGE.
35
36"""Contains a metaclass and helper functions used to create
37protocol message classes from Descriptor objects at runtime.
38
39Recall that a metaclass is the "type" of a class.
40(A class is to a metaclass what an instance is to a class.)
41
42In this case, we use the GeneratedProtocolMessageType metaclass
43to inject all the useful functionality into the classes
44output by the protocol compiler at compile-time.
45
46The upshot of all this is that the real implementation
47details for ALL pure-Python protocol buffers are *here in
48this file*.
49"""
50
51__author__ = 'robinson@google.com (Will Robinson)'
52
53import heapq
54import threading
55import weakref
56# We use "as" to avoid name collisions with variables.
kenton@google.com26bd9ee2008-11-21 00:06:27 +000057from google.protobuf.internal import containers
temporal40ee5512008-07-10 02:12:20 +000058from google.protobuf.internal import decoder
59from google.protobuf.internal import encoder
60from google.protobuf.internal import message_listener as message_listener_mod
temporal779f61c2008-08-13 03:15:00 +000061from google.protobuf.internal import type_checkers
temporal40ee5512008-07-10 02:12:20 +000062from google.protobuf.internal import wire_format
63from google.protobuf import descriptor as descriptor_mod
64from google.protobuf import message as message_mod
65
66_FieldDescriptor = descriptor_mod.FieldDescriptor
67
68
69class GeneratedProtocolMessageType(type):
70
71 """Metaclass for protocol message classes created at runtime from Descriptors.
72
73 We add implementations for all methods described in the Message class. We
74 also create properties to allow getting/setting all fields in the protocol
75 message. Finally, we create slots to prevent users from accidentally
76 "setting" nonexistent fields in the protocol message, which then wouldn't get
77 serialized / deserialized properly.
78
79 The protocol compiler currently uses this metaclass to create protocol
80 message classes at runtime. Clients can also manually create their own
81 classes at runtime, as in this example:
82
83 mydescriptor = Descriptor(.....)
84 class MyProtoClass(Message):
85 __metaclass__ = GeneratedProtocolMessageType
86 DESCRIPTOR = mydescriptor
87 myproto_instance = MyProtoClass()
88 myproto.foo_field = 23
89 ...
90 """
91
92 # Must be consistent with the protocol-compiler code in
93 # proto2/compiler/internal/generator.*.
94 _DESCRIPTOR_KEY = 'DESCRIPTOR'
95
96 def __new__(cls, name, bases, dictionary):
97 """Custom allocation for runtime-generated class types.
98
99 We override __new__ because this is apparently the only place
100 where we can meaningfully set __slots__ on the class we're creating(?).
101 (The interplay between metaclasses and slots is not very well-documented).
102
103 Args:
104 name: Name of the class (ignored, but required by the
105 metaclass protocol).
106 bases: Base classes of the class we're constructing.
107 (Should be message.Message). We ignore this field, but
108 it's required by the metaclass protocol
109 dictionary: The class dictionary of the class we're
110 constructing. dictionary[_DESCRIPTOR_KEY] must contain
111 a Descriptor object describing this protocol message
112 type.
113
114 Returns:
115 Newly-allocated class.
116 """
117 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
118 _AddSlots(descriptor, dictionary)
119 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
120 superclass = super(GeneratedProtocolMessageType, cls)
121 return superclass.__new__(cls, name, bases, dictionary)
122
123 def __init__(cls, name, bases, dictionary):
124 """Here we perform the majority of our work on the class.
125 We add enum getters, an __init__ method, implementations
126 of all Message methods, and properties for all fields
127 in the protocol type.
128
129 Args:
130 name: Name of the class (ignored, but required by the
131 metaclass protocol).
132 bases: Base classes of the class we're constructing.
133 (Should be message.Message). We ignore this field, but
134 it's required by the metaclass protocol
135 dictionary: The class dictionary of the class we're
136 constructing. dictionary[_DESCRIPTOR_KEY] must contain
137 a Descriptor object describing this protocol message
138 type.
139 """
140 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
141 # We act as a "friend" class of the descriptor, setting
142 # its _concrete_class attribute the first time we use a
143 # given descriptor to initialize a concrete protocol message
144 # class.
145 concrete_class_attr_name = '_concrete_class'
146 if not hasattr(descriptor, concrete_class_attr_name):
147 setattr(descriptor, concrete_class_attr_name, cls)
148 cls._known_extensions = []
149 _AddEnumValues(descriptor, cls)
150 _AddInitMethod(descriptor, cls)
151 _AddPropertiesForFields(descriptor, cls)
152 _AddStaticMethods(cls)
153 _AddMessageMethods(descriptor, cls)
154 _AddPrivateHelperMethods(cls)
155 superclass = super(GeneratedProtocolMessageType, cls)
156 superclass.__init__(cls, name, bases, dictionary)
157
158
159# Stateless helpers for GeneratedProtocolMessageType below.
160# Outside clients should not access these directly.
161#
162# I opted not to make any of these methods on the metaclass, to make it more
163# clear that I'm not really using any state there and to keep clients from
164# thinking that they have direct access to these construction helpers.
165
166
167def _PropertyName(proto_field_name):
168 """Returns the name of the public property attribute which
169 clients can use to get and (in some cases) set the value
170 of a protocol message field.
171
172 Args:
173 proto_field_name: The protocol message field name, exactly
174 as it appears (or would appear) in a .proto file.
175 """
176 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
177 # nnorwitz makes my day by writing:
178 # """
179 # FYI. See the keyword module in the stdlib. This could be as simple as:
180 #
181 # if keyword.iskeyword(proto_field_name):
182 # return proto_field_name + "_"
183 # return proto_field_name
184 # """
185 return proto_field_name
186
187
188def _ValueFieldName(proto_field_name):
189 """Returns the name of the (internal) instance attribute which objects
190 should use to store the current value for a given protocol message field.
191
192 Args:
193 proto_field_name: The protocol message field name, exactly
194 as it appears (or would appear) in a .proto file.
195 """
196 return '_value_' + proto_field_name
197
198
199def _HasFieldName(proto_field_name):
200 """Returns the name of the (internal) instance attribute which
201 objects should use to store a boolean telling whether this field
202 is explicitly set or not.
203
204 Args:
205 proto_field_name: The protocol message field name, exactly
206 as it appears (or would appear) in a .proto file.
207 """
208 return '_has_' + proto_field_name
209
210
211def _AddSlots(message_descriptor, dictionary):
212 """Adds a __slots__ entry to dictionary, containing the names of all valid
213 attributes for this message type.
214
215 Args:
216 message_descriptor: A Descriptor instance describing this message type.
217 dictionary: Class dictionary to which we'll add a '__slots__' entry.
218 """
219 field_names = [_ValueFieldName(f.name) for f in message_descriptor.fields]
220 field_names.extend(_HasFieldName(f.name) for f in message_descriptor.fields
221 if f.label != _FieldDescriptor.LABEL_REPEATED)
222 field_names.extend(('Extensions',
223 '_cached_byte_size',
224 '_cached_byte_size_dirty',
225 '_called_transition_to_nonempty',
226 '_listener',
227 '_lock', '__weakref__'))
228 dictionary['__slots__'] = field_names
229
230
231def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
232 extension_dict = descriptor.extensions_by_name
233 for extension_name, extension_field in extension_dict.iteritems():
234 assert extension_name not in dictionary
235 dictionary[extension_name] = extension_field
236
237
238def _AddEnumValues(descriptor, cls):
239 """Sets class-level attributes for all enum fields defined in this message.
240
241 Args:
242 descriptor: Descriptor object for this message type.
243 cls: Class we're constructing for this message type.
244 """
245 for enum_type in descriptor.enum_types:
246 for enum_value in enum_type.values:
247 setattr(cls, enum_value.name, enum_value.number)
248
249
250def _DefaultValueForField(message, field):
251 """Returns a default value for a field.
252
253 Args:
254 message: Message instance containing this field, or a weakref proxy
255 of same.
256 field: FieldDescriptor object for this field.
257
258 Returns: A default value for this field. May refer back to |message|
259 via a weak reference.
260 """
261 # TODO(robinson): Only the repeated fields need a reference to 'message' (so
262 # that they can set the 'has' bit on the containing Message when someone
263 # append()s a value). We could special-case this, and avoid an extra
264 # function call on __init__() and Clear() for non-repeated fields.
265
266 # TODO(robinson): Find a better place for the default value assertion in this
267 # function. No need to repeat them every time the client calls Clear('foo').
268 # (We should probably just assert these things once and as early as possible,
269 # by tightening checking in the descriptor classes.)
270 if field.label == _FieldDescriptor.LABEL_REPEATED:
271 if field.default_value != []:
272 raise ValueError('Repeated field default value not empty list: %s' % (
273 field.default_value))
274 listener = _Listener(message, None)
275 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
276 # We can't look at _concrete_class yet since it might not have
277 # been set. (Depends on order in which we initialize the classes).
kenton@google.com26bd9ee2008-11-21 00:06:27 +0000278 return containers.RepeatedCompositeFieldContainer(
279 listener, field.message_type)
temporal40ee5512008-07-10 02:12:20 +0000280 else:
kenton@google.com26bd9ee2008-11-21 00:06:27 +0000281 return containers.RepeatedScalarFieldContainer(
kenton@google.com24bf56f2008-09-24 20:31:01 +0000282 listener, type_checkers.GetTypeChecker(field.cpp_type, field.type))
temporal40ee5512008-07-10 02:12:20 +0000283
284 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
285 assert field.default_value is None
286
287 return field.default_value
288
289
290def _AddInitMethod(message_descriptor, cls):
291 """Adds an __init__ method to cls."""
292 fields = message_descriptor.fields
293 def init(self):
294 self._cached_byte_size = 0
295 self._cached_byte_size_dirty = False
296 self._listener = message_listener_mod.NullMessageListener()
297 self._called_transition_to_nonempty = False
298 # TODO(robinson): We should only create a lock if we really need one
299 # in this class.
300 self._lock = threading.Lock()
301 for field in fields:
302 default_value = _DefaultValueForField(self, field)
303 python_field_name = _ValueFieldName(field.name)
304 setattr(self, python_field_name, default_value)
305 if field.label != _FieldDescriptor.LABEL_REPEATED:
306 setattr(self, _HasFieldName(field.name), False)
307 self.Extensions = _ExtensionDict(self, cls._known_extensions)
308
309 init.__module__ = None
310 init.__doc__ = None
311 cls.__init__ = init
312
313
314def _AddPropertiesForFields(descriptor, cls):
315 """Adds properties for all fields in this protocol message type."""
316 for field in descriptor.fields:
317 _AddPropertiesForField(field, cls)
318
319
320def _AddPropertiesForField(field, cls):
321 """Adds a public property for a protocol message field.
322 Clients can use this property to get and (in the case
323 of non-repeated scalar fields) directly set the value
324 of a protocol message field.
325
326 Args:
327 field: A FieldDescriptor for this field.
328 cls: The class we're constructing.
329 """
330 # Catch it if we add other types that we should
331 # handle specially here.
332 assert _FieldDescriptor.MAX_CPPTYPE == 10
333
334 if field.label == _FieldDescriptor.LABEL_REPEATED:
335 _AddPropertiesForRepeatedField(field, cls)
336 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
337 _AddPropertiesForNonRepeatedCompositeField(field, cls)
338 else:
339 _AddPropertiesForNonRepeatedScalarField(field, cls)
340
341
342def _AddPropertiesForRepeatedField(field, cls):
343 """Adds a public property for a "repeated" protocol message field. Clients
344 can use this property to get the value of the field, which will be either a
345 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
346 below).
347
348 Note that when clients add values to these containers, we perform
349 type-checking in the case of repeated scalar fields, and we also set any
350 necessary "has" bits as a side-effect.
351
352 Args:
353 field: A FieldDescriptor for this field.
354 cls: The class we're constructing.
355 """
356 proto_field_name = field.name
357 python_field_name = _ValueFieldName(proto_field_name)
358 property_name = _PropertyName(proto_field_name)
359
360 def getter(self):
361 return getattr(self, python_field_name)
362 getter.__module__ = None
363 getter.__doc__ = 'Getter for %s.' % proto_field_name
364
365 # We define a setter just so we can throw an exception with a more
366 # helpful error message.
367 def setter(self, new_value):
368 raise AttributeError('Assignment not allowed to repeated field '
369 '"%s" in protocol message object.' % proto_field_name)
370
371 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
372 setattr(cls, property_name, property(getter, setter, doc=doc))
373
374
375def _AddPropertiesForNonRepeatedScalarField(field, cls):
376 """Adds a public property for a nonrepeated, scalar protocol message field.
377 Clients can use this property to get and directly set the value of the field.
378 Note that when the client sets the value of a field by using this property,
379 all necessary "has" bits are set as a side-effect, and we also perform
380 type-checking.
381
382 Args:
383 field: A FieldDescriptor for this field.
384 cls: The class we're constructing.
385 """
386 proto_field_name = field.name
387 python_field_name = _ValueFieldName(proto_field_name)
388 has_field_name = _HasFieldName(proto_field_name)
389 property_name = _PropertyName(proto_field_name)
kenton@google.com24bf56f2008-09-24 20:31:01 +0000390 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
temporal40ee5512008-07-10 02:12:20 +0000391
392 def getter(self):
393 return getattr(self, python_field_name)
394 getter.__module__ = None
395 getter.__doc__ = 'Getter for %s.' % proto_field_name
396 def setter(self, new_value):
397 type_checker.CheckValue(new_value)
398 setattr(self, has_field_name, True)
399 self._MarkByteSizeDirty()
400 self._MaybeCallTransitionToNonemptyCallback()
401 setattr(self, python_field_name, new_value)
402 setter.__module__ = None
403 setter.__doc__ = 'Setter for %s.' % proto_field_name
404
405 # Add a property to encapsulate the getter/setter.
406 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
407 setattr(cls, property_name, property(getter, setter, doc=doc))
408
409
410def _AddPropertiesForNonRepeatedCompositeField(field, cls):
411 """Adds a public property for a nonrepeated, composite protocol message field.
412 A composite field is a "group" or "message" field.
413
414 Clients can use this property to get the value of the field, but cannot
415 assign to the property directly.
416
417 Args:
418 field: A FieldDescriptor for this field.
419 cls: The class we're constructing.
420 """
421 # TODO(robinson): Remove duplication with similar method
422 # for non-repeated scalars.
423 proto_field_name = field.name
424 python_field_name = _ValueFieldName(proto_field_name)
425 has_field_name = _HasFieldName(proto_field_name)
426 property_name = _PropertyName(proto_field_name)
427 message_type = field.message_type
428
429 def getter(self):
430 # TODO(robinson): Appropriately scary note about double-checked locking.
431 field_value = getattr(self, python_field_name)
432 if field_value is None:
433 self._lock.acquire()
434 try:
435 field_value = getattr(self, python_field_name)
436 if field_value is None:
437 field_class = message_type._concrete_class
438 field_value = field_class()
439 field_value._SetListener(_Listener(self, has_field_name))
440 setattr(self, python_field_name, field_value)
441 finally:
442 self._lock.release()
443 return field_value
444 getter.__module__ = None
445 getter.__doc__ = 'Getter for %s.' % proto_field_name
446
447 # We define a setter just so we can throw an exception with a more
448 # helpful error message.
449 def setter(self, new_value):
450 raise AttributeError('Assignment not allowed to composite field '
451 '"%s" in protocol message object.' % proto_field_name)
452
453 # Add a property to encapsulate the getter.
454 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
455 setattr(cls, property_name, property(getter, setter, doc=doc))
456
457
458def _AddStaticMethods(cls):
459 # TODO(robinson): This probably needs to be thread-safe(?)
460 def RegisterExtension(extension_handle):
461 extension_handle.containing_type = cls.DESCRIPTOR
462 cls._known_extensions.append(extension_handle)
463 cls.RegisterExtension = staticmethod(RegisterExtension)
464
465
466def _AddListFieldsMethod(message_descriptor, cls):
467 """Helper for _AddMessageMethods()."""
468
469 # Ensure that we always list in ascending field-number order.
470 # For non-extension fields, we can do the sort once, here, at import-time.
471 # For extensions, we sort on each ListFields() call, though
472 # we could do better if we have to.
473 fields = sorted(message_descriptor.fields, key=lambda f: f.number)
474 has_field_names = (_HasFieldName(f.name) for f in fields)
475 value_field_names = (_ValueFieldName(f.name) for f in fields)
476 triplets = zip(has_field_names, value_field_names, fields)
477
478 def ListFields(self):
479 # We need to list all extension and non-extension fields
480 # together, in sorted order by field number.
481
482 # Step 0: Get an iterator over all "set" non-extension fields,
483 # sorted by field number.
484 # This iterator yields (field_number, field_descriptor, value) tuples.
485 def SortedSetFieldsIter():
486 # Note that triplets is already sorted by field number.
487 for has_field_name, value_field_name, field_descriptor in triplets:
488 if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED:
489 value = getattr(self, _ValueFieldName(field_descriptor.name))
490 if len(value) > 0:
491 yield (field_descriptor.number, field_descriptor, value)
492 elif getattr(self, _HasFieldName(field_descriptor.name)):
493 value = getattr(self, _ValueFieldName(field_descriptor.name))
494 yield (field_descriptor.number, field_descriptor, value)
495 sorted_fields = SortedSetFieldsIter()
496
497 # Step 1: Get an iterator over all "set" extension fields,
498 # sorted by field number.
499 # This iterator ALSO yields (field_number, field_descriptor, value) tuples.
500 # TODO(robinson): It's not necessary to repeat this with each
501 # serialization call. We can do better.
502 sorted_extension_fields = sorted(
503 [(f.number, f, v) for f, v in self.Extensions._ListSetExtensions()])
504
505 # Step 2: Create a composite iterator that merges the extension-
506 # and non-extension fields, and that still yields fields in
507 # sorted order.
508 all_set_fields = _ImergeSorted(sorted_fields, sorted_extension_fields)
509
510 # Step 3: Strip off the field numbers and return.
511 return [field[1:] for field in all_set_fields]
512
513 cls.ListFields = ListFields
514
515def _AddHasFieldMethod(cls):
516 """Helper for _AddMessageMethods()."""
517 def HasField(self, field_name):
518 try:
519 return getattr(self, _HasFieldName(field_name))
520 except AttributeError:
521 raise ValueError('Protocol message has no "%s" field.' % field_name)
522 cls.HasField = HasField
523
524
525def _AddClearFieldMethod(cls):
526 """Helper for _AddMessageMethods()."""
527 def ClearField(self, field_name):
528 try:
529 field = self.DESCRIPTOR.fields_by_name[field_name]
530 except KeyError:
531 raise ValueError('Protocol message has no "%s" field.' % field_name)
532 proto_field_name = field.name
533 python_field_name = _ValueFieldName(proto_field_name)
534 has_field_name = _HasFieldName(proto_field_name)
535 default_value = _DefaultValueForField(self, field)
536 if field.label == _FieldDescriptor.LABEL_REPEATED:
537 self._MarkByteSizeDirty()
538 else:
539 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
540 old_field_value = getattr(self, python_field_name)
541 if old_field_value is not None:
542 # Snip the old object out of the object tree.
543 old_field_value._SetListener(None)
544 if getattr(self, has_field_name):
545 setattr(self, has_field_name, False)
546 # Set dirty bit on ourself and parents only if
547 # we're actually changing state.
548 self._MarkByteSizeDirty()
549 setattr(self, python_field_name, default_value)
550 cls.ClearField = ClearField
551
552
553def _AddClearExtensionMethod(cls):
554 """Helper for _AddMessageMethods()."""
555 def ClearExtension(self, extension_handle):
556 self.Extensions._ClearExtension(extension_handle)
557 cls.ClearExtension = ClearExtension
558
559
560def _AddClearMethod(cls):
561 """Helper for _AddMessageMethods()."""
562 def Clear(self):
563 # Clear fields.
564 fields = self.DESCRIPTOR.fields
565 for field in fields:
566 self.ClearField(field.name)
567 # Clear extensions.
568 extensions = self.Extensions._ListSetExtensions()
569 for extension in extensions:
570 self.ClearExtension(extension[0])
571 cls.Clear = Clear
572
573
574def _AddHasExtensionMethod(cls):
575 """Helper for _AddMessageMethods()."""
576 def HasExtension(self, extension_handle):
577 return self.Extensions._HasExtension(extension_handle)
578 cls.HasExtension = HasExtension
579
580
581def _AddEqualsMethod(message_descriptor, cls):
582 """Helper for _AddMessageMethods()."""
583 def __eq__(self, other):
584 if self is other:
585 return True
586
587 # Compare all fields contained directly in this message.
588 for field_descriptor in message_descriptor.fields:
589 label = field_descriptor.label
590 property_name = _PropertyName(field_descriptor.name)
591 # Non-repeated field equality requires matching "has" bits as well
592 # as having an equal value.
593 if label != _FieldDescriptor.LABEL_REPEATED:
594 self_has = self.HasField(property_name)
595 other_has = other.HasField(property_name)
596 if self_has != other_has:
597 return False
598 if not self_has:
599 # If the "has" bit for this field is False, we must stop here.
600 # Otherwise we will recurse forever on recursively-defined protos.
601 continue
602 if getattr(self, property_name) != getattr(other, property_name):
603 return False
604
605 # Compare the extensions present in both messages.
606 return self.Extensions == other.Extensions
607 cls.__eq__ = __eq__
608
609
610def _AddSetListenerMethod(cls):
611 """Helper for _AddMessageMethods()."""
612 def SetListener(self, listener):
613 if listener is None:
614 self._listener = message_listener_mod.NullMessageListener()
615 else:
616 self._listener = listener
617 cls._SetListener = SetListener
618
619
620def _BytesForNonRepeatedElement(value, field_number, field_type):
621 """Returns the number of bytes needed to serialize a non-repeated element.
622 The returned byte count includes space for tag information and any
623 other additional space associated with serializing value.
624
625 Args:
626 value: Value we're serializing.
627 field_number: Field number of this value. (Since the field number
628 is stored as part of a varint-encoded tag, this has an impact
629 on the total bytes required to serialize the value).
630 field_type: The type of the field. One of the TYPE_* constants
631 within FieldDescriptor.
632 """
633 try:
temporal779f61c2008-08-13 03:15:00 +0000634 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
temporal40ee5512008-07-10 02:12:20 +0000635 return fn(field_number, value)
636 except KeyError:
637 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
638
639
640def _AddByteSizeMethod(message_descriptor, cls):
641 """Helper for _AddMessageMethods()."""
642
643 def BytesForField(message, field, value):
644 """Returns the number of bytes required to serialize a single field
645 in message. The field may be repeated or not, composite or not.
646
647 Args:
648 message: The Message instance containing a field of the given type.
649 field: A FieldDescriptor describing the field of interest.
650 value: The value whose byte size we're interested in.
651
652 Returns: The number of bytes required to serialize the current value
653 of "field" in "message", including space for tags and any other
654 necessary information.
655 """
656
657 if _MessageSetField(field):
658 return wire_format.MessageSetItemByteSize(field.number, value)
659
660 field_number, field_type = field.number, field.type
661
662 # Repeated fields.
663 if field.label == _FieldDescriptor.LABEL_REPEATED:
664 elements = value
665 else:
666 elements = [value]
667
668 size = sum(_BytesForNonRepeatedElement(element, field_number, field_type)
669 for element in elements)
670 return size
671
672 fields = message_descriptor.fields
673 has_field_names = (_HasFieldName(f.name) for f in fields)
674 zipped = zip(has_field_names, fields)
675
676 def ByteSize(self):
677 if not self._cached_byte_size_dirty:
678 return self._cached_byte_size
679
680 size = 0
681 # Hardcoded fields first.
682 for has_field_name, field in zipped:
683 if (field.label == _FieldDescriptor.LABEL_REPEATED
684 or getattr(self, has_field_name)):
685 value = getattr(self, _ValueFieldName(field.name))
686 size += BytesForField(self, field, value)
687 # Extensions next.
688 for field, value in self.Extensions._ListSetExtensions():
689 size += BytesForField(self, field, value)
690
691 self._cached_byte_size = size
692 self._cached_byte_size_dirty = False
693 return size
694 cls.ByteSize = ByteSize
695
696
697def _MessageSetField(field_descriptor):
698 """Checks if a field should be serialized using the message set wire format.
699
700 Args:
701 field_descriptor: Descriptor of the field.
702
703 Returns:
704 True if the field should be serialized using the message set wire format,
705 false otherwise.
706 """
707 return (field_descriptor.is_extension and
708 field_descriptor.label != _FieldDescriptor.LABEL_REPEATED and
709 field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
710 field_descriptor.containing_type.GetOptions().message_set_wire_format)
711
712
713def _SerializeValueToEncoder(value, field_number, field_descriptor, encoder):
714 """Appends the serialization of a single value to encoder.
715
716 Args:
717 value: Value to serialize.
718 field_number: Field number of this value.
719 field_descriptor: Descriptor of the field to serialize.
720 encoder: encoder.Encoder object to which we should serialize this value.
721 """
722 if _MessageSetField(field_descriptor):
723 encoder.AppendMessageSetItem(field_number, value)
724 return
725
726 try:
temporal779f61c2008-08-13 03:15:00 +0000727 method = type_checkers.TYPE_TO_SERIALIZE_METHOD[field_descriptor.type]
temporal40ee5512008-07-10 02:12:20 +0000728 method(encoder, field_number, value)
729 except KeyError:
730 raise message_mod.EncodeError('Unrecognized field type: %d' %
731 field_descriptor.type)
732
733
734def _ImergeSorted(*streams):
735 """Merges N sorted iterators into a single sorted iterator.
736 Each element in streams must be an iterable that yields
737 its elements in sorted order, and the elements contained
738 in each stream must all be comparable.
739
740 There may be repeated elements in the component streams or
741 across the streams; the repeated elements will all be repeated
742 in the merged iterator as well.
743
744 I believe that the heapq module at HEAD in the Python
745 sources has a method like this, but for now we roll our own.
746 """
747 iters = [iter(stream) for stream in streams]
748 heap = []
749 for index, it in enumerate(iters):
750 try:
751 heap.append((it.next(), index))
752 except StopIteration:
753 pass
754 heapq.heapify(heap)
755
756 while heap:
757 smallest_value, idx = heap[0]
758 yield smallest_value
759 try:
760 next_element = iters[idx].next()
761 heapq.heapreplace(heap, (next_element, idx))
762 except StopIteration:
763 heapq.heappop(heap)
764
765
766def _AddSerializeToStringMethod(message_descriptor, cls):
767 """Helper for _AddMessageMethods()."""
temporal40ee5512008-07-10 02:12:20 +0000768
769 def SerializeToString(self):
temporal779f61c2008-08-13 03:15:00 +0000770 # Check if the message has all of its required fields set.
771 errors = []
772 if not _InternalIsInitialized(self, errors):
773 raise message_mod.EncodeError('\n'.join(errors))
774 return self.SerializePartialToString()
775 cls.SerializeToString = SerializeToString
776
777
778def _AddSerializePartialToStringMethod(message_descriptor, cls):
779 """Helper for _AddMessageMethods()."""
780 Encoder = encoder.Encoder
781
782 def SerializePartialToString(self):
temporal40ee5512008-07-10 02:12:20 +0000783 encoder = Encoder()
784 # We need to serialize all extension and non-extension fields
785 # together, in sorted order by field number.
temporal40ee5512008-07-10 02:12:20 +0000786 for field_descriptor, field_value in self.ListFields():
787 if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED:
788 repeated_value = field_value
789 else:
790 repeated_value = [field_value]
791 for element in repeated_value:
792 _SerializeValueToEncoder(element, field_descriptor.number,
793 field_descriptor, encoder)
794 return encoder.ToString()
temporal779f61c2008-08-13 03:15:00 +0000795 cls.SerializePartialToString = SerializePartialToString
temporal40ee5512008-07-10 02:12:20 +0000796
797
798def _WireTypeForFieldType(field_type):
799 """Given a field type, returns the expected wire type."""
800 try:
temporal779f61c2008-08-13 03:15:00 +0000801 return type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_type]
temporal40ee5512008-07-10 02:12:20 +0000802 except KeyError:
803 raise message_mod.DecodeError('Unknown field type: %d' % field_type)
804
805
806def _RecursivelyMerge(field_number, field_type, decoder, message):
807 """Decodes a message from decoder into message.
808 message is either a group or a nested message within some containing
809 protocol message. If it's a group, we use the group protocol to
810 deserialize, and if it's a nested message, we use the nested-message
811 protocol.
812
813 Args:
814 field_number: The field number of message in its enclosing protocol buffer.
815 field_type: The field type of message. Must be either TYPE_MESSAGE
816 or TYPE_GROUP.
817 decoder: Decoder to read from.
818 message: Message to deserialize into.
819 """
820 if field_type == _FieldDescriptor.TYPE_MESSAGE:
821 decoder.ReadMessageInto(message)
822 elif field_type == _FieldDescriptor.TYPE_GROUP:
823 decoder.ReadGroupInto(field_number, message)
824 else:
825 raise message_mod.DecodeError('Unexpected field type: %d' % field_type)
826
827
828def _DeserializeScalarFromDecoder(field_type, decoder):
829 """Deserializes a scalar of the requested type from decoder. field_type must
830 be a scalar (non-group, non-message) FieldDescriptor.FIELD_* constant.
831 """
832 try:
temporal779f61c2008-08-13 03:15:00 +0000833 method = type_checkers.TYPE_TO_DESERIALIZE_METHOD[field_type]
temporal40ee5512008-07-10 02:12:20 +0000834 return method(decoder)
835 except KeyError:
836 raise message_mod.DecodeError('Unrecognized field type: %d' % field_type)
837
838
839def _SkipField(field_number, wire_type, decoder):
840 """Skips a field with the specified wire type.
841
842 Args:
843 field_number: Tag number of the field to skip.
844 wire_type: Wire type of the field to skip.
845 decoder: Decoder used to deserialize the messsage. It must be positioned
846 just after reading the the tag and wire type of the field.
847 """
848 if wire_type == wire_format.WIRETYPE_VARINT:
kenton@google.com24bf56f2008-09-24 20:31:01 +0000849 decoder.ReadUInt64()
temporal40ee5512008-07-10 02:12:20 +0000850 elif wire_type == wire_format.WIRETYPE_FIXED64:
851 decoder.ReadFixed64()
852 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
853 decoder.SkipBytes(decoder.ReadInt32())
854 elif wire_type == wire_format.WIRETYPE_START_GROUP:
855 _SkipGroup(field_number, decoder)
856 elif wire_type == wire_format.WIRETYPE_END_GROUP:
857 pass
858 elif wire_type == wire_format.WIRETYPE_FIXED32:
859 decoder.ReadFixed32()
860 else:
861 raise message_mod.DecodeError('Unexpected wire type: %d' % wire_type)
862
863
864def _SkipGroup(group_number, decoder):
865 """Skips a nested group from the decoder.
866
867 Args:
868 group_number: Tag number of the group to skip.
869 decoder: Decoder used to deserialize the message. It must be positioned
870 exactly at the beginning of the message that should be skipped.
871 """
872 while True:
873 field_number, wire_type = decoder.ReadFieldNumberAndWireType()
874 if (wire_type == wire_format.WIRETYPE_END_GROUP and
875 field_number == group_number):
876 return
877 _SkipField(field_number, wire_type, decoder)
878
879
880def _DeserializeMessageSetItem(message, decoder):
881 """Deserializes a message using the message set wire format.
882
883 Args:
884 message: Message to be parsed to.
885 decoder: The decoder to be used to deserialize encoded data. Note that the
886 decoder should be positioned just after reading the START_GROUP tag that
887 began the messageset item.
888 """
889 field_number, wire_type = decoder.ReadFieldNumberAndWireType()
890 if wire_type != wire_format.WIRETYPE_VARINT or field_number != 2:
891 raise message_mod.DecodeError(
892 'Incorrect message set wire format. '
893 'wire_type: %d, field_number: %d' % (wire_type, field_number))
894
895 type_id = decoder.ReadInt32()
896 field_number, wire_type = decoder.ReadFieldNumberAndWireType()
897 if wire_type != wire_format.WIRETYPE_LENGTH_DELIMITED or field_number != 3:
898 raise message_mod.DecodeError(
899 'Incorrect message set wire format. '
900 'wire_type: %d, field_number: %d' % (wire_type, field_number))
901
902 extension_dict = message.Extensions
903 extensions_by_number = extension_dict._AllExtensionsByNumber()
904 if type_id not in extensions_by_number:
905 _SkipField(field_number, wire_type, decoder)
906 return
907
908 field_descriptor = extensions_by_number[type_id]
909 value = extension_dict[field_descriptor]
910 decoder.ReadMessageInto(value)
911 # Read the END_GROUP tag.
912 field_number, wire_type = decoder.ReadFieldNumberAndWireType()
913 if wire_type != wire_format.WIRETYPE_END_GROUP or field_number != 1:
914 raise message_mod.DecodeError(
915 'Incorrect message set wire format. '
916 'wire_type: %d, field_number: %d' % (wire_type, field_number))
917
918
919def _DeserializeOneEntity(message_descriptor, message, decoder):
920 """Deserializes the next wire entity from decoder into message.
921 The next wire entity is either a scalar or a nested message,
922 and may also be an element in a repeated field (the wire encoding
923 is the same).
924
925 Args:
926 message_descriptor: A Descriptor instance describing all fields
927 in message.
928 message: The Message instance into which we're decoding our fields.
929 decoder: The Decoder we're using to deserialize encoded data.
930
931 Returns: The number of bytes read from decoder during this method.
932 """
933 initial_position = decoder.Position()
934 field_number, wire_type = decoder.ReadFieldNumberAndWireType()
935 extension_dict = message.Extensions
936 extensions_by_number = extension_dict._AllExtensionsByNumber()
937 if field_number in message_descriptor.fields_by_number:
938 # Non-extension field.
939 field_descriptor = message_descriptor.fields_by_number[field_number]
940 value = getattr(message, _PropertyName(field_descriptor.name))
941 def nonextension_setter_fn(scalar):
942 setattr(message, _PropertyName(field_descriptor.name), scalar)
943 scalar_setter_fn = nonextension_setter_fn
944 elif field_number in extensions_by_number:
945 # Extension field.
946 field_descriptor = extensions_by_number[field_number]
947 value = extension_dict[field_descriptor]
948 def extension_setter_fn(scalar):
949 extension_dict[field_descriptor] = scalar
950 scalar_setter_fn = extension_setter_fn
951 elif wire_type == wire_format.WIRETYPE_END_GROUP:
952 # We assume we're being parsed as the group that's ended.
953 return 0
954 elif (wire_type == wire_format.WIRETYPE_START_GROUP and
955 field_number == 1 and
956 message_descriptor.GetOptions().message_set_wire_format):
957 # A Message Set item.
958 _DeserializeMessageSetItem(message, decoder)
959 return decoder.Position() - initial_position
960 else:
961 _SkipField(field_number, wire_type, decoder)
962 return decoder.Position() - initial_position
963
964 # If we reach this point, we've identified the field as either
965 # hardcoded or extension, and set |field_descriptor|, |scalar_setter_fn|,
966 # and |value| appropriately. Now actually deserialize the thing.
967 #
968 # field_descriptor: Describes the field we're deserializing.
969 # value: The value currently stored in the field to deserialize.
970 # Used only if the field is composite and/or repeated.
971 # scalar_setter_fn: A function F such that F(scalar) will
972 # set a nonrepeated scalar value for this field. Used only
973 # if this field is a nonrepeated scalar.
974
975 field_number = field_descriptor.number
976 field_type = field_descriptor.type
977 expected_wire_type = _WireTypeForFieldType(field_type)
978 if wire_type != expected_wire_type:
979 # Need to fill in uninterpreted_bytes. Work for the next CL.
980 raise RuntimeError('TODO(robinson): Wiretype mismatches not handled.')
981
982 property_name = _PropertyName(field_descriptor.name)
983 label = field_descriptor.label
984 cpp_type = field_descriptor.cpp_type
985
986 # Nonrepeated scalar. Just set the field directly.
987 if (label != _FieldDescriptor.LABEL_REPEATED
988 and cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE):
989 scalar_setter_fn(_DeserializeScalarFromDecoder(field_type, decoder))
990 return decoder.Position() - initial_position
991
992 # Nonrepeated composite. Recursively deserialize.
993 if label != _FieldDescriptor.LABEL_REPEATED:
994 composite = value
995 _RecursivelyMerge(field_number, field_type, decoder, composite)
996 return decoder.Position() - initial_position
997
998 # Now we know we're dealing with a repeated field of some kind.
999 element_list = value
1000
1001 if cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE:
1002 # Repeated scalar.
1003 element_list.append(_DeserializeScalarFromDecoder(field_type, decoder))
1004 return decoder.Position() - initial_position
1005 else:
1006 # Repeated composite.
1007 composite = element_list.add()
1008 _RecursivelyMerge(field_number, field_type, decoder, composite)
1009 return decoder.Position() - initial_position
1010
1011
1012def _FieldOrExtensionValues(message, field_or_extension):
1013 """Retrieves the list of values for the specified field or extension.
1014
1015 The target field or extension can be optional, required or repeated, but it
1016 must have value(s) set. The assumption is that the target field or extension
1017 is set (e.g. _HasFieldOrExtension holds true).
1018
1019 Args:
1020 message: Message which contains the target field or extension.
1021 field_or_extension: Field or extension for which the list of values is
1022 required. Must be an instance of FieldDescriptor.
1023
1024 Returns:
1025 A list of values for the specified field or extension. This list will only
1026 contain a single element if the field is non-repeated.
1027 """
1028 if field_or_extension.is_extension:
1029 value = message.Extensions[field_or_extension]
1030 else:
1031 value = getattr(message, _ValueFieldName(field_or_extension.name))
1032 if field_or_extension.label != _FieldDescriptor.LABEL_REPEATED:
1033 return [value]
1034 else:
1035 # In this case value is a list or repeated values.
1036 return value
1037
1038
1039def _HasFieldOrExtension(message, field_or_extension):
1040 """Checks if a message has the specified field or extension set.
1041
1042 The field or extension specified can be optional, required or repeated. If
1043 it is repeated, this function returns True. Otherwise it checks the has bit
1044 of the field or extension.
1045
1046 Args:
1047 message: Message which contains the target field or extension.
1048 field_or_extension: Field or extension to check. This must be a
1049 FieldDescriptor instance.
1050
1051 Returns:
1052 True if the message has a value set for the specified field or extension,
1053 or if the field or extension is repeated.
1054 """
1055 if field_or_extension.label == _FieldDescriptor.LABEL_REPEATED:
1056 return True
1057 if field_or_extension.is_extension:
1058 return message.HasExtension(field_or_extension)
1059 else:
1060 return message.HasField(field_or_extension.name)
1061
1062
temporal779f61c2008-08-13 03:15:00 +00001063def _IsFieldOrExtensionInitialized(message, field, errors=None):
temporal40ee5512008-07-10 02:12:20 +00001064 """Checks if a message field or extension is initialized.
1065
1066 Args:
1067 message: The message which contains the field or extension.
1068 field: Field or extension to check. This must be a FieldDescriptor instance.
temporal779f61c2008-08-13 03:15:00 +00001069 errors: Errors will be appended to it, if set to a meaningful value.
temporal40ee5512008-07-10 02:12:20 +00001070
1071 Returns:
1072 True if the field/extension can be considered initialized.
1073 """
1074 # If the field is required and is not set, it isn't initialized.
1075 if field.label == _FieldDescriptor.LABEL_REQUIRED:
1076 if not _HasFieldOrExtension(message, field):
temporal779f61c2008-08-13 03:15:00 +00001077 if errors is not None:
1078 errors.append('Required field %s is not set.' % field.full_name)
temporal40ee5512008-07-10 02:12:20 +00001079 return False
1080
1081 # If the field is optional and is not set, or if it
1082 # isn't a submessage then the field is initialized.
1083 if field.label == _FieldDescriptor.LABEL_OPTIONAL:
1084 if not _HasFieldOrExtension(message, field):
1085 return True
1086 if field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE:
1087 return True
1088
1089 # The field is set and is either a single or a repeated submessage.
1090 messages = _FieldOrExtensionValues(message, field)
1091 # If all submessages in this field are initialized, the field is
1092 # considered initialized.
1093 for message in messages:
temporal779f61c2008-08-13 03:15:00 +00001094 if not _InternalIsInitialized(message, errors):
1095 return False
1096 return True
1097
1098
1099def _InternalIsInitialized(message, errors=None):
1100 """Checks if all required fields of a message are set.
1101
1102 Args:
1103 message: The message to check.
1104 errors: If set, initialization errors will be appended to it.
1105
1106 Returns:
1107 True iff the specified message has all required fields set.
1108 """
1109 fields_and_extensions = []
1110 fields_and_extensions.extend(message.DESCRIPTOR.fields)
1111 fields_and_extensions.extend(
1112 [extension[0] for extension in message.Extensions._ListSetExtensions()])
1113 for field_or_extension in fields_and_extensions:
1114 if not _IsFieldOrExtensionInitialized(message, field_or_extension, errors):
temporal40ee5512008-07-10 02:12:20 +00001115 return False
1116 return True
1117
1118
1119def _AddMergeFromStringMethod(message_descriptor, cls):
1120 """Helper for _AddMessageMethods()."""
1121 Decoder = decoder.Decoder
1122 def MergeFromString(self, serialized):
1123 decoder = Decoder(serialized)
1124 byte_count = 0
1125 while not decoder.EndOfStream():
1126 bytes_read = _DeserializeOneEntity(message_descriptor, self, decoder)
1127 if not bytes_read:
1128 break
1129 byte_count += bytes_read
1130 return byte_count
1131 cls.MergeFromString = MergeFromString
1132
1133
temporal779f61c2008-08-13 03:15:00 +00001134def _AddIsInitializedMethod(cls):
temporal40ee5512008-07-10 02:12:20 +00001135 """Adds the IsInitialized method to the protocol message class."""
temporal779f61c2008-08-13 03:15:00 +00001136 cls.IsInitialized = _InternalIsInitialized
1137
1138
1139def _MergeFieldOrExtension(destination_msg, field, value):
1140 """Merges a specified message field into another message."""
1141 property_name = _PropertyName(field.name)
1142 is_extension = field.is_extension
1143
1144 if not is_extension:
1145 destination = getattr(destination_msg, property_name)
1146 elif (field.label == _FieldDescriptor.LABEL_REPEATED or
1147 field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
1148 destination = destination_msg.Extensions[field]
1149
1150 # Case 1 - a composite field.
1151 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1152 if field.label == _FieldDescriptor.LABEL_REPEATED:
1153 for v in value:
1154 destination.add().MergeFrom(v)
1155 else:
1156 destination.MergeFrom(value)
1157 return
1158
1159 # Case 2 - a repeated field.
1160 if field.label == _FieldDescriptor.LABEL_REPEATED:
1161 for v in value:
1162 destination.append(v)
1163 return
1164
1165 # Case 3 - a singular field.
1166 if is_extension:
1167 destination_msg.Extensions[field] = value
1168 else:
1169 setattr(destination_msg, property_name, value)
1170
1171
1172def _AddMergeFromMethod(cls):
1173 def MergeFrom(self, msg):
1174 assert msg is not self
1175 for field in msg.ListFields():
1176 _MergeFieldOrExtension(self, field[0], field[1])
1177 cls.MergeFrom = MergeFrom
temporal40ee5512008-07-10 02:12:20 +00001178
1179
1180def _AddMessageMethods(message_descriptor, cls):
1181 """Adds implementations of all Message methods to cls."""
temporal40ee5512008-07-10 02:12:20 +00001182 _AddListFieldsMethod(message_descriptor, cls)
1183 _AddHasFieldMethod(cls)
1184 _AddClearFieldMethod(cls)
1185 _AddClearExtensionMethod(cls)
1186 _AddClearMethod(cls)
1187 _AddHasExtensionMethod(cls)
1188 _AddEqualsMethod(message_descriptor, cls)
1189 _AddSetListenerMethod(cls)
1190 _AddByteSizeMethod(message_descriptor, cls)
1191 _AddSerializeToStringMethod(message_descriptor, cls)
temporal779f61c2008-08-13 03:15:00 +00001192 _AddSerializePartialToStringMethod(message_descriptor, cls)
temporal40ee5512008-07-10 02:12:20 +00001193 _AddMergeFromStringMethod(message_descriptor, cls)
temporal779f61c2008-08-13 03:15:00 +00001194 _AddIsInitializedMethod(cls)
1195 _AddMergeFromMethod(cls)
temporal40ee5512008-07-10 02:12:20 +00001196
1197
1198def _AddPrivateHelperMethods(cls):
1199 """Adds implementation of private helper methods to cls."""
1200
1201 def MaybeCallTransitionToNonemptyCallback(self):
1202 """Calls self._listener.TransitionToNonempty() the first time this
1203 method is called. On all subsequent calls, this is a no-op.
1204 """
1205 if not self._called_transition_to_nonempty:
1206 self._listener.TransitionToNonempty()
1207 self._called_transition_to_nonempty = True
1208 cls._MaybeCallTransitionToNonemptyCallback = (
1209 MaybeCallTransitionToNonemptyCallback)
1210
1211 def MarkByteSizeDirty(self):
1212 """Sets the _cached_byte_size_dirty bit to true,
1213 and propagates this to our listener iff this was a state change.
1214 """
1215 if not self._cached_byte_size_dirty:
1216 self._cached_byte_size_dirty = True
1217 self._listener.ByteSizeDirty()
1218 cls._MarkByteSizeDirty = MarkByteSizeDirty
1219
1220
1221class _Listener(object):
1222
1223 """MessageListener implementation that a parent message registers with its
1224 child message.
1225
1226 In order to support semantics like:
1227
1228 foo.bar.baz = 23
1229 assert foo.HasField('bar')
1230
1231 ...child objects must have back references to their parents.
1232 This helper class is at the heart of this support.
1233 """
1234
1235 def __init__(self, parent_message, has_field_name):
1236 """Args:
1237 parent_message: The message whose _MaybeCallTransitionToNonemptyCallback()
1238 and _MarkByteSizeDirty() methods we should call when we receive
1239 TransitionToNonempty() and ByteSizeDirty() messages.
1240 has_field_name: The name of the "has" field that we should set in
1241 the parent message when we receive a TransitionToNonempty message,
1242 or None if there's no "has" field to set. (This will be the case
1243 for child objects in "repeated" fields).
1244 """
1245 # This listener establishes a back reference from a child (contained) object
1246 # to its parent (containing) object. We make this a weak reference to avoid
1247 # creating cyclic garbage when the client finishes with the 'parent' object
1248 # in the tree.
1249 if isinstance(parent_message, weakref.ProxyType):
1250 self._parent_message_weakref = parent_message
1251 else:
1252 self._parent_message_weakref = weakref.proxy(parent_message)
1253 self._has_field_name = has_field_name
1254
1255 def TransitionToNonempty(self):
1256 try:
1257 if self._has_field_name is not None:
1258 setattr(self._parent_message_weakref, self._has_field_name, True)
1259 # Propagate the signal to our parents iff this is the first field set.
1260 self._parent_message_weakref._MaybeCallTransitionToNonemptyCallback()
1261 except ReferenceError:
1262 # We can get here if a client has kept a reference to a child object,
1263 # and is now setting a field on it, but the child's parent has been
1264 # garbage-collected. This is not an error.
1265 pass
1266
1267 def ByteSizeDirty(self):
1268 try:
1269 self._parent_message_weakref._MarkByteSizeDirty()
1270 except ReferenceError:
1271 # Same as above.
1272 pass
1273
1274
temporal40ee5512008-07-10 02:12:20 +00001275# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous...
1276# TODO(robinson): Unify error handling of "unknown extension" crap.
1277# TODO(robinson): There's so much similarity between the way that
1278# extensions behave and the way that normal fields behave that it would
1279# be really nice to unify more code. It's not immediately obvious
1280# how to do this, though, and I'd rather get the full functionality
1281# implemented (and, crucially, get all the tests and specs fleshed out
1282# and passing), and then come back to this thorny unification problem.
1283# TODO(robinson): Support iteritems()-style iteration over all
1284# extensions with the "has" bits turned on?
1285class _ExtensionDict(object):
1286
1287 """Dict-like container for supporting an indexable "Extensions"
1288 field on proto instances.
1289
1290 Note that in all cases we expect extension handles to be
1291 FieldDescriptors.
1292 """
1293
1294 class _ExtensionListener(object):
1295
1296 """Adapts an _ExtensionDict to behave as a MessageListener."""
1297
1298 def __init__(self, extension_dict, handle_id):
1299 self._extension_dict = extension_dict
1300 self._handle_id = handle_id
1301
1302 def TransitionToNonempty(self):
1303 self._extension_dict._SubmessageTransitionedToNonempty(self._handle_id)
1304
1305 def ByteSizeDirty(self):
1306 self._extension_dict._SubmessageByteSizeBecameDirty()
1307
1308 # TODO(robinson): Somewhere, we need to blow up if people
1309 # try to register two extensions with the same field number.
1310 # (And we need a test for this of course).
1311
1312 def __init__(self, extended_message, known_extensions):
1313 """extended_message: Message instance for which we are the Extensions dict.
1314 known_extensions: Iterable of known extension handles.
1315 These must be FieldDescriptors.
1316 """
1317 # We keep a weak reference to extended_message, since
1318 # it has a reference to this instance in turn.
1319 self._extended_message = weakref.proxy(extended_message)
1320 # We make a deep copy of known_extensions to avoid any
1321 # thread-safety concerns, since the argument passed in
1322 # is the global (class-level) dict of known extensions for
1323 # this type of message, which could be modified at any time
1324 # via a RegisterExtension() call.
1325 #
1326 # This dict maps from handle id to handle (a FieldDescriptor).
1327 #
1328 # XXX
1329 # TODO(robinson): This isn't good enough. The client could
1330 # instantiate an object in module A, then afterward import
1331 # module B and pass the instance to B.Foo(). If B imports
1332 # an extender of this proto and then tries to use it, B
1333 # will get a KeyError, even though the extension *is* registered
1334 # at the time of use.
1335 # XXX
1336 self._known_extensions = dict((id(e), e) for e in known_extensions)
1337 # Read lock around self._values, which may be modified by multiple
1338 # concurrent readers in the conceptually "const" __getitem__ method.
1339 # So, we grab this lock in every "read-only" method to ensure
1340 # that concurrent read access is safe without external locking.
1341 self._lock = threading.Lock()
1342 # Maps from extension handle ID to current value of that extension.
1343 self._values = {}
1344 # Maps from extension handle ID to a boolean "has" bit, but only
1345 # for non-repeated extension fields.
1346 keys = (id for id, extension in self._known_extensions.iteritems()
1347 if extension.label != _FieldDescriptor.LABEL_REPEATED)
1348 self._has_bits = dict.fromkeys(keys, False)
1349
1350 def __getitem__(self, extension_handle):
1351 """Returns the current value of the given extension handle."""
1352 # We don't care as much about keeping critical sections short in the
1353 # extension support, since it's presumably much less of a common case.
1354 self._lock.acquire()
1355 try:
1356 handle_id = id(extension_handle)
1357 if handle_id not in self._known_extensions:
1358 raise KeyError('Extension not known to this class')
1359 if handle_id not in self._values:
1360 self._AddMissingHandle(extension_handle, handle_id)
1361 return self._values[handle_id]
1362 finally:
1363 self._lock.release()
1364
1365 def __eq__(self, other):
1366 # We have to grab read locks since we're accessing _values
1367 # in a "const" method. See the comment in the constructor.
1368 if self is other:
1369 return True
1370 self._lock.acquire()
1371 try:
1372 other._lock.acquire()
1373 try:
1374 if self._has_bits != other._has_bits:
1375 return False
1376 # If there's a "has" bit, then only compare values where it is true.
1377 for k, v in self._values.iteritems():
1378 if self._has_bits.get(k, False) and v != other._values[k]:
1379 return False
1380 return True
1381 finally:
1382 other._lock.release()
1383 finally:
1384 self._lock.release()
1385
1386 def __ne__(self, other):
1387 return not self == other
1388
1389 # Note that this is only meaningful for non-repeated, scalar extension
1390 # fields. Note also that we may have to call
1391 # MaybeCallTransitionToNonemptyCallback() when we do successfully set a field
1392 # this way, to set any necssary "has" bits in the ancestors of the extended
1393 # message.
1394 def __setitem__(self, extension_handle, value):
1395 """If extension_handle specifies a non-repeated, scalar extension
1396 field, sets the value of that field.
1397 """
1398 handle_id = id(extension_handle)
1399 if handle_id not in self._known_extensions:
1400 raise KeyError('Extension not known to this class')
1401 field = extension_handle # Just shorten the name.
1402 if (field.label == _FieldDescriptor.LABEL_OPTIONAL
1403 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE):
1404 # It's slightly wasteful to lookup the type checker each time,
1405 # but we expect this to be a vanishingly uncommon case anyway.
kenton@google.com24bf56f2008-09-24 20:31:01 +00001406 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
temporal40ee5512008-07-10 02:12:20 +00001407 type_checker.CheckValue(value)
1408 self._values[handle_id] = value
1409 self._has_bits[handle_id] = True
1410 self._extended_message._MarkByteSizeDirty()
1411 self._extended_message._MaybeCallTransitionToNonemptyCallback()
1412 else:
1413 raise TypeError('Extension is repeated and/or a composite type.')
1414
1415 def _AddMissingHandle(self, extension_handle, handle_id):
1416 """Helper internal to ExtensionDict."""
1417 # Special handling for non-repeated message extensions, which (like
1418 # normal fields of this kind) are initialized lazily.
1419 # REQUIRES: _lock already held.
1420 cpp_type = extension_handle.cpp_type
1421 label = extension_handle.label
1422 if (cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
1423 and label != _FieldDescriptor.LABEL_REPEATED):
1424 self._AddMissingNonRepeatedCompositeHandle(extension_handle, handle_id)
1425 else:
1426 self._values[handle_id] = _DefaultValueForField(
1427 self._extended_message, extension_handle)
1428
1429 def _AddMissingNonRepeatedCompositeHandle(self, extension_handle, handle_id):
1430 """Helper internal to ExtensionDict."""
1431 # REQUIRES: _lock already held.
1432 value = extension_handle.message_type._concrete_class()
1433 value._SetListener(_ExtensionDict._ExtensionListener(self, handle_id))
1434 self._values[handle_id] = value
1435
1436 def _SubmessageTransitionedToNonempty(self, handle_id):
1437 """Called when a submessage with a given handle id first transitions to
1438 being nonempty. Called by _ExtensionListener.
1439 """
1440 assert handle_id in self._has_bits
1441 self._has_bits[handle_id] = True
1442 self._extended_message._MaybeCallTransitionToNonemptyCallback()
1443
1444 def _SubmessageByteSizeBecameDirty(self):
1445 """Called whenever a submessage's cached byte size becomes invalid
1446 (goes from being "clean" to being "dirty"). Called by _ExtensionListener.
1447 """
1448 self._extended_message._MarkByteSizeDirty()
1449
1450 # We may wish to widen the public interface of Message.Extensions
1451 # to expose some of this private functionality in the future.
1452 # For now, we make all this functionality module-private and just
1453 # implement what we need for serialization/deserialization,
1454 # HasField()/ClearField(), etc.
1455
1456 def _HasExtension(self, extension_handle):
1457 """Method for internal use by this module.
1458 Returns true iff we "have" this extension in the sense of the
1459 "has" bit being set.
1460 """
1461 handle_id = id(extension_handle)
1462 # Note that this is different from the other checks.
1463 if handle_id not in self._has_bits:
1464 raise KeyError('Extension not known to this class, or is repeated field.')
1465 return self._has_bits[handle_id]
1466
1467 # Intentionally pretty similar to ClearField() above.
1468 def _ClearExtension(self, extension_handle):
1469 """Method for internal use by this module.
1470 Clears the specified extension, unsetting its "has" bit.
1471 """
1472 handle_id = id(extension_handle)
1473 if handle_id not in self._known_extensions:
1474 raise KeyError('Extension not known to this class')
1475 default_value = _DefaultValueForField(self._extended_message,
1476 extension_handle)
1477 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
1478 self._extended_message._MarkByteSizeDirty()
1479 else:
1480 cpp_type = extension_handle.cpp_type
1481 if cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1482 if handle_id in self._values:
1483 # Future modifications to this object shouldn't set any
1484 # "has" bits here.
1485 self._values[handle_id]._SetListener(None)
1486 if self._has_bits[handle_id]:
1487 self._has_bits[handle_id] = False
1488 self._extended_message._MarkByteSizeDirty()
1489 if handle_id in self._values:
1490 del self._values[handle_id]
1491
1492 def _ListSetExtensions(self):
1493 """Method for internal use by this module.
1494
1495 Returns an sequence of all extensions that are currently "set"
1496 in this extension dict. A "set" extension is a repeated extension,
1497 or a non-repeated extension with its "has" bit set.
1498
1499 The returned sequence contains (field_descriptor, value) pairs,
1500 where value is the current value of the extension with the given
1501 field descriptor.
1502
1503 The sequence values are in arbitrary order.
1504 """
1505 self._lock.acquire() # Read-only methods must lock around self._values.
1506 try:
1507 set_extensions = []
1508 for handle_id, value in self._values.iteritems():
1509 handle = self._known_extensions[handle_id]
1510 if (handle.label == _FieldDescriptor.LABEL_REPEATED
1511 or self._has_bits[handle_id]):
1512 set_extensions.append((handle, value))
1513 return set_extensions
1514 finally:
1515 self._lock.release()
1516
1517 def _AllExtensionsByNumber(self):
1518 """Method for internal use by this module.
1519
1520 Returns: A dict mapping field_number to (handle, field_descriptor),
1521 for *all* registered extensions for this dict.
1522 """
1523 # TODO(robinson): Precompute and store this away. Note that we'll have to
1524 # be careful when we move away from having _known_extensions as a
1525 # deep-copied member of this object.
1526 return dict((f.number, f) for f in self._known_extensions.itervalues())