blob: 5947f97a88d0083c4a1cbd1557a25bd98d298007 [file] [log] [blame]
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc.
# http://code.google.com/p/protobuf/
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unittest for reflection.py, which also indirectly tests the output of the
pure-Python protocol compiler.
"""
__author__ = 'robinson@google.com (Will Robinson)'
import operator
import unittest
# TODO(robinson): When we split this test in two, only some of these imports
# will be necessary in each test.
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2
from google.protobuf import descriptor_pb2
from google.protobuf import descriptor
from google.protobuf import message
from google.protobuf import reflection
from google.protobuf.internal import more_extensions_pb2
from google.protobuf.internal import more_messages_pb2
from google.protobuf.internal import wire_format
from google.protobuf.internal import test_util
from google.protobuf.internal import decoder
class RefectionTest(unittest.TestCase):
def testSimpleHasBits(self):
# Test a scalar.
proto = unittest_pb2.TestAllTypes()
self.assertTrue(not proto.HasField('optional_int32'))
self.assertEqual(0, proto.optional_int32)
# HasField() shouldn't be true if all we've done is
# read the default value.
self.assertTrue(not proto.HasField('optional_int32'))
proto.optional_int32 = 1
# Setting a value however *should* set the "has" bit.
self.assertTrue(proto.HasField('optional_int32'))
proto.ClearField('optional_int32')
# And clearing that value should unset the "has" bit.
self.assertTrue(not proto.HasField('optional_int32'))
def testHasBitsWithSinglyNestedScalar(self):
# Helper used to test foreign messages and groups.
#
# composite_field_name should be the name of a non-repeated
# composite (i.e., foreign or group) field in TestAllTypes,
# and scalar_field_name should be the name of an integer-valued
# scalar field within that composite.
#
# I never thought I'd miss C++ macros and templates so much. :(
# This helper is semantically just:
#
# assert proto.composite_field.scalar_field == 0
# assert not proto.composite_field.HasField('scalar_field')
# assert not proto.HasField('composite_field')
#
# proto.composite_field.scalar_field = 10
# old_composite_field = proto.composite_field
#
# assert proto.composite_field.scalar_field == 10
# assert proto.composite_field.HasField('scalar_field')
# assert proto.HasField('composite_field')
#
# proto.ClearField('composite_field')
#
# assert not proto.composite_field.HasField('scalar_field')
# assert not proto.HasField('composite_field')
# assert proto.composite_field.scalar_field == 0
#
# # Now ensure that ClearField('composite_field') disconnected
# # the old field object from the object tree...
# assert old_composite_field is not proto.composite_field
# old_composite_field.scalar_field = 20
# assert not proto.composite_field.HasField('scalar_field')
# assert not proto.HasField('composite_field')
def TestCompositeHasBits(composite_field_name, scalar_field_name):
proto = unittest_pb2.TestAllTypes()
# First, check that we can get the scalar value, and see that it's the
# default (0), but that proto.HasField('omposite') and
# proto.composite.HasField('scalar') will still return False.
composite_field = getattr(proto, composite_field_name)
original_scalar_value = getattr(composite_field, scalar_field_name)
self.assertEqual(0, original_scalar_value)
# Assert that the composite object does not "have" the scalar.
self.assertTrue(not composite_field.HasField(scalar_field_name))
# Assert that proto does not "have" the composite field.
self.assertTrue(not proto.HasField(composite_field_name))
# Now set the scalar within the composite field. Ensure that the setting
# is reflected, and that proto.HasField('composite') and
# proto.composite.HasField('scalar') now both return True.
new_val = 20
setattr(composite_field, scalar_field_name, new_val)
self.assertEqual(new_val, getattr(composite_field, scalar_field_name))
# Hold on to a reference to the current composite_field object.
old_composite_field = composite_field
# Assert that the has methods now return true.
self.assertTrue(composite_field.HasField(scalar_field_name))
self.assertTrue(proto.HasField(composite_field_name))
# Now call the clear method...
proto.ClearField(composite_field_name)
# ...and ensure that the "has" bits are all back to False...
composite_field = getattr(proto, composite_field_name)
self.assertTrue(not composite_field.HasField(scalar_field_name))
self.assertTrue(not proto.HasField(composite_field_name))
# ...and ensure that the scalar field has returned to its default.
self.assertEqual(0, getattr(composite_field, scalar_field_name))
# Finally, ensure that modifications to the old composite field object
# don't have any effect on the parent.
#
# (NOTE that when we clear the composite field in the parent, we actually
# don't recursively clear down the tree. Instead, we just disconnect the
# cleared composite from the tree.)
self.assertTrue(old_composite_field is not composite_field)
setattr(old_composite_field, scalar_field_name, new_val)
self.assertTrue(not composite_field.HasField(scalar_field_name))
self.assertTrue(not proto.HasField(composite_field_name))
self.assertEqual(0, getattr(composite_field, scalar_field_name))
# Test simple, single-level nesting when we set a scalar.
TestCompositeHasBits('optionalgroup', 'a')
TestCompositeHasBits('optional_nested_message', 'bb')
TestCompositeHasBits('optional_foreign_message', 'c')
TestCompositeHasBits('optional_import_message', 'd')
def testReferencesToNestedMessage(self):
proto = unittest_pb2.TestAllTypes()
nested = proto.optional_nested_message
del proto
# A previous version had a bug where this would raise an exception when
# hitting a now-dead weak reference.
nested.bb = 23
def testDisconnectingNestedMessageBeforeSettingField(self):
proto = unittest_pb2.TestAllTypes()
nested = proto.optional_nested_message
proto.ClearField('optional_nested_message') # Should disconnect from parent
self.assertTrue(nested is not proto.optional_nested_message)
nested.bb = 23
self.assertTrue(not proto.HasField('optional_nested_message'))
self.assertEqual(0, proto.optional_nested_message.bb)
def testHasBitsWhenModifyingRepeatedFields(self):
# Test nesting when we add an element to a repeated field in a submessage.
proto = unittest_pb2.TestNestedMessageHasBits()
proto.optional_nested_message.nestedmessage_repeated_int32.append(5)
self.assertEqual(
[5], proto.optional_nested_message.nestedmessage_repeated_int32)
self.assertTrue(proto.HasField('optional_nested_message'))
# Do the same test, but with a repeated composite field within the
# submessage.
proto.ClearField('optional_nested_message')
self.assertTrue(not proto.HasField('optional_nested_message'))
proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add()
self.assertTrue(proto.HasField('optional_nested_message'))
def testHasBitsForManyLevelsOfNesting(self):
# Test nesting many levels deep.
recursive_proto = unittest_pb2.TestMutualRecursionA()
self.assertTrue(not recursive_proto.HasField('bb'))
self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32)
self.assertTrue(not recursive_proto.HasField('bb'))
recursive_proto.bb.a.bb.a.bb.optional_int32 = 5
self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32)
self.assertTrue(recursive_proto.HasField('bb'))
self.assertTrue(recursive_proto.bb.HasField('a'))
self.assertTrue(recursive_proto.bb.a.HasField('bb'))
self.assertTrue(recursive_proto.bb.a.bb.HasField('a'))
self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb'))
self.assertTrue(not recursive_proto.bb.a.bb.a.bb.HasField('a'))
self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32'))
def testSingularListFields(self):
proto = unittest_pb2.TestAllTypes()
proto.optional_fixed32 = 1
proto.optional_int32 = 5
proto.optional_string = 'foo'
self.assertEqual(
[ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5),
(proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
(proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ],
proto.ListFields())
def testRepeatedListFields(self):
proto = unittest_pb2.TestAllTypes()
proto.repeated_fixed32.append(1)
proto.repeated_int32.append(5)
proto.repeated_int32.append(11)
proto.repeated_string.append('foo')
proto.repeated_string.append('bar')
proto.repeated_string.append('baz')
proto.optional_int32 = 21
self.assertEqual(
[ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21),
(proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]),
(proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]),
(proto.DESCRIPTOR.fields_by_name['repeated_string' ],
['foo', 'bar', 'baz']) ],
proto.ListFields())
def testSingularListExtensions(self):
proto = unittest_pb2.TestAllExtensions()
proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1
proto.Extensions[unittest_pb2.optional_int32_extension ] = 5
proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo'
self.assertEqual(
[ (unittest_pb2.optional_int32_extension , 5),
(unittest_pb2.optional_fixed32_extension, 1),
(unittest_pb2.optional_string_extension , 'foo') ],
proto.ListFields())
def testRepeatedListExtensions(self):
proto = unittest_pb2.TestAllExtensions()
proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1)
proto.Extensions[unittest_pb2.repeated_int32_extension ].append(5)
proto.Extensions[unittest_pb2.repeated_int32_extension ].append(11)
proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo')
proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar')
proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz')
proto.Extensions[unittest_pb2.optional_int32_extension ] = 21
self.assertEqual(
[ (unittest_pb2.optional_int32_extension , 21),
(unittest_pb2.repeated_int32_extension , [5, 11]),
(unittest_pb2.repeated_fixed32_extension, [1]),
(unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ],
proto.ListFields())
def testListFieldsAndExtensions(self):
proto = unittest_pb2.TestFieldOrderings()
test_util.SetAllFieldsAndExtensions(proto)
unittest_pb2.my_extension_int
self.assertEqual(
[ (proto.DESCRIPTOR.fields_by_name['my_int' ], 1),
(unittest_pb2.my_extension_int , 23),
(proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'),
(unittest_pb2.my_extension_string , 'bar'),
(proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ],
proto.ListFields())
def testDefaultValues(self):
proto = unittest_pb2.TestAllTypes()
self.assertEqual(0, proto.optional_int32)
self.assertEqual(0, proto.optional_int64)
self.assertEqual(0, proto.optional_uint32)
self.assertEqual(0, proto.optional_uint64)
self.assertEqual(0, proto.optional_sint32)
self.assertEqual(0, proto.optional_sint64)
self.assertEqual(0, proto.optional_fixed32)
self.assertEqual(0, proto.optional_fixed64)
self.assertEqual(0, proto.optional_sfixed32)
self.assertEqual(0, proto.optional_sfixed64)
self.assertEqual(0.0, proto.optional_float)
self.assertEqual(0.0, proto.optional_double)
self.assertEqual(False, proto.optional_bool)
self.assertEqual('', proto.optional_string)
self.assertEqual('', proto.optional_bytes)
self.assertEqual(41, proto.default_int32)
self.assertEqual(42, proto.default_int64)
self.assertEqual(43, proto.default_uint32)
self.assertEqual(44, proto.default_uint64)
self.assertEqual(-45, proto.default_sint32)
self.assertEqual(46, proto.default_sint64)
self.assertEqual(47, proto.default_fixed32)
self.assertEqual(48, proto.default_fixed64)
self.assertEqual(49, proto.default_sfixed32)
self.assertEqual(-50, proto.default_sfixed64)
self.assertEqual(51.5, proto.default_float)
self.assertEqual(52e3, proto.default_double)
self.assertEqual(True, proto.default_bool)
self.assertEqual('hello', proto.default_string)
self.assertEqual('world', proto.default_bytes)
self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum)
self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum)
self.assertEqual(unittest_import_pb2.IMPORT_BAR,
proto.default_import_enum)
def testHasFieldWithUnknownFieldName(self):
proto = unittest_pb2.TestAllTypes()
self.assertRaises(ValueError, proto.HasField, 'nonexistent_field')
def testClearFieldWithUnknownFieldName(self):
proto = unittest_pb2.TestAllTypes()
self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field')
def testDisallowedAssignments(self):
# It's illegal to assign values directly to repeated fields
# or to nonrepeated composite fields. Ensure that this fails.
proto = unittest_pb2.TestAllTypes()
# Repeated fields.
self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10)
# Lists shouldn't work, either.
self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10])
# Composite fields.
self.assertRaises(AttributeError, setattr, proto,
'optional_nested_message', 23)
# proto.nonexistent_field = 23 should fail as well.
self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23)
# TODO(robinson): Add type-safety check for enums.
def testSingleScalarTypeSafety(self):
proto = unittest_pb2.TestAllTypes()
self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1)
self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo')
self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
def testSingleScalarBoundsChecking(self):
def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
pb = unittest_pb2.TestAllTypes()
setattr(pb, field_name, expected_min)
setattr(pb, field_name, expected_max)
self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1)
self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1)
TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1)
TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
TestMinAndMaxIntegers('optional_nested_enum', -(1 << 31), (1 << 31) - 1)
def testRepeatedScalarTypeSafety(self):
proto = unittest_pb2.TestAllTypes()
self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
self.assertRaises(TypeError, proto.repeated_int32.append, 'foo')
self.assertRaises(TypeError, proto.repeated_string, 10)
self.assertRaises(TypeError, proto.repeated_bytes, 10)
proto.repeated_int32.append(10)
proto.repeated_int32[0] = 23
self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23)
self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc')
def testSingleScalarGettersAndSetters(self):
proto = unittest_pb2.TestAllTypes()
self.assertEqual(0, proto.optional_int32)
proto.optional_int32 = 1
self.assertEqual(1, proto.optional_int32)
# TODO(robinson): Test all other scalar field types.
def testSingleScalarClearField(self):
proto = unittest_pb2.TestAllTypes()
# Should be allowed to clear something that's not there (a no-op).
proto.ClearField('optional_int32')
proto.optional_int32 = 1
self.assertTrue(proto.HasField('optional_int32'))
proto.ClearField('optional_int32')
self.assertEqual(0, proto.optional_int32)
self.assertTrue(not proto.HasField('optional_int32'))
# TODO(robinson): Test all other scalar field types.
def testEnums(self):
proto = unittest_pb2.TestAllTypes()
self.assertEqual(1, proto.FOO)
self.assertEqual(1, unittest_pb2.TestAllTypes.FOO)
self.assertEqual(2, proto.BAR)
self.assertEqual(2, unittest_pb2.TestAllTypes.BAR)
self.assertEqual(3, proto.BAZ)
self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
def testRepeatedScalars(self):
proto = unittest_pb2.TestAllTypes()
self.assertTrue(not proto.repeated_int32)
self.assertEqual(0, len(proto.repeated_int32))
proto.repeated_int32.append(5);
proto.repeated_int32.append(10);
self.assertTrue(proto.repeated_int32)
self.assertEqual(2, len(proto.repeated_int32))
self.assertEqual([5, 10], proto.repeated_int32)
self.assertEqual(5, proto.repeated_int32[0])
self.assertEqual(10, proto.repeated_int32[-1])
# Test out-of-bounds indices.
self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234)
self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234)
# Test incorrect types passed to __getitem__.
self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo')
self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None)
# Test that we can use the field as an iterator.
result = []
for i in proto.repeated_int32:
result.append(i)
self.assertEqual([5, 10], result)
# Test clearing.
proto.ClearField('repeated_int32')
self.assertTrue(not proto.repeated_int32)
self.assertEqual(0, len(proto.repeated_int32))
def testRepeatedComposites(self):
proto = unittest_pb2.TestAllTypes()
self.assertTrue(not proto.repeated_nested_message)
self.assertEqual(0, len(proto.repeated_nested_message))
m0 = proto.repeated_nested_message.add()
m1 = proto.repeated_nested_message.add()
self.assertTrue(proto.repeated_nested_message)
self.assertEqual(2, len(proto.repeated_nested_message))
self.assertTrue(m0 is proto.repeated_nested_message[0])
self.assertTrue(m1 is proto.repeated_nested_message[1])
self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage))
# Test out-of-bounds indices.
self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
1234)
self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
-1234)
# Test incorrect types passed to __getitem__.
self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
'foo')
self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
None)
# Test that we can use the field as an iterator.
result = []
for i in proto.repeated_nested_message:
result.append(i)
self.assertEqual(2, len(result))
self.assertTrue(m0 is result[0])
self.assertTrue(m1 is result[1])
# Test clearing.
proto.ClearField('repeated_nested_message')
self.assertTrue(not proto.repeated_nested_message)
self.assertEqual(0, len(proto.repeated_nested_message))
def testHandWrittenReflection(self):
# TODO(robinson): We probably need a better way to specify
# protocol types by hand. But then again, this isn't something
# we expect many people to do. Hmm.
FieldDescriptor = descriptor.FieldDescriptor
foo_field_descriptor = FieldDescriptor(
name='foo_field', full_name='MyProto.foo_field',
index=0, number=1, type=FieldDescriptor.TYPE_INT64,
cpp_type=FieldDescriptor.CPPTYPE_INT64,
label=FieldDescriptor.LABEL_OPTIONAL, default_value=0,
containing_type=None, message_type=None, enum_type=None,
is_extension=False, extension_scope=None,
options=descriptor_pb2.FieldOptions())
mydescriptor = descriptor.Descriptor(
name='MyProto', full_name='MyProto', filename='ignored',
containing_type=None, nested_types=[], enum_types=[],
fields=[foo_field_descriptor], extensions=[],
options=descriptor_pb2.MessageOptions())
class MyProtoClass(message.Message):
DESCRIPTOR = mydescriptor
__metaclass__ = reflection.GeneratedProtocolMessageType
myproto_instance = MyProtoClass()
self.assertEqual(0, myproto_instance.foo_field)
self.assertTrue(not myproto_instance.HasField('foo_field'))
myproto_instance.foo_field = 23
self.assertEqual(23, myproto_instance.foo_field)
self.assertTrue(myproto_instance.HasField('foo_field'))
def testTopLevelExtensionsForOptionalScalar(self):
extendee_proto = unittest_pb2.TestAllExtensions()
extension = unittest_pb2.optional_int32_extension
self.assertTrue(not extendee_proto.HasExtension(extension))
self.assertEqual(0, extendee_proto.Extensions[extension])
# As with normal scalar fields, just doing a read doesn't actually set the
# "has" bit.
self.assertTrue(not extendee_proto.HasExtension(extension))
# Actually set the thing.
extendee_proto.Extensions[extension] = 23
self.assertEqual(23, extendee_proto.Extensions[extension])
self.assertTrue(extendee_proto.HasExtension(extension))
# Ensure that clearing works as well.
extendee_proto.ClearExtension(extension)
self.assertEqual(0, extendee_proto.Extensions[extension])
self.assertTrue(not extendee_proto.HasExtension(extension))
def testTopLevelExtensionsForRepeatedScalar(self):
extendee_proto = unittest_pb2.TestAllExtensions()
extension = unittest_pb2.repeated_string_extension
self.assertEqual(0, len(extendee_proto.Extensions[extension]))
extendee_proto.Extensions[extension].append('foo')
self.assertEqual(['foo'], extendee_proto.Extensions[extension])
string_list = extendee_proto.Extensions[extension]
extendee_proto.ClearExtension(extension)
self.assertEqual(0, len(extendee_proto.Extensions[extension]))
self.assertTrue(string_list is not extendee_proto.Extensions[extension])
# Shouldn't be allowed to do Extensions[extension] = 'a'
self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
extension, 'a')
def testTopLevelExtensionsForOptionalMessage(self):
extendee_proto = unittest_pb2.TestAllExtensions()
extension = unittest_pb2.optional_foreign_message_extension
self.assertTrue(not extendee_proto.HasExtension(extension))
self.assertEqual(0, extendee_proto.Extensions[extension].c)
# As with normal (non-extension) fields, merely reading from the
# thing shouldn't set the "has" bit.
self.assertTrue(not extendee_proto.HasExtension(extension))
extendee_proto.Extensions[extension].c = 23
self.assertEqual(23, extendee_proto.Extensions[extension].c)
self.assertTrue(extendee_proto.HasExtension(extension))
# Save a reference here.
foreign_message = extendee_proto.Extensions[extension]
extendee_proto.ClearExtension(extension)
self.assertTrue(foreign_message is not extendee_proto.Extensions[extension])
# Setting a field on foreign_message now shouldn't set
# any "has" bits on extendee_proto.
foreign_message.c = 42
self.assertEqual(42, foreign_message.c)
self.assertTrue(foreign_message.HasField('c'))
self.assertTrue(not extendee_proto.HasExtension(extension))
# Shouldn't be allowed to do Extensions[extension] = 'a'
self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
extension, 'a')
def testTopLevelExtensionsForRepeatedMessage(self):
extendee_proto = unittest_pb2.TestAllExtensions()
extension = unittest_pb2.repeatedgroup_extension
self.assertEqual(0, len(extendee_proto.Extensions[extension]))
group = extendee_proto.Extensions[extension].add()
group.a = 23
self.assertEqual(23, extendee_proto.Extensions[extension][0].a)
group.a = 42
self.assertEqual(42, extendee_proto.Extensions[extension][0].a)
group_list = extendee_proto.Extensions[extension]
extendee_proto.ClearExtension(extension)
self.assertEqual(0, len(extendee_proto.Extensions[extension]))
self.assertTrue(group_list is not extendee_proto.Extensions[extension])
# Shouldn't be allowed to do Extensions[extension] = 'a'
self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
extension, 'a')
def testNestedExtensions(self):
extendee_proto = unittest_pb2.TestAllExtensions()
extension = unittest_pb2.TestRequired.single
# We just test the non-repeated case.
self.assertTrue(not extendee_proto.HasExtension(extension))
required = extendee_proto.Extensions[extension]
self.assertEqual(0, required.a)
self.assertTrue(not extendee_proto.HasExtension(extension))
required.a = 23
self.assertEqual(23, extendee_proto.Extensions[extension].a)
self.assertTrue(extendee_proto.HasExtension(extension))
extendee_proto.ClearExtension(extension)
self.assertTrue(required is not extendee_proto.Extensions[extension])
self.assertTrue(not extendee_proto.HasExtension(extension))
# If message A directly contains message B, and
# a.HasField('b') is currently False, then mutating any
# extension in B should change a.HasField('b') to True
# (and so on up the object tree).
def testHasBitsForAncestorsOfExtendedMessage(self):
# Optional scalar extension.
toplevel = more_extensions_pb2.TopLevelMessage()
self.assertTrue(not toplevel.HasField('submessage'))
self.assertEqual(0, toplevel.submessage.Extensions[
more_extensions_pb2.optional_int_extension])
self.assertTrue(not toplevel.HasField('submessage'))
toplevel.submessage.Extensions[
more_extensions_pb2.optional_int_extension] = 23
self.assertEqual(23, toplevel.submessage.Extensions[
more_extensions_pb2.optional_int_extension])
self.assertTrue(toplevel.HasField('submessage'))
# Repeated scalar extension.
toplevel = more_extensions_pb2.TopLevelMessage()
self.assertTrue(not toplevel.HasField('submessage'))
self.assertEqual([], toplevel.submessage.Extensions[
more_extensions_pb2.repeated_int_extension])
self.assertTrue(not toplevel.HasField('submessage'))
toplevel.submessage.Extensions[
more_extensions_pb2.repeated_int_extension].append(23)
self.assertEqual([23], toplevel.submessage.Extensions[
more_extensions_pb2.repeated_int_extension])
self.assertTrue(toplevel.HasField('submessage'))
# Optional message extension.
toplevel = more_extensions_pb2.TopLevelMessage()
self.assertTrue(not toplevel.HasField('submessage'))
self.assertEqual(0, toplevel.submessage.Extensions[
more_extensions_pb2.optional_message_extension].foreign_message_int)
self.assertTrue(not toplevel.HasField('submessage'))
toplevel.submessage.Extensions[
more_extensions_pb2.optional_message_extension].foreign_message_int = 23
self.assertEqual(23, toplevel.submessage.Extensions[
more_extensions_pb2.optional_message_extension].foreign_message_int)
self.assertTrue(toplevel.HasField('submessage'))
# Repeated message extension.
toplevel = more_extensions_pb2.TopLevelMessage()
self.assertTrue(not toplevel.HasField('submessage'))
self.assertEqual(0, len(toplevel.submessage.Extensions[
more_extensions_pb2.repeated_message_extension]))
self.assertTrue(not toplevel.HasField('submessage'))
foreign = toplevel.submessage.Extensions[
more_extensions_pb2.repeated_message_extension].add()
self.assertTrue(foreign is toplevel.submessage.Extensions[
more_extensions_pb2.repeated_message_extension][0])
self.assertTrue(toplevel.HasField('submessage'))
def testDisconnectionAfterClearingEmptyMessage(self):
toplevel = more_extensions_pb2.TopLevelMessage()
extendee_proto = toplevel.submessage
extension = more_extensions_pb2.optional_message_extension
extension_proto = extendee_proto.Extensions[extension]
extendee_proto.ClearExtension(extension)
extension_proto.foreign_message_int = 23
self.assertTrue(not toplevel.HasField('submessage'))
self.assertTrue(extension_proto is not extendee_proto.Extensions[extension])
def testExtensionFailureModes(self):
extendee_proto = unittest_pb2.TestAllExtensions()
# Try non-extension-handle arguments to HasExtension,
# ClearExtension(), and Extensions[]...
self.assertRaises(KeyError, extendee_proto.HasExtension, 1234)
self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234)
self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234)
self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5)
# Try something that *is* an extension handle, just not for
# this message...
unknown_handle = more_extensions_pb2.optional_int_extension
self.assertRaises(KeyError, extendee_proto.HasExtension,
unknown_handle)
self.assertRaises(KeyError, extendee_proto.ClearExtension,
unknown_handle)
self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__,
unknown_handle)
self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__,
unknown_handle, 5)
# Try call HasExtension() with a valid handle, but for a
# *repeated* field. (Just as with non-extension repeated
# fields, Has*() isn't supported for extension repeated fields).
self.assertRaises(KeyError, extendee_proto.HasExtension,
unittest_pb2.repeated_string_extension)
def testCopyFrom(self):
# TODO(robinson): Implement.
pass
def testClear(self):
proto = unittest_pb2.TestAllTypes()
test_util.SetAllFields(proto)
# Clear the message.
proto.Clear()
self.assertEquals(proto.ByteSize(), 0)
empty_proto = unittest_pb2.TestAllTypes()
self.assertEquals(proto, empty_proto)
# Test if extensions which were set are cleared.
proto = unittest_pb2.TestAllExtensions()
test_util.SetAllExtensions(proto)
# Clear the message.
proto.Clear()
self.assertEquals(proto.ByteSize(), 0)
empty_proto = unittest_pb2.TestAllExtensions()
self.assertEquals(proto, empty_proto)
def testIsInitialized(self):
# Trivial cases - all optional fields and extensions.
proto = unittest_pb2.TestAllTypes()
self.assertTrue(proto.IsInitialized())
proto = unittest_pb2.TestAllExtensions()
self.assertTrue(proto.IsInitialized())
# The case of uninitialized required fields.
proto = unittest_pb2.TestRequired()
self.assertFalse(proto.IsInitialized())
proto.a = proto.b = proto.c = 2
self.assertTrue(proto.IsInitialized())
# The case of uninitialized submessage.
proto = unittest_pb2.TestRequiredForeign()
self.assertTrue(proto.IsInitialized())
proto.optional_message.a = 1
self.assertFalse(proto.IsInitialized())
proto.optional_message.b = 0
proto.optional_message.c = 0
self.assertTrue(proto.IsInitialized())
# Uninitialized repeated submessage.
message1 = proto.repeated_message.add()
self.assertFalse(proto.IsInitialized())
message1.a = message1.b = message1.c = 0
self.assertTrue(proto.IsInitialized())
# Uninitialized repeated group in an extension.
proto = unittest_pb2.TestAllExtensions()
extension = unittest_pb2.TestRequired.multi
message1 = proto.Extensions[extension].add()
message2 = proto.Extensions[extension].add()
self.assertFalse(proto.IsInitialized())
message1.a = 1
message1.b = 1
message1.c = 1
self.assertFalse(proto.IsInitialized())
message2.a = 2
message2.b = 2
message2.c = 2
self.assertTrue(proto.IsInitialized())
# Uninitialized nonrepeated message in an extension.
proto = unittest_pb2.TestAllExtensions()
extension = unittest_pb2.TestRequired.single
proto.Extensions[extension].a = 1
self.assertFalse(proto.IsInitialized())
proto.Extensions[extension].b = 2
proto.Extensions[extension].c = 3
self.assertTrue(proto.IsInitialized())
# Since we had so many tests for protocol buffer equality, we broke these out
# into separate TestCase classes.
class TestAllTypesEqualityTest(unittest.TestCase):
def setUp(self):
self.first_proto = unittest_pb2.TestAllTypes()
self.second_proto = unittest_pb2.TestAllTypes()
def testSelfEquality(self):
self.assertEqual(self.first_proto, self.first_proto)
def testEmptyProtosEqual(self):
self.assertEqual(self.first_proto, self.second_proto)
class FullProtosEqualityTest(unittest.TestCase):
"""Equality tests using completely-full protos as a starting point."""
def setUp(self):
self.first_proto = unittest_pb2.TestAllTypes()
self.second_proto = unittest_pb2.TestAllTypes()
test_util.SetAllFields(self.first_proto)
test_util.SetAllFields(self.second_proto)
def testAllFieldsFilledEquality(self):
self.assertEqual(self.first_proto, self.second_proto)
def testNonRepeatedScalar(self):
# Nonrepeated scalar field change should cause inequality.
self.first_proto.optional_int32 += 1
self.assertNotEqual(self.first_proto, self.second_proto)
# ...as should clearing a field.
self.first_proto.ClearField('optional_int32')
self.assertNotEqual(self.first_proto, self.second_proto)
def testNonRepeatedComposite(self):
# Change a nonrepeated composite field.
self.first_proto.optional_nested_message.bb += 1
self.assertNotEqual(self.first_proto, self.second_proto)
self.first_proto.optional_nested_message.bb -= 1
self.assertEqual(self.first_proto, self.second_proto)
# Clear a field in the nested message.
self.first_proto.optional_nested_message.ClearField('bb')
self.assertNotEqual(self.first_proto, self.second_proto)
self.first_proto.optional_nested_message.bb = (
self.second_proto.optional_nested_message.bb)
self.assertEqual(self.first_proto, self.second_proto)
# Remove the nested message entirely.
self.first_proto.ClearField('optional_nested_message')
self.assertNotEqual(self.first_proto, self.second_proto)
def testRepeatedScalar(self):
# Change a repeated scalar field.
self.first_proto.repeated_int32.append(5)
self.assertNotEqual(self.first_proto, self.second_proto)
self.first_proto.ClearField('repeated_int32')
self.assertNotEqual(self.first_proto, self.second_proto)
def testRepeatedComposite(self):
# Change value within a repeated composite field.
self.first_proto.repeated_nested_message[0].bb += 1
self.assertNotEqual(self.first_proto, self.second_proto)
self.first_proto.repeated_nested_message[0].bb -= 1
self.assertEqual(self.first_proto, self.second_proto)
# Add a value to a repeated composite field.
self.first_proto.repeated_nested_message.add()
self.assertNotEqual(self.first_proto, self.second_proto)
self.second_proto.repeated_nested_message.add()
self.assertEqual(self.first_proto, self.second_proto)
def testNonRepeatedScalarHasBits(self):
# Ensure that we test "has" bits as well as value for
# nonrepeated scalar field.
self.first_proto.ClearField('optional_int32')
self.second_proto.optional_int32 = 0
self.assertNotEqual(self.first_proto, self.second_proto)
def testNonRepeatedCompositeHasBits(self):
# Ensure that we test "has" bits as well as value for
# nonrepeated composite field.
self.first_proto.ClearField('optional_nested_message')
self.second_proto.optional_nested_message.ClearField('bb')
self.assertNotEqual(self.first_proto, self.second_proto)
# TODO(robinson): Replace next two lines with method
# to set the "has" bit without changing the value,
# if/when such a method exists.
self.first_proto.optional_nested_message.bb = 0
self.first_proto.optional_nested_message.ClearField('bb')
self.assertEqual(self.first_proto, self.second_proto)
class ExtensionEqualityTest(unittest.TestCase):
def testExtensionEquality(self):
first_proto = unittest_pb2.TestAllExtensions()
second_proto = unittest_pb2.TestAllExtensions()
self.assertEqual(first_proto, second_proto)
test_util.SetAllExtensions(first_proto)
self.assertNotEqual(first_proto, second_proto)
test_util.SetAllExtensions(second_proto)
self.assertEqual(first_proto, second_proto)
# Ensure that we check value equality.
first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1
self.assertNotEqual(first_proto, second_proto)
first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1
self.assertEqual(first_proto, second_proto)
# Ensure that we also look at "has" bits.
first_proto.ClearExtension(unittest_pb2.optional_int32_extension)
second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
self.assertNotEqual(first_proto, second_proto)
first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
self.assertEqual(first_proto, second_proto)
# Ensure that differences in cached values
# don't matter if "has" bits are both false.
first_proto = unittest_pb2.TestAllExtensions()
second_proto = unittest_pb2.TestAllExtensions()
self.assertEqual(
0, first_proto.Extensions[unittest_pb2.optional_int32_extension])
self.assertEqual(first_proto, second_proto)
class MutualRecursionEqualityTest(unittest.TestCase):
def testEqualityWithMutualRecursion(self):
first_proto = unittest_pb2.TestMutualRecursionA()
second_proto = unittest_pb2.TestMutualRecursionA()
self.assertEqual(first_proto, second_proto)
first_proto.bb.a.bb.optional_int32 = 23
self.assertNotEqual(first_proto, second_proto)
second_proto.bb.a.bb.optional_int32 = 23
self.assertEqual(first_proto, second_proto)
class ByteSizeTest(unittest.TestCase):
def setUp(self):
self.proto = unittest_pb2.TestAllTypes()
self.extended_proto = more_extensions_pb2.ExtendedMessage()
def Size(self):
return self.proto.ByteSize()
def testEmptyMessage(self):
self.assertEqual(0, self.proto.ByteSize())
def testVarints(self):
def Test(i, expected_varint_size):
self.proto.Clear()
self.proto.optional_int64 = i
# Add one to the varint size for the tag info
# for tag 1.
self.assertEqual(expected_varint_size + 1, self.Size())
Test(0, 1)
Test(1, 1)
for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)):
Test((1 << i) - 1, num_bytes)
Test(-1, 10)
Test(-2, 10)
Test(-(1 << 63), 10)
def testStrings(self):
self.proto.optional_string = ''
# Need one byte for tag info (tag #14), and one byte for length.
self.assertEqual(2, self.Size())
self.proto.optional_string = 'abc'
# Need one byte for tag info (tag #14), and one byte for length.
self.assertEqual(2 + len(self.proto.optional_string), self.Size())
self.proto.optional_string = 'x' * 128
# Need one byte for tag info (tag #14), and TWO bytes for length.
self.assertEqual(3 + len(self.proto.optional_string), self.Size())
def testOtherNumerics(self):
self.proto.optional_fixed32 = 1234
# One byte for tag and 4 bytes for fixed32.
self.assertEqual(5, self.Size())
self.proto = unittest_pb2.TestAllTypes()
self.proto.optional_fixed64 = 1234
# One byte for tag and 8 bytes for fixed64.
self.assertEqual(9, self.Size())
self.proto = unittest_pb2.TestAllTypes()
self.proto.optional_float = 1.234
# One byte for tag and 4 bytes for float.
self.assertEqual(5, self.Size())
self.proto = unittest_pb2.TestAllTypes()
self.proto.optional_double = 1.234
# One byte for tag and 8 bytes for float.
self.assertEqual(9, self.Size())
self.proto = unittest_pb2.TestAllTypes()
self.proto.optional_sint32 = 64
# One byte for tag and 2 bytes for zig-zag-encoded 64.
self.assertEqual(3, self.Size())
self.proto = unittest_pb2.TestAllTypes()
def testComposites(self):
# 3 bytes.
self.proto.optional_nested_message.bb = (1 << 14)
# Plus one byte for bb tag.
# Plus 1 byte for optional_nested_message serialized size.
# Plus two bytes for optional_nested_message tag.
self.assertEqual(3 + 1 + 1 + 2, self.Size())
def testGroups(self):
# 4 bytes.
self.proto.optionalgroup.a = (1 << 21)
# Plus two bytes for |a| tag.
# Plus 2 * two bytes for START_GROUP and END_GROUP tags.
self.assertEqual(4 + 2 + 2*2, self.Size())
def testRepeatedScalars(self):
self.proto.repeated_int32.append(10) # 1 byte.
self.proto.repeated_int32.append(128) # 2 bytes.
# Also need 2 bytes for each entry for tag.
self.assertEqual(1 + 2 + 2*2, self.Size())
def testRepeatedComposites(self):
# Empty message. 2 bytes tag plus 1 byte length.
foreign_message_0 = self.proto.repeated_nested_message.add()
# 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
foreign_message_1 = self.proto.repeated_nested_message.add()
foreign_message_1.bb = 7
self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
def testRepeatedGroups(self):
# 2-byte START_GROUP plus 2-byte END_GROUP.
group_0 = self.proto.repeatedgroup.add()
# 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a|
# plus 2-byte END_GROUP.
group_1 = self.proto.repeatedgroup.add()
group_1.a = 7
self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size())
def testExtensions(self):
proto = unittest_pb2.TestAllExtensions()
self.assertEqual(0, proto.ByteSize())
extension = unittest_pb2.optional_int32_extension # Field #1, 1 byte.
proto.Extensions[extension] = 23
# 1 byte for tag, 1 byte for value.
self.assertEqual(2, proto.ByteSize())
def testCacheInvalidationForNonrepeatedScalar(self):
# Test non-extension.
self.proto.optional_int32 = 1
self.assertEqual(2, self.proto.ByteSize())
self.proto.optional_int32 = 128
self.assertEqual(3, self.proto.ByteSize())
self.proto.ClearField('optional_int32')
self.assertEqual(0, self.proto.ByteSize())
# Test within extension.
extension = more_extensions_pb2.optional_int_extension
self.extended_proto.Extensions[extension] = 1
self.assertEqual(2, self.extended_proto.ByteSize())
self.extended_proto.Extensions[extension] = 128
self.assertEqual(3, self.extended_proto.ByteSize())
self.extended_proto.ClearExtension(extension)
self.assertEqual(0, self.extended_proto.ByteSize())
def testCacheInvalidationForRepeatedScalar(self):
# Test non-extension.
self.proto.repeated_int32.append(1)
self.assertEqual(3, self.proto.ByteSize())
self.proto.repeated_int32.append(1)
self.assertEqual(6, self.proto.ByteSize())
self.proto.repeated_int32[1] = 128
self.assertEqual(7, self.proto.ByteSize())
self.proto.ClearField('repeated_int32')
self.assertEqual(0, self.proto.ByteSize())
# Test within extension.
extension = more_extensions_pb2.repeated_int_extension
repeated = self.extended_proto.Extensions[extension]
repeated.append(1)
self.assertEqual(2, self.extended_proto.ByteSize())
repeated.append(1)
self.assertEqual(4, self.extended_proto.ByteSize())
repeated[1] = 128
self.assertEqual(5, self.extended_proto.ByteSize())
self.extended_proto.ClearExtension(extension)
self.assertEqual(0, self.extended_proto.ByteSize())
def testCacheInvalidationForNonrepeatedMessage(self):
# Test non-extension.
self.proto.optional_foreign_message.c = 1
self.assertEqual(5, self.proto.ByteSize())
self.proto.optional_foreign_message.c = 128
self.assertEqual(6, self.proto.ByteSize())
self.proto.optional_foreign_message.ClearField('c')
self.assertEqual(3, self.proto.ByteSize())
self.proto.ClearField('optional_foreign_message')
self.assertEqual(0, self.proto.ByteSize())
child = self.proto.optional_foreign_message
self.proto.ClearField('optional_foreign_message')
child.c = 128
self.assertEqual(0, self.proto.ByteSize())
# Test within extension.
extension = more_extensions_pb2.optional_message_extension
child = self.extended_proto.Extensions[extension]
self.assertEqual(0, self.extended_proto.ByteSize())
child.foreign_message_int = 1
self.assertEqual(4, self.extended_proto.ByteSize())
child.foreign_message_int = 128
self.assertEqual(5, self.extended_proto.ByteSize())
self.extended_proto.ClearExtension(extension)
self.assertEqual(0, self.extended_proto.ByteSize())
def testCacheInvalidationForRepeatedMessage(self):
# Test non-extension.
child0 = self.proto.repeated_foreign_message.add()
self.assertEqual(3, self.proto.ByteSize())
self.proto.repeated_foreign_message.add()
self.assertEqual(6, self.proto.ByteSize())
child0.c = 1
self.assertEqual(8, self.proto.ByteSize())
self.proto.ClearField('repeated_foreign_message')
self.assertEqual(0, self.proto.ByteSize())
# Test within extension.
extension = more_extensions_pb2.repeated_message_extension
child_list = self.extended_proto.Extensions[extension]
child0 = child_list.add()
self.assertEqual(2, self.extended_proto.ByteSize())
child_list.add()
self.assertEqual(4, self.extended_proto.ByteSize())
child0.foreign_message_int = 1
self.assertEqual(6, self.extended_proto.ByteSize())
child0.ClearField('foreign_message_int')
self.assertEqual(4, self.extended_proto.ByteSize())
self.extended_proto.ClearExtension(extension)
self.assertEqual(0, self.extended_proto.ByteSize())
# TODO(robinson): We need cross-language serialization consistency tests.
# Issues to be sure to cover include:
# * Handling of unrecognized tags ("uninterpreted_bytes").
# * Handling of MessageSets.
# * Consistent ordering of tags in the wire format,
# including ordering between extensions and non-extension
# fields.
# * Consistent serialization of negative numbers, especially
# negative int32s.
# * Handling of empty submessages (with and without "has"
# bits set).
class SerializationTest(unittest.TestCase):
def testSerializeEmtpyMessage(self):
first_proto = unittest_pb2.TestAllTypes()
second_proto = unittest_pb2.TestAllTypes()
serialized = first_proto.SerializeToString()
self.assertEqual(first_proto.ByteSize(), len(serialized))
second_proto.MergeFromString(serialized)
self.assertEqual(first_proto, second_proto)
def testSerializeAllFields(self):
first_proto = unittest_pb2.TestAllTypes()
second_proto = unittest_pb2.TestAllTypes()
test_util.SetAllFields(first_proto)
serialized = first_proto.SerializeToString()
self.assertEqual(first_proto.ByteSize(), len(serialized))
second_proto.MergeFromString(serialized)
self.assertEqual(first_proto, second_proto)
def testSerializeAllExtensions(self):
first_proto = unittest_pb2.TestAllExtensions()
second_proto = unittest_pb2.TestAllExtensions()
test_util.SetAllExtensions(first_proto)
serialized = first_proto.SerializeToString()
second_proto.MergeFromString(serialized)
self.assertEqual(first_proto, second_proto)
def testCanonicalSerializationOrder(self):
proto = more_messages_pb2.OutOfOrderFields()
# These are also their tag numbers. Even though we're setting these in
# reverse-tag order AND they're listed in reverse tag-order in the .proto
# file, they should nonetheless be serialized in tag order.
proto.optional_sint32 = 5
proto.Extensions[more_messages_pb2.optional_uint64] = 4
proto.optional_uint32 = 3
proto.Extensions[more_messages_pb2.optional_int64] = 2
proto.optional_int32 = 1
serialized = proto.SerializeToString()
self.assertEqual(proto.ByteSize(), len(serialized))
d = decoder.Decoder(serialized)
ReadTag = d.ReadFieldNumberAndWireType
self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag())
self.assertEqual(1, d.ReadInt32())
self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag())
self.assertEqual(2, d.ReadInt64())
self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag())
self.assertEqual(3, d.ReadUInt32())
self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag())
self.assertEqual(4, d.ReadUInt64())
self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag())
self.assertEqual(5, d.ReadSInt32())
def testCanonicalSerializationOrderSameAsCpp(self):
# Copy of the same test we use for C++.
proto = unittest_pb2.TestFieldOrderings()
test_util.SetAllFieldsAndExtensions(proto)
serialized = proto.SerializeToString()
test_util.ExpectAllFieldsAndExtensionsInOrder(serialized)
def testMergeFromStringWhenFieldsAlreadySet(self):
first_proto = unittest_pb2.TestAllTypes()
first_proto.repeated_string.append('foobar')
first_proto.optional_int32 = 23
first_proto.optional_nested_message.bb = 42
serialized = first_proto.SerializeToString()
second_proto = unittest_pb2.TestAllTypes()
second_proto.repeated_string.append('baz')
second_proto.optional_int32 = 100
second_proto.optional_nested_message.bb = 999
second_proto.MergeFromString(serialized)
# Ensure that we append to repeated fields.
self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string))
# Ensure that we overwrite nonrepeatd scalars.
self.assertEqual(23, second_proto.optional_int32)
# Ensure that we recursively call MergeFromString() on
# submessages.
self.assertEqual(42, second_proto.optional_nested_message.bb)
def testMessageSetWireFormat(self):
proto = unittest_mset_pb2.TestMessageSet()
extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
extension_message2 = unittest_mset_pb2.TestMessageSetExtension2
extension1 = extension_message1.message_set_extension
extension2 = extension_message2.message_set_extension
proto.Extensions[extension1].i = 123
proto.Extensions[extension2].str = 'foo'
# Serialize using the MessageSet wire format (this is specified in the
# .proto file).
serialized = proto.SerializeToString()
raw = unittest_mset_pb2.RawMessageSet()
self.assertEqual(False,
raw.DESCRIPTOR.GetOptions().message_set_wire_format)
raw.MergeFromString(serialized)
self.assertEqual(2, len(raw.item))
message1 = unittest_mset_pb2.TestMessageSetExtension1()
message1.MergeFromString(raw.item[0].message)
self.assertEqual(123, message1.i)
message2 = unittest_mset_pb2.TestMessageSetExtension2()
message2.MergeFromString(raw.item[1].message)
self.assertEqual('foo', message2.str)
# Deserialize using the MessageSet wire format.
proto2 = unittest_mset_pb2.TestMessageSet()
proto2.MergeFromString(serialized)
self.assertEqual(123, proto2.Extensions[extension1].i)
self.assertEqual('foo', proto2.Extensions[extension2].str)
# Check byte size.
self.assertEqual(proto2.ByteSize(), len(serialized))
self.assertEqual(proto.ByteSize(), len(serialized))
def testMessageSetWireFormatUnknownExtension(self):
# Create a message using the message set wire format with an unknown
# message.
raw = unittest_mset_pb2.RawMessageSet()
# Add an item.
item = raw.item.add()
item.type_id = 1545008
extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
message1 = unittest_mset_pb2.TestMessageSetExtension1()
message1.i = 12345
item.message = message1.SerializeToString()
# Add a second, unknown extension.
item = raw.item.add()
item.type_id = 1545009
extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
message1 = unittest_mset_pb2.TestMessageSetExtension1()
message1.i = 12346
item.message = message1.SerializeToString()
# Add another unknown extension.
item = raw.item.add()
item.type_id = 1545010
message1 = unittest_mset_pb2.TestMessageSetExtension2()
message1.str = 'foo'
item.message = message1.SerializeToString()
serialized = raw.SerializeToString()
# Parse message using the message set wire format.
proto = unittest_mset_pb2.TestMessageSet()
proto.MergeFromString(serialized)
# Check that the message parsed well.
extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
extension1 = extension_message1.message_set_extension
self.assertEquals(12345, proto.Extensions[extension1].i)
def testUnknownFields(self):
proto = unittest_pb2.TestAllTypes()
test_util.SetAllFields(proto)
serialized = proto.SerializeToString()
# The empty message should be parsable with all of the fields
# unknown.
proto2 = unittest_pb2.TestEmptyMessage()
# Parsing this message should succeed.
proto2.MergeFromString(serialized)
class OptionsTest(unittest.TestCase):
def testMessageOptions(self):
proto = unittest_mset_pb2.TestMessageSet()
self.assertEqual(True,
proto.DESCRIPTOR.GetOptions().message_set_wire_format)
proto = unittest_pb2.TestAllTypes()
self.assertEqual(False,
proto.DESCRIPTOR.GetOptions().message_set_wire_format)
class UtilityTest(unittest.TestCase):
def testImergeSorted(self):
ImergeSorted = reflection._ImergeSorted
# Various types of emptiness.
self.assertEqual([], list(ImergeSorted()))
self.assertEqual([], list(ImergeSorted([])))
self.assertEqual([], list(ImergeSorted([], [])))
# One nonempty list.
self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3])))
self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3], [])))
self.assertEqual([1, 2, 3], list(ImergeSorted([], [1, 2, 3])))
# Merging some nonempty lists together.
self.assertEqual([1, 2, 3], list(ImergeSorted([1, 3], [2])))
self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2])))
self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2], [])))
# Elements repeated across component iterators.
self.assertEqual([1, 2, 2, 3, 3],
list(ImergeSorted([1, 2], [3], [2, 3])))
# Elements repeated within an iterator.
self.assertEqual([1, 2, 2, 3, 3],
list(ImergeSorted([1, 2, 2], [3], [3])))
if __name__ == '__main__':
unittest.main()