blob: 3f426502f8f314ac0dd740ce4d7e66ee8149bb5a [file] [log] [blame]
liujisi@google.com33165fe2010-11-02 13:14:58 +00001# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3# http://code.google.com/p/protobuf/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9# * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11# * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15# * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Contains helper functions used to create protocol message classes from
32Descriptor objects at runtime backed by the protocol buffer C++ API.
33"""
34
35__author__ = 'petar@google.com (Petar Petrov)'
36
37import operator
38from google.protobuf.internal import _net_proto2___python
39from google.protobuf import message
40
41
42_LABEL_REPEATED = _net_proto2___python.LABEL_REPEATED
43_LABEL_OPTIONAL = _net_proto2___python.LABEL_OPTIONAL
44_CPPTYPE_MESSAGE = _net_proto2___python.CPPTYPE_MESSAGE
45_TYPE_MESSAGE = _net_proto2___python.TYPE_MESSAGE
46
47
48def GetDescriptorPool():
49 """Creates a new DescriptorPool C++ object."""
50 return _net_proto2___python.NewCDescriptorPool()
51
52
53_pool = GetDescriptorPool()
54
55
56def GetFieldDescriptor(full_field_name):
57 """Searches for a field descriptor given a full field name."""
58 return _pool.FindFieldByName(full_field_name)
59
60
61def BuildFile(content):
62 """Registers a new proto file in the underlying C++ descriptor pool."""
63 _net_proto2___python.BuildFile(content)
64
65
66def GetExtensionDescriptor(full_extension_name):
67 """Searches for extension descriptor given a full field name."""
68 return _pool.FindExtensionByName(full_extension_name)
69
70
71def NewCMessage(full_message_name):
72 """Creates a new C++ protocol message by its name."""
73 return _net_proto2___python.NewCMessage(full_message_name)
74
75
76def ScalarProperty(cdescriptor):
77 """Returns a scalar property for the given descriptor."""
78
79 def Getter(self):
80 return self._cmsg.GetScalar(cdescriptor)
81
82 def Setter(self, value):
83 self._cmsg.SetScalar(cdescriptor, value)
84
85 return property(Getter, Setter)
86
87
88def CompositeProperty(cdescriptor, message_type):
89 """Returns a Python property the given composite field."""
90
91 def Getter(self):
92 sub_message = self._composite_fields.get(cdescriptor.name, None)
93 if sub_message is None:
94 cmessage = self._cmsg.NewSubMessage(cdescriptor)
95 sub_message = message_type._concrete_class(__cmessage=cmessage)
96 self._composite_fields[cdescriptor.name] = sub_message
97 return sub_message
98
99 return property(Getter)
100
101
102class RepeatedScalarContainer(object):
103 """Container for repeated scalar fields."""
104
105 __slots__ = ['_message', '_cfield_descriptor', '_cmsg']
106
107 def __init__(self, msg, cfield_descriptor):
108 self._message = msg
109 self._cmsg = msg._cmsg
110 self._cfield_descriptor = cfield_descriptor
111
112 def append(self, value):
113 self._cmsg.AddRepeatedScalar(
114 self._cfield_descriptor, value)
115
116 def extend(self, sequence):
117 for element in sequence:
118 self.append(element)
119
120 def insert(self, key, value):
121 values = self[slice(None, None, None)]
122 values.insert(key, value)
123 self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values)
124
125 def remove(self, value):
126 values = self[slice(None, None, None)]
127 values.remove(value)
128 self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values)
129
130 def __setitem__(self, key, value):
131 values = self[slice(None, None, None)]
132 values[key] = value
133 self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values)
134
135 def __getitem__(self, key):
136 return self._cmsg.GetRepeatedScalar(self._cfield_descriptor, key)
137
138 def __delitem__(self, key):
139 self._cmsg.DeleteRepeatedField(self._cfield_descriptor, key)
140
141 def __len__(self):
142 return len(self[slice(None, None, None)])
143
144 def __eq__(self, other):
145 if self is other:
146 return True
147 if not operator.isSequenceType(other):
148 raise TypeError(
149 'Can only compare repeated scalar fields against sequences.')
150 # We are presumably comparing against some other sequence type.
151 return other == self[slice(None, None, None)]
152
153 def __ne__(self, other):
154 return not self == other
155
156 def __hash__(self):
157 raise TypeError('unhashable object')
158
159 def sort(self, sort_function=cmp):
160 values = self[slice(None, None, None)]
161 values.sort(sort_function)
162 self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values)
163
164
165def RepeatedScalarProperty(cdescriptor):
166 """Returns a Python property the given repeated scalar field."""
167
168 def Getter(self):
169 container = self._composite_fields.get(cdescriptor.name, None)
170 if container is None:
171 container = RepeatedScalarContainer(self, cdescriptor)
172 self._composite_fields[cdescriptor.name] = container
173 return container
174
175 def Setter(self, new_value):
176 raise AttributeError('Assignment not allowed to repeated field '
177 '"%s" in protocol message object.' % cdescriptor.name)
178
179 doc = 'Magic attribute generated for "%s" proto field.' % cdescriptor.name
180 return property(Getter, Setter, doc=doc)
181
182
183class RepeatedCompositeContainer(object):
184 """Container for repeated composite fields."""
185
186 __slots__ = ['_message', '_subclass', '_cfield_descriptor', '_cmsg']
187
188 def __init__(self, msg, cfield_descriptor, subclass):
189 self._message = msg
190 self._cmsg = msg._cmsg
191 self._subclass = subclass
192 self._cfield_descriptor = cfield_descriptor
193
194 def add(self, **kwargs):
195 cmessage = self._cmsg.AddMessage(self._cfield_descriptor)
196 return self._subclass(__cmessage=cmessage, __owner=self._message, **kwargs)
197
198 def extend(self, elem_seq):
199 """Extends by appending the given sequence of elements of the same type
200 as this one, copying each individual message.
201 """
202 for message in elem_seq:
203 self.add().MergeFrom(message)
204
205 def MergeFrom(self, other):
206 for message in other[:]:
207 self.add().MergeFrom(message)
208
209 def __getitem__(self, key):
210 cmessages = self._cmsg.GetRepeatedMessage(
211 self._cfield_descriptor, key)
212 subclass = self._subclass
213 if not isinstance(cmessages, list):
214 return subclass(__cmessage=cmessages, __owner=self._message)
215
216 return [subclass(__cmessage=m, __owner=self._message) for m in cmessages]
217
218 def __delitem__(self, key):
219 self._cmsg.DeleteRepeatedField(
220 self._cfield_descriptor, key)
221
222 def __len__(self):
223 return self._cmsg.FieldLength(self._cfield_descriptor)
224
225 def __eq__(self, other):
226 """Compares the current instance with another one."""
227 if self is other:
228 return True
229 if not isinstance(other, self.__class__):
230 raise TypeError('Can only compare repeated composite fields against '
231 'other repeated composite fields.')
232 messages = self[slice(None, None, None)]
233 other_messages = other[slice(None, None, None)]
234 return messages == other_messages
235
236 def __hash__(self):
237 raise TypeError('unhashable object')
238
239 def sort(self, sort_function=cmp):
240 messages = []
241 for index in range(len(self)):
242 # messages[i][0] is where the i-th element of the new array has to come
243 # from.
244 # messages[i][1] is where the i-th element of the old array has to go.
245 messages.append([index, 0, self[index]])
246 messages.sort(lambda x,y: sort_function(x[2], y[2]))
247
248 # Remember which position each elements has to move to.
249 for i in range(len(messages)):
250 messages[messages[i][0]][1] = i
251
252 # Apply the transposition.
253 for i in range(len(messages)):
254 from_position = messages[i][0]
255 if i == from_position:
256 continue
257 self._cmsg.SwapRepeatedFieldElements(
258 self._cfield_descriptor, i, from_position)
259 messages[messages[i][1]][0] = from_position
260
261
262def RepeatedCompositeProperty(cdescriptor, message_type):
263 """Returns a Python property for the given repeated composite field."""
264
265 def Getter(self):
266 container = self._composite_fields.get(cdescriptor.name, None)
267 if container is None:
268 container = RepeatedCompositeContainer(
269 self, cdescriptor, message_type._concrete_class)
270 self._composite_fields[cdescriptor.name] = container
271 return container
272
273 def Setter(self, new_value):
274 raise AttributeError('Assignment not allowed to repeated field '
275 '"%s" in protocol message object.' % cdescriptor.name)
276
277 doc = 'Magic attribute generated for "%s" proto field.' % cdescriptor.name
278 return property(Getter, Setter, doc=doc)
279
280
281class ExtensionDict(object):
282 """Extension dictionary added to each protocol message."""
283
284 def __init__(self, msg):
285 self._message = msg
286 self._cmsg = msg._cmsg
287 self._values = {}
288
289 def __setitem__(self, extension, value):
290 from google.protobuf import descriptor
291 if not isinstance(extension, descriptor.FieldDescriptor):
292 raise KeyError('Bad extension %r.' % (extension,))
293 cdescriptor = extension._cdescriptor
294 if (cdescriptor.label != _LABEL_OPTIONAL or
295 cdescriptor.cpp_type == _CPPTYPE_MESSAGE):
296 raise TypeError('Extension %r is repeated and/or a composite type.' % (
297 extension.full_name,))
298 self._cmsg.SetScalar(cdescriptor, value)
299 self._values[extension] = value
300
301 def __getitem__(self, extension):
302 from google.protobuf import descriptor
303 if not isinstance(extension, descriptor.FieldDescriptor):
304 raise KeyError('Bad extension %r.' % (extension,))
305
306 cdescriptor = extension._cdescriptor
307 if (cdescriptor.label != _LABEL_REPEATED and
308 cdescriptor.cpp_type != _CPPTYPE_MESSAGE):
309 return self._cmsg.GetScalar(cdescriptor)
310
311 ext = self._values.get(extension, None)
312 if ext is not None:
313 return ext
314
315 ext = self._CreateNewHandle(extension)
316 self._values[extension] = ext
317 return ext
318
319 def ClearExtension(self, extension):
320 from google.protobuf import descriptor
321 if not isinstance(extension, descriptor.FieldDescriptor):
322 raise KeyError('Bad extension %r.' % (extension,))
323 self._cmsg.ClearFieldByDescriptor(extension._cdescriptor)
324 if extension in self._values:
325 del self._values[extension]
326
327 def HasExtension(self, extension):
328 from google.protobuf import descriptor
329 if not isinstance(extension, descriptor.FieldDescriptor):
330 raise KeyError('Bad extension %r.' % (extension,))
331 return self._cmsg.HasFieldByDescriptor(extension._cdescriptor)
332
333 def _FindExtensionByName(self, name):
334 """Tries to find a known extension with the specified name.
335
336 Args:
337 name: Extension full name.
338
339 Returns:
340 Extension field descriptor.
341 """
342 return self._message._extensions_by_name.get(name, None)
343
344 def _CreateNewHandle(self, extension):
345 cdescriptor = extension._cdescriptor
346 if (cdescriptor.label != _LABEL_REPEATED and
347 cdescriptor.cpp_type == _CPPTYPE_MESSAGE):
348 cmessage = self._cmsg.NewSubMessage(cdescriptor)
349 return extension.message_type._concrete_class(__cmessage=cmessage)
350
351 if cdescriptor.label == _LABEL_REPEATED:
352 if cdescriptor.cpp_type == _CPPTYPE_MESSAGE:
353 return RepeatedCompositeContainer(
354 self._message, cdescriptor, extension.message_type._concrete_class)
355 else:
356 return RepeatedScalarContainer(self._message, cdescriptor)
357 # This shouldn't happen!
358 assert False
359 return None
360
361
362def NewMessage(message_descriptor, dictionary):
363 """Creates a new protocol message *class*."""
364 _AddClassAttributesForNestedExtensions(message_descriptor, dictionary)
365 _AddEnumValues(message_descriptor, dictionary)
366 _AddDescriptors(message_descriptor, dictionary)
367
368
369def InitMessage(message_descriptor, cls):
370 """Constructs a new message instance (called before instance's __init__)."""
371 cls._extensions_by_name = {}
372 _AddInitMethod(message_descriptor, cls)
373 _AddMessageMethods(message_descriptor, cls)
374 _AddPropertiesForExtensions(message_descriptor, cls)
375
376
377def _AddDescriptors(message_descriptor, dictionary):
378 """Sets up a new protocol message class dictionary.
379
380 Args:
381 message_descriptor: A Descriptor instance describing this message type.
382 dictionary: Class dictionary to which we'll add a '__slots__' entry.
383 """
384 dictionary['__descriptors'] = {}
385 for field in message_descriptor.fields:
386 dictionary['__descriptors'][field.name] = GetFieldDescriptor(
387 field.full_name)
388
389 dictionary['__slots__'] = list(dictionary['__descriptors'].iterkeys()) + [
390 '_cmsg', '_owner', '_composite_fields', 'Extensions']
391
392
393def _AddEnumValues(message_descriptor, dictionary):
394 """Sets class-level attributes for all enum fields defined in this message.
395
396 Args:
397 message_descriptor: Descriptor object for this message type.
398 dictionary: Class dictionary that should be populated.
399 """
400 for enum_type in message_descriptor.enum_types:
401 for enum_value in enum_type.values:
402 dictionary[enum_value.name] = enum_value.number
403
404
405def _AddClassAttributesForNestedExtensions(message_descriptor, dictionary):
406 """Adds class attributes for the nested extensions."""
407 extension_dict = message_descriptor.extensions_by_name
408 for extension_name, extension_field in extension_dict.iteritems():
409 assert extension_name not in dictionary
410 dictionary[extension_name] = extension_field
411
412
413def _AddInitMethod(message_descriptor, cls):
414 """Adds an __init__ method to cls."""
415
416 # Create and attach message field properties to the message class.
417 # This can be done just once per message class, since property setters and
418 # getters are passed the message instance.
419 # This makes message instantiation extremely fast, and at the same time it
420 # doesn't require the creation of property objects for each message instance,
421 # which saves a lot of memory.
422 for field in message_descriptor.fields:
423 field_cdescriptor = cls.__descriptors[field.name]
424 if field.label == _LABEL_REPEATED:
425 if field.cpp_type == _CPPTYPE_MESSAGE:
426 value = RepeatedCompositeProperty(field_cdescriptor, field.message_type)
427 else:
428 value = RepeatedScalarProperty(field_cdescriptor)
429 elif field.cpp_type == _CPPTYPE_MESSAGE:
430 value = CompositeProperty(field_cdescriptor, field.message_type)
431 else:
432 value = ScalarProperty(field_cdescriptor)
433 setattr(cls, field.name, value)
434
435 # Attach a constant with the field number.
436 constant_name = field.name.upper() + '_FIELD_NUMBER'
437 setattr(cls, constant_name, field.number)
438
439 def Init(self, **kwargs):
440 """Message constructor."""
441 cmessage = kwargs.pop('__cmessage', None)
442 if cmessage is None:
443 self._cmsg = NewCMessage(message_descriptor.full_name)
444 else:
445 self._cmsg = cmessage
446
447 # Keep a reference to the owner, as the owner keeps a reference to the
448 # underlying protocol buffer message.
449 owner = kwargs.pop('__owner', None)
450 if owner is not None:
451 self._owner = owner
452
453 self.Extensions = ExtensionDict(self)
454 self._composite_fields = {}
455
456 for field_name, field_value in kwargs.iteritems():
457 field_cdescriptor = self.__descriptors.get(field_name, None)
458 if field_cdescriptor is None:
459 raise ValueError('Protocol message has no "%s" field.' % field_name)
460 if field_cdescriptor.label == _LABEL_REPEATED:
461 if field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE:
462 for val in field_value:
463 getattr(self, field_name).add().MergeFrom(val)
464 else:
465 getattr(self, field_name).extend(field_value)
466 elif field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE:
467 getattr(self, field_name).MergeFrom(field_value)
468 else:
469 setattr(self, field_name, field_value)
470
471 Init.__module__ = None
472 Init.__doc__ = None
473 cls.__init__ = Init
474
475
476def _IsMessageSetExtension(field):
477 """Checks if a field is a message set extension."""
478 return (field.is_extension and
479 field.containing_type.has_options and
480 field.containing_type.GetOptions().message_set_wire_format and
481 field.type == _TYPE_MESSAGE and
482 field.message_type == field.extension_scope and
483 field.label == _LABEL_OPTIONAL)
484
485
486def _AddMessageMethods(message_descriptor, cls):
487 """Adds the methods to a protocol message class."""
488 if message_descriptor.is_extendable:
489
490 def ClearExtension(self, extension):
491 self.Extensions.ClearExtension(extension)
492
493 def HasExtension(self, extension):
494 return self.Extensions.HasExtension(extension)
495
496 def HasField(self, field_name):
497 return self._cmsg.HasField(field_name)
498
499 def ClearField(self, field_name):
500 if field_name in self._composite_fields:
501 del self._composite_fields[field_name]
502 self._cmsg.ClearField(field_name)
503
504 def Clear(self):
505 return self._cmsg.Clear()
506
507 def IsInitialized(self, errors=None):
508 if self._cmsg.IsInitialized():
509 return True
510 if errors is not None:
511 errors.extend(self.FindInitializationErrors());
512 return False
513
514 def SerializeToString(self):
515 if not self.IsInitialized():
516 raise message.EncodeError(
517 'Message is missing required fields: ' +
518 ','.join(self.FindInitializationErrors()))
519 return self._cmsg.SerializeToString()
520
521 def SerializePartialToString(self):
522 return self._cmsg.SerializePartialToString()
523
524 def ParseFromString(self, serialized):
525 self.Clear()
526 self.MergeFromString(serialized)
527
528 def MergeFromString(self, serialized):
529 byte_size = self._cmsg.MergeFromString(serialized)
530 if byte_size < 0:
531 raise message.DecodeError('Unable to merge from string.')
532 return byte_size
533
534 def MergeFrom(self, msg):
535 if not isinstance(msg, cls):
536 raise TypeError(
537 "Parameter to MergeFrom() must be instance of same class.")
538 self._cmsg.MergeFrom(msg._cmsg)
539
540 def CopyFrom(self, msg):
541 self._cmsg.CopyFrom(msg._cmsg)
542
543 def ByteSize(self):
544 return self._cmsg.ByteSize()
545
546 def SetInParent(self):
547 return self._cmsg.SetInParent()
548
549 def ListFields(self):
550 all_fields = []
551 field_list = self._cmsg.ListFields()
552 fields_by_name = cls.DESCRIPTOR.fields_by_name
553 for is_extension, field_name in field_list:
554 if is_extension:
555 extension = cls._extensions_by_name[field_name]
556 all_fields.append((extension, self.Extensions[extension]))
557 else:
558 field_descriptor = fields_by_name[field_name]
559 all_fields.append(
560 (field_descriptor, getattr(self, field_name)))
561 all_fields.sort(key=lambda item: item[0].number)
562 return all_fields
563
564 def FindInitializationErrors(self):
565 return self._cmsg.FindInitializationErrors()
566
567 def __str__(self):
568 return self._cmsg.DebugString()
569
570 def __eq__(self, other):
571 if self is other:
572 return True
573 if not isinstance(other, self.__class__):
574 return False
575 return self.ListFields() == other.ListFields()
576
577 def __ne__(self, other):
578 return not self == other
579
580 def __hash__(self):
581 raise TypeError('unhashable object')
582
583 def __unicode__(self):
584 return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
585
586 # Attach the local methods to the message class.
587 for key, value in locals().copy().iteritems():
588 if key not in ('key', 'value', '__builtins__', '__name__', '__doc__'):
589 setattr(cls, key, value)
590
591 # Static methods:
592
593 def RegisterExtension(extension_handle):
594 extension_handle.containing_type = cls.DESCRIPTOR
595 cls._extensions_by_name[extension_handle.full_name] = extension_handle
596
597 if _IsMessageSetExtension(extension_handle):
598 # MessageSet extension. Also register under type name.
599 cls._extensions_by_name[
600 extension_handle.message_type.full_name] = extension_handle
601 cls.RegisterExtension = staticmethod(RegisterExtension)
602
603 def FromString(string):
604 msg = cls()
605 msg.MergeFromString(string)
606 return msg
607 cls.FromString = staticmethod(FromString)
608
609
610
611def _AddPropertiesForExtensions(message_descriptor, cls):
612 """Adds properties for all fields in this protocol message type."""
613 extension_dict = message_descriptor.extensions_by_name
614 for extension_name, extension_field in extension_dict.iteritems():
615 constant_name = extension_name.upper() + '_FIELD_NUMBER'
616 setattr(cls, constant_name, extension_field.number)