Feng Xiao | e841bac | 2015-12-11 17:09:20 -0800 | [diff] [blame] | 1 | #! /usr/bin/env python |
| 2 | # |
| 3 | # Protocol Buffers - Google's data interchange format |
| 4 | # Copyright 2008 Google Inc. All rights reserved. |
| 5 | # https://developers.google.com/protocol-buffers/ |
| 6 | # |
| 7 | # Redistribution and use in source and binary forms, with or without |
| 8 | # modification, are permitted provided that the following conditions are |
| 9 | # met: |
| 10 | # |
| 11 | # * Redistributions of source code must retain the above copyright |
| 12 | # notice, this list of conditions and the following disclaimer. |
| 13 | # * Redistributions in binary form must reproduce the above |
| 14 | # copyright notice, this list of conditions and the following disclaimer |
| 15 | # in the documentation and/or other materials provided with the |
| 16 | # distribution. |
| 17 | # * Neither the name of Google Inc. nor the names of its |
| 18 | # contributors may be used to endorse or promote products derived from |
| 19 | # this software without specific prior written permission. |
| 20 | # |
| 21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
| 22 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
| 23 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
| 24 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT |
| 25 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, |
| 26 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT |
| 27 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
| 28 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
| 29 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| 30 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| 31 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 32 | |
| 33 | """Test for google.protobuf.internal.well_known_types.""" |
| 34 | |
| 35 | __author__ = 'jieluo@google.com (Jie Luo)' |
| 36 | |
| 37 | from datetime import datetime |
| 38 | |
Feng Xiao | a0b8fd5 | 2015-12-21 13:43:13 -0800 | [diff] [blame] | 39 | try: |
| 40 | import unittest2 as unittest |
| 41 | except ImportError: |
| 42 | import unittest |
| 43 | |
CH Albach | 5477f8c | 2016-01-29 18:10:50 -0800 | [diff] [blame] | 44 | from google.protobuf import any_pb2 |
Feng Xiao | e841bac | 2015-12-11 17:09:20 -0800 | [diff] [blame] | 45 | from google.protobuf import duration_pb2 |
| 46 | from google.protobuf import field_mask_pb2 |
CH Albach | 5477f8c | 2016-01-29 18:10:50 -0800 | [diff] [blame] | 47 | from google.protobuf import struct_pb2 |
Feng Xiao | e841bac | 2015-12-11 17:09:20 -0800 | [diff] [blame] | 48 | from google.protobuf import timestamp_pb2 |
Feng Xiao | e841bac | 2015-12-11 17:09:20 -0800 | [diff] [blame] | 49 | from google.protobuf import unittest_pb2 |
CH Albach | 5477f8c | 2016-01-29 18:10:50 -0800 | [diff] [blame] | 50 | from google.protobuf.internal import any_test_pb2 |
Feng Xiao | e841bac | 2015-12-11 17:09:20 -0800 | [diff] [blame] | 51 | from google.protobuf.internal import test_util |
| 52 | from google.protobuf.internal import well_known_types |
| 53 | from google.protobuf import descriptor |
CH Albach | 5477f8c | 2016-01-29 18:10:50 -0800 | [diff] [blame] | 54 | from google.protobuf import text_format |
Feng Xiao | e841bac | 2015-12-11 17:09:20 -0800 | [diff] [blame] | 55 | |
| 56 | |
| 57 | class TimeUtilTestBase(unittest.TestCase): |
| 58 | |
| 59 | def CheckTimestampConversion(self, message, text): |
| 60 | self.assertEqual(text, message.ToJsonString()) |
| 61 | parsed_message = timestamp_pb2.Timestamp() |
| 62 | parsed_message.FromJsonString(text) |
| 63 | self.assertEqual(message, parsed_message) |
| 64 | |
| 65 | def CheckDurationConversion(self, message, text): |
| 66 | self.assertEqual(text, message.ToJsonString()) |
| 67 | parsed_message = duration_pb2.Duration() |
| 68 | parsed_message.FromJsonString(text) |
| 69 | self.assertEqual(message, parsed_message) |
| 70 | |
| 71 | |
| 72 | class TimeUtilTest(TimeUtilTestBase): |
| 73 | |
| 74 | def testTimestampSerializeAndParse(self): |
| 75 | message = timestamp_pb2.Timestamp() |
| 76 | # Generated output should contain 3, 6, or 9 fractional digits. |
| 77 | message.seconds = 0 |
| 78 | message.nanos = 0 |
| 79 | self.CheckTimestampConversion(message, '1970-01-01T00:00:00Z') |
| 80 | message.nanos = 10000000 |
| 81 | self.CheckTimestampConversion(message, '1970-01-01T00:00:00.010Z') |
| 82 | message.nanos = 10000 |
| 83 | self.CheckTimestampConversion(message, '1970-01-01T00:00:00.000010Z') |
| 84 | message.nanos = 10 |
| 85 | self.CheckTimestampConversion(message, '1970-01-01T00:00:00.000000010Z') |
| 86 | # Test min timestamps. |
| 87 | message.seconds = -62135596800 |
| 88 | message.nanos = 0 |
| 89 | self.CheckTimestampConversion(message, '0001-01-01T00:00:00Z') |
| 90 | # Test max timestamps. |
| 91 | message.seconds = 253402300799 |
| 92 | message.nanos = 999999999 |
| 93 | self.CheckTimestampConversion(message, '9999-12-31T23:59:59.999999999Z') |
| 94 | # Test negative timestamps. |
| 95 | message.seconds = -1 |
| 96 | self.CheckTimestampConversion(message, '1969-12-31T23:59:59.999999999Z') |
| 97 | |
| 98 | # Parsing accepts an fractional digits as long as they fit into nano |
| 99 | # precision. |
| 100 | message.FromJsonString('1970-01-01T00:00:00.1Z') |
| 101 | self.assertEqual(0, message.seconds) |
| 102 | self.assertEqual(100000000, message.nanos) |
| 103 | # Parsing accpets offsets. |
| 104 | message.FromJsonString('1970-01-01T00:00:00-08:00') |
| 105 | self.assertEqual(8 * 3600, message.seconds) |
| 106 | self.assertEqual(0, message.nanos) |
| 107 | |
| 108 | def testDurationSerializeAndParse(self): |
| 109 | message = duration_pb2.Duration() |
| 110 | # Generated output should contain 3, 6, or 9 fractional digits. |
| 111 | message.seconds = 0 |
| 112 | message.nanos = 0 |
| 113 | self.CheckDurationConversion(message, '0s') |
| 114 | message.nanos = 10000000 |
| 115 | self.CheckDurationConversion(message, '0.010s') |
| 116 | message.nanos = 10000 |
| 117 | self.CheckDurationConversion(message, '0.000010s') |
| 118 | message.nanos = 10 |
| 119 | self.CheckDurationConversion(message, '0.000000010s') |
| 120 | |
| 121 | # Test min and max |
| 122 | message.seconds = 315576000000 |
| 123 | message.nanos = 999999999 |
| 124 | self.CheckDurationConversion(message, '315576000000.999999999s') |
| 125 | message.seconds = -315576000000 |
| 126 | message.nanos = -999999999 |
| 127 | self.CheckDurationConversion(message, '-315576000000.999999999s') |
| 128 | |
| 129 | # Parsing accepts an fractional digits as long as they fit into nano |
| 130 | # precision. |
| 131 | message.FromJsonString('0.1s') |
| 132 | self.assertEqual(100000000, message.nanos) |
| 133 | message.FromJsonString('0.0000001s') |
| 134 | self.assertEqual(100, message.nanos) |
| 135 | |
| 136 | def testTimestampIntegerConversion(self): |
| 137 | message = timestamp_pb2.Timestamp() |
| 138 | message.FromNanoseconds(1) |
| 139 | self.assertEqual('1970-01-01T00:00:00.000000001Z', |
| 140 | message.ToJsonString()) |
| 141 | self.assertEqual(1, message.ToNanoseconds()) |
| 142 | |
| 143 | message.FromNanoseconds(-1) |
| 144 | self.assertEqual('1969-12-31T23:59:59.999999999Z', |
| 145 | message.ToJsonString()) |
| 146 | self.assertEqual(-1, message.ToNanoseconds()) |
| 147 | |
| 148 | message.FromMicroseconds(1) |
| 149 | self.assertEqual('1970-01-01T00:00:00.000001Z', |
| 150 | message.ToJsonString()) |
| 151 | self.assertEqual(1, message.ToMicroseconds()) |
| 152 | |
| 153 | message.FromMicroseconds(-1) |
| 154 | self.assertEqual('1969-12-31T23:59:59.999999Z', |
| 155 | message.ToJsonString()) |
| 156 | self.assertEqual(-1, message.ToMicroseconds()) |
| 157 | |
| 158 | message.FromMilliseconds(1) |
| 159 | self.assertEqual('1970-01-01T00:00:00.001Z', |
| 160 | message.ToJsonString()) |
| 161 | self.assertEqual(1, message.ToMilliseconds()) |
| 162 | |
| 163 | message.FromMilliseconds(-1) |
| 164 | self.assertEqual('1969-12-31T23:59:59.999Z', |
| 165 | message.ToJsonString()) |
| 166 | self.assertEqual(-1, message.ToMilliseconds()) |
| 167 | |
| 168 | message.FromSeconds(1) |
| 169 | self.assertEqual('1970-01-01T00:00:01Z', |
| 170 | message.ToJsonString()) |
| 171 | self.assertEqual(1, message.ToSeconds()) |
| 172 | |
| 173 | message.FromSeconds(-1) |
| 174 | self.assertEqual('1969-12-31T23:59:59Z', |
| 175 | message.ToJsonString()) |
| 176 | self.assertEqual(-1, message.ToSeconds()) |
| 177 | |
| 178 | message.FromNanoseconds(1999) |
| 179 | self.assertEqual(1, message.ToMicroseconds()) |
| 180 | # For negative values, Timestamp will be rounded down. |
| 181 | # For example, "1969-12-31T23:59:59.5Z" (i.e., -0.5s) rounded to seconds |
| 182 | # will be "1969-12-31T23:59:59Z" (i.e., -1s) rather than |
| 183 | # "1970-01-01T00:00:00Z" (i.e., 0s). |
| 184 | message.FromNanoseconds(-1999) |
| 185 | self.assertEqual(-2, message.ToMicroseconds()) |
| 186 | |
| 187 | def testDurationIntegerConversion(self): |
| 188 | message = duration_pb2.Duration() |
| 189 | message.FromNanoseconds(1) |
| 190 | self.assertEqual('0.000000001s', |
| 191 | message.ToJsonString()) |
| 192 | self.assertEqual(1, message.ToNanoseconds()) |
| 193 | |
| 194 | message.FromNanoseconds(-1) |
| 195 | self.assertEqual('-0.000000001s', |
| 196 | message.ToJsonString()) |
| 197 | self.assertEqual(-1, message.ToNanoseconds()) |
| 198 | |
| 199 | message.FromMicroseconds(1) |
| 200 | self.assertEqual('0.000001s', |
| 201 | message.ToJsonString()) |
| 202 | self.assertEqual(1, message.ToMicroseconds()) |
| 203 | |
| 204 | message.FromMicroseconds(-1) |
| 205 | self.assertEqual('-0.000001s', |
| 206 | message.ToJsonString()) |
| 207 | self.assertEqual(-1, message.ToMicroseconds()) |
| 208 | |
| 209 | message.FromMilliseconds(1) |
| 210 | self.assertEqual('0.001s', |
| 211 | message.ToJsonString()) |
| 212 | self.assertEqual(1, message.ToMilliseconds()) |
| 213 | |
| 214 | message.FromMilliseconds(-1) |
| 215 | self.assertEqual('-0.001s', |
| 216 | message.ToJsonString()) |
| 217 | self.assertEqual(-1, message.ToMilliseconds()) |
| 218 | |
| 219 | message.FromSeconds(1) |
| 220 | self.assertEqual('1s', message.ToJsonString()) |
| 221 | self.assertEqual(1, message.ToSeconds()) |
| 222 | |
| 223 | message.FromSeconds(-1) |
| 224 | self.assertEqual('-1s', |
| 225 | message.ToJsonString()) |
| 226 | self.assertEqual(-1, message.ToSeconds()) |
| 227 | |
| 228 | # Test truncation behavior. |
| 229 | message.FromNanoseconds(1999) |
| 230 | self.assertEqual(1, message.ToMicroseconds()) |
| 231 | |
| 232 | # For negative values, Duration will be rounded towards 0. |
| 233 | message.FromNanoseconds(-1999) |
| 234 | self.assertEqual(-1, message.ToMicroseconds()) |
| 235 | |
| 236 | def testDatetimeConverison(self): |
| 237 | message = timestamp_pb2.Timestamp() |
| 238 | dt = datetime(1970, 1, 1) |
| 239 | message.FromDatetime(dt) |
| 240 | self.assertEqual(dt, message.ToDatetime()) |
| 241 | |
| 242 | message.FromMilliseconds(1999) |
| 243 | self.assertEqual(datetime(1970, 1, 1, 0, 0, 1, 999000), |
| 244 | message.ToDatetime()) |
| 245 | |
| 246 | def testTimedeltaConversion(self): |
| 247 | message = duration_pb2.Duration() |
| 248 | message.FromNanoseconds(1999999999) |
| 249 | td = message.ToTimedelta() |
| 250 | self.assertEqual(1, td.seconds) |
| 251 | self.assertEqual(999999, td.microseconds) |
| 252 | |
| 253 | message.FromNanoseconds(-1999999999) |
| 254 | td = message.ToTimedelta() |
| 255 | self.assertEqual(-1, td.days) |
| 256 | self.assertEqual(86398, td.seconds) |
| 257 | self.assertEqual(1, td.microseconds) |
| 258 | |
| 259 | message.FromMicroseconds(-1) |
| 260 | td = message.ToTimedelta() |
| 261 | self.assertEqual(-1, td.days) |
| 262 | self.assertEqual(86399, td.seconds) |
| 263 | self.assertEqual(999999, td.microseconds) |
| 264 | converted_message = duration_pb2.Duration() |
| 265 | converted_message.FromTimedelta(td) |
| 266 | self.assertEqual(message, converted_message) |
| 267 | |
| 268 | def testInvalidTimestamp(self): |
| 269 | message = timestamp_pb2.Timestamp() |
| 270 | self.assertRaisesRegexp( |
| 271 | ValueError, |
| 272 | 'time data \'10000-01-01T00:00:00\' does not match' |
| 273 | ' format \'%Y-%m-%dT%H:%M:%S\'', |
| 274 | message.FromJsonString, '10000-01-01T00:00:00.00Z') |
| 275 | self.assertRaisesRegexp( |
| 276 | well_known_types.ParseError, |
| 277 | 'nanos 0123456789012 more than 9 fractional digits.', |
| 278 | message.FromJsonString, |
| 279 | '1970-01-01T00:00:00.0123456789012Z') |
| 280 | self.assertRaisesRegexp( |
| 281 | well_known_types.ParseError, |
| 282 | (r'Invalid timezone offset value: \+08.'), |
| 283 | message.FromJsonString, |
| 284 | '1972-01-01T01:00:00.01+08',) |
| 285 | self.assertRaisesRegexp( |
| 286 | ValueError, |
| 287 | 'year is out of range', |
| 288 | message.FromJsonString, |
| 289 | '0000-01-01T00:00:00Z') |
| 290 | message.seconds = 253402300800 |
| 291 | self.assertRaisesRegexp( |
| 292 | OverflowError, |
| 293 | 'date value out of range', |
| 294 | message.ToJsonString) |
| 295 | |
| 296 | def testInvalidDuration(self): |
| 297 | message = duration_pb2.Duration() |
| 298 | self.assertRaisesRegexp( |
| 299 | well_known_types.ParseError, |
| 300 | 'Duration must end with letter "s": 1.', |
| 301 | message.FromJsonString, '1') |
| 302 | self.assertRaisesRegexp( |
| 303 | well_known_types.ParseError, |
| 304 | 'Couldn\'t parse duration: 1...2s.', |
| 305 | message.FromJsonString, '1...2s') |
| 306 | |
| 307 | |
| 308 | class FieldMaskTest(unittest.TestCase): |
| 309 | |
| 310 | def testStringFormat(self): |
| 311 | mask = field_mask_pb2.FieldMask() |
| 312 | self.assertEqual('', mask.ToJsonString()) |
| 313 | mask.paths.append('foo') |
| 314 | self.assertEqual('foo', mask.ToJsonString()) |
| 315 | mask.paths.append('bar') |
| 316 | self.assertEqual('foo,bar', mask.ToJsonString()) |
| 317 | |
| 318 | mask.FromJsonString('') |
| 319 | self.assertEqual('', mask.ToJsonString()) |
| 320 | mask.FromJsonString('foo') |
| 321 | self.assertEqual(['foo'], mask.paths) |
| 322 | mask.FromJsonString('foo,bar') |
| 323 | self.assertEqual(['foo', 'bar'], mask.paths) |
| 324 | |
| 325 | def testDescriptorToFieldMask(self): |
| 326 | mask = field_mask_pb2.FieldMask() |
| 327 | msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR |
| 328 | mask.AllFieldsFromDescriptor(msg_descriptor) |
| 329 | self.assertEqual(75, len(mask.paths)) |
| 330 | self.assertTrue(mask.IsValidForDescriptor(msg_descriptor)) |
| 331 | for field in msg_descriptor.fields: |
| 332 | self.assertTrue(field.name in mask.paths) |
| 333 | mask.paths.append('optional_nested_message.bb') |
| 334 | self.assertTrue(mask.IsValidForDescriptor(msg_descriptor)) |
| 335 | mask.paths.append('repeated_nested_message.bb') |
| 336 | self.assertFalse(mask.IsValidForDescriptor(msg_descriptor)) |
| 337 | |
| 338 | def testCanonicalFrom(self): |
| 339 | mask = field_mask_pb2.FieldMask() |
| 340 | out_mask = field_mask_pb2.FieldMask() |
| 341 | # Paths will be sorted. |
| 342 | mask.FromJsonString('baz.quz,bar,foo') |
| 343 | out_mask.CanonicalFormFromMask(mask) |
| 344 | self.assertEqual('bar,baz.quz,foo', out_mask.ToJsonString()) |
| 345 | # Duplicated paths will be removed. |
| 346 | mask.FromJsonString('foo,bar,foo') |
| 347 | out_mask.CanonicalFormFromMask(mask) |
| 348 | self.assertEqual('bar,foo', out_mask.ToJsonString()) |
| 349 | # Sub-paths of other paths will be removed. |
| 350 | mask.FromJsonString('foo.b1,bar.b1,foo.b2,bar') |
| 351 | out_mask.CanonicalFormFromMask(mask) |
| 352 | self.assertEqual('bar,foo.b1,foo.b2', out_mask.ToJsonString()) |
| 353 | |
| 354 | # Test more deeply nested cases. |
| 355 | mask.FromJsonString( |
| 356 | 'foo.bar.baz1,foo.bar.baz2.quz,foo.bar.baz2') |
| 357 | out_mask.CanonicalFormFromMask(mask) |
| 358 | self.assertEqual('foo.bar.baz1,foo.bar.baz2', |
| 359 | out_mask.ToJsonString()) |
| 360 | mask.FromJsonString( |
| 361 | 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz') |
| 362 | out_mask.CanonicalFormFromMask(mask) |
| 363 | self.assertEqual('foo.bar.baz1,foo.bar.baz2', |
| 364 | out_mask.ToJsonString()) |
| 365 | mask.FromJsonString( |
| 366 | 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo.bar') |
| 367 | out_mask.CanonicalFormFromMask(mask) |
| 368 | self.assertEqual('foo.bar', out_mask.ToJsonString()) |
| 369 | mask.FromJsonString( |
| 370 | 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo') |
| 371 | out_mask.CanonicalFormFromMask(mask) |
| 372 | self.assertEqual('foo', out_mask.ToJsonString()) |
| 373 | |
| 374 | def testUnion(self): |
| 375 | mask1 = field_mask_pb2.FieldMask() |
| 376 | mask2 = field_mask_pb2.FieldMask() |
| 377 | out_mask = field_mask_pb2.FieldMask() |
| 378 | mask1.FromJsonString('foo,baz') |
| 379 | mask2.FromJsonString('bar,quz') |
| 380 | out_mask.Union(mask1, mask2) |
| 381 | self.assertEqual('bar,baz,foo,quz', out_mask.ToJsonString()) |
| 382 | # Overlap with duplicated paths. |
| 383 | mask1.FromJsonString('foo,baz.bb') |
| 384 | mask2.FromJsonString('baz.bb,quz') |
| 385 | out_mask.Union(mask1, mask2) |
| 386 | self.assertEqual('baz.bb,foo,quz', out_mask.ToJsonString()) |
| 387 | # Overlap with paths covering some other paths. |
| 388 | mask1.FromJsonString('foo.bar.baz,quz') |
| 389 | mask2.FromJsonString('foo.bar,bar') |
| 390 | out_mask.Union(mask1, mask2) |
| 391 | self.assertEqual('bar,foo.bar,quz', out_mask.ToJsonString()) |
| 392 | |
| 393 | def testIntersect(self): |
| 394 | mask1 = field_mask_pb2.FieldMask() |
| 395 | mask2 = field_mask_pb2.FieldMask() |
| 396 | out_mask = field_mask_pb2.FieldMask() |
| 397 | # Test cases without overlapping. |
| 398 | mask1.FromJsonString('foo,baz') |
| 399 | mask2.FromJsonString('bar,quz') |
| 400 | out_mask.Intersect(mask1, mask2) |
| 401 | self.assertEqual('', out_mask.ToJsonString()) |
| 402 | # Overlap with duplicated paths. |
| 403 | mask1.FromJsonString('foo,baz.bb') |
| 404 | mask2.FromJsonString('baz.bb,quz') |
| 405 | out_mask.Intersect(mask1, mask2) |
| 406 | self.assertEqual('baz.bb', out_mask.ToJsonString()) |
| 407 | # Overlap with paths covering some other paths. |
| 408 | mask1.FromJsonString('foo.bar.baz,quz') |
| 409 | mask2.FromJsonString('foo.bar,bar') |
| 410 | out_mask.Intersect(mask1, mask2) |
| 411 | self.assertEqual('foo.bar.baz', out_mask.ToJsonString()) |
| 412 | mask1.FromJsonString('foo.bar,bar') |
| 413 | mask2.FromJsonString('foo.bar.baz,quz') |
| 414 | out_mask.Intersect(mask1, mask2) |
| 415 | self.assertEqual('foo.bar.baz', out_mask.ToJsonString()) |
| 416 | |
| 417 | def testMergeMessage(self): |
| 418 | # Test merge one field. |
| 419 | src = unittest_pb2.TestAllTypes() |
| 420 | test_util.SetAllFields(src) |
| 421 | for field in src.DESCRIPTOR.fields: |
| 422 | if field.containing_oneof: |
| 423 | continue |
| 424 | field_name = field.name |
| 425 | dst = unittest_pb2.TestAllTypes() |
| 426 | # Only set one path to mask. |
| 427 | mask = field_mask_pb2.FieldMask() |
| 428 | mask.paths.append(field_name) |
| 429 | mask.MergeMessage(src, dst) |
| 430 | # The expected result message. |
| 431 | msg = unittest_pb2.TestAllTypes() |
| 432 | if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: |
| 433 | repeated_src = getattr(src, field_name) |
| 434 | repeated_msg = getattr(msg, field_name) |
| 435 | if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: |
| 436 | for item in repeated_src: |
| 437 | repeated_msg.add().CopyFrom(item) |
| 438 | else: |
| 439 | repeated_msg.extend(repeated_src) |
| 440 | elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: |
| 441 | getattr(msg, field_name).CopyFrom(getattr(src, field_name)) |
| 442 | else: |
| 443 | setattr(msg, field_name, getattr(src, field_name)) |
| 444 | # Only field specified in mask is merged. |
| 445 | self.assertEqual(msg, dst) |
| 446 | |
| 447 | # Test merge nested fields. |
| 448 | nested_src = unittest_pb2.NestedTestAllTypes() |
| 449 | nested_dst = unittest_pb2.NestedTestAllTypes() |
| 450 | nested_src.child.payload.optional_int32 = 1234 |
| 451 | nested_src.child.child.payload.optional_int32 = 5678 |
| 452 | mask = field_mask_pb2.FieldMask() |
| 453 | mask.FromJsonString('child.payload') |
| 454 | mask.MergeMessage(nested_src, nested_dst) |
| 455 | self.assertEqual(1234, nested_dst.child.payload.optional_int32) |
| 456 | self.assertEqual(0, nested_dst.child.child.payload.optional_int32) |
| 457 | |
| 458 | mask.FromJsonString('child.child.payload') |
| 459 | mask.MergeMessage(nested_src, nested_dst) |
| 460 | self.assertEqual(1234, nested_dst.child.payload.optional_int32) |
| 461 | self.assertEqual(5678, nested_dst.child.child.payload.optional_int32) |
| 462 | |
| 463 | nested_dst.Clear() |
| 464 | mask.FromJsonString('child.child.payload') |
| 465 | mask.MergeMessage(nested_src, nested_dst) |
| 466 | self.assertEqual(0, nested_dst.child.payload.optional_int32) |
| 467 | self.assertEqual(5678, nested_dst.child.child.payload.optional_int32) |
| 468 | |
| 469 | nested_dst.Clear() |
| 470 | mask.FromJsonString('child') |
| 471 | mask.MergeMessage(nested_src, nested_dst) |
| 472 | self.assertEqual(1234, nested_dst.child.payload.optional_int32) |
| 473 | self.assertEqual(5678, nested_dst.child.child.payload.optional_int32) |
| 474 | |
| 475 | # Test MergeOptions. |
| 476 | nested_dst.Clear() |
| 477 | nested_dst.child.payload.optional_int64 = 4321 |
| 478 | # Message fields will be merged by default. |
| 479 | mask.FromJsonString('child.payload') |
| 480 | mask.MergeMessage(nested_src, nested_dst) |
| 481 | self.assertEqual(1234, nested_dst.child.payload.optional_int32) |
| 482 | self.assertEqual(4321, nested_dst.child.payload.optional_int64) |
| 483 | # Change the behavior to replace message fields. |
| 484 | mask.FromJsonString('child.payload') |
| 485 | mask.MergeMessage(nested_src, nested_dst, True, False) |
| 486 | self.assertEqual(1234, nested_dst.child.payload.optional_int32) |
| 487 | self.assertEqual(0, nested_dst.child.payload.optional_int64) |
| 488 | |
| 489 | # By default, fields missing in source are not cleared in destination. |
| 490 | nested_dst.payload.optional_int32 = 1234 |
| 491 | self.assertTrue(nested_dst.HasField('payload')) |
| 492 | mask.FromJsonString('payload') |
| 493 | mask.MergeMessage(nested_src, nested_dst) |
| 494 | self.assertTrue(nested_dst.HasField('payload')) |
| 495 | # But they are cleared when replacing message fields. |
| 496 | nested_dst.Clear() |
| 497 | nested_dst.payload.optional_int32 = 1234 |
| 498 | mask.FromJsonString('payload') |
| 499 | mask.MergeMessage(nested_src, nested_dst, True, False) |
| 500 | self.assertFalse(nested_dst.HasField('payload')) |
| 501 | |
| 502 | nested_src.payload.repeated_int32.append(1234) |
| 503 | nested_dst.payload.repeated_int32.append(5678) |
| 504 | # Repeated fields will be appended by default. |
| 505 | mask.FromJsonString('payload.repeated_int32') |
| 506 | mask.MergeMessage(nested_src, nested_dst) |
| 507 | self.assertEqual(2, len(nested_dst.payload.repeated_int32)) |
| 508 | self.assertEqual(5678, nested_dst.payload.repeated_int32[0]) |
| 509 | self.assertEqual(1234, nested_dst.payload.repeated_int32[1]) |
| 510 | # Change the behavior to replace repeated fields. |
| 511 | mask.FromJsonString('payload.repeated_int32') |
| 512 | mask.MergeMessage(nested_src, nested_dst, False, True) |
| 513 | self.assertEqual(1, len(nested_dst.payload.repeated_int32)) |
| 514 | self.assertEqual(1234, nested_dst.payload.repeated_int32[0]) |
| 515 | |
CH Albach | 5477f8c | 2016-01-29 18:10:50 -0800 | [diff] [blame] | 516 | |
| 517 | class StructTest(unittest.TestCase): |
| 518 | |
| 519 | def testStruct(self): |
| 520 | struct = struct_pb2.Struct() |
| 521 | struct_class = struct.__class__ |
| 522 | |
| 523 | struct['key1'] = 5 |
| 524 | struct['key2'] = 'abc' |
| 525 | struct['key3'] = True |
| 526 | struct.get_or_create_struct('key4')['subkey'] = 11.0 |
| 527 | struct_list = struct.get_or_create_list('key5') |
| 528 | struct_list.extend([6, 'seven', True, False, None]) |
| 529 | struct_list.add_struct()['subkey2'] = 9 |
| 530 | |
| 531 | self.assertTrue(isinstance(struct, well_known_types.Struct)) |
| 532 | self.assertEquals(5, struct['key1']) |
| 533 | self.assertEquals('abc', struct['key2']) |
| 534 | self.assertIs(True, struct['key3']) |
| 535 | self.assertEquals(11, struct['key4']['subkey']) |
| 536 | inner_struct = struct_class() |
| 537 | inner_struct['subkey2'] = 9 |
| 538 | self.assertEquals([6, 'seven', True, False, None, inner_struct], |
| 539 | list(struct['key5'].items())) |
| 540 | |
| 541 | serialized = struct.SerializeToString() |
| 542 | |
| 543 | struct2 = struct_pb2.Struct() |
| 544 | struct2.ParseFromString(serialized) |
| 545 | |
| 546 | self.assertEquals(struct, struct2) |
| 547 | |
| 548 | self.assertTrue(isinstance(struct2, well_known_types.Struct)) |
| 549 | self.assertEquals(5, struct2['key1']) |
| 550 | self.assertEquals('abc', struct2['key2']) |
| 551 | self.assertIs(True, struct2['key3']) |
| 552 | self.assertEquals(11, struct2['key4']['subkey']) |
| 553 | self.assertEquals([6, 'seven', True, False, None, inner_struct], |
| 554 | list(struct2['key5'].items())) |
| 555 | |
| 556 | struct_list = struct2['key5'] |
| 557 | self.assertEquals(6, struct_list[0]) |
| 558 | self.assertEquals('seven', struct_list[1]) |
| 559 | self.assertEquals(True, struct_list[2]) |
| 560 | self.assertEquals(False, struct_list[3]) |
| 561 | self.assertEquals(None, struct_list[4]) |
| 562 | self.assertEquals(inner_struct, struct_list[5]) |
| 563 | |
| 564 | struct_list[1] = 7 |
| 565 | self.assertEquals(7, struct_list[1]) |
| 566 | |
| 567 | struct_list.add_list().extend([1, 'two', True, False, None]) |
| 568 | self.assertEquals([1, 'two', True, False, None], |
| 569 | list(struct_list[6].items())) |
| 570 | |
| 571 | text_serialized = str(struct) |
| 572 | struct3 = struct_pb2.Struct() |
| 573 | text_format.Merge(text_serialized, struct3) |
| 574 | self.assertEquals(struct, struct3) |
| 575 | |
| 576 | struct.get_or_create_struct('key3')['replace'] = 12 |
| 577 | self.assertEquals(12, struct['key3']['replace']) |
| 578 | |
| 579 | |
| 580 | class AnyTest(unittest.TestCase): |
| 581 | |
| 582 | def testAnyMessage(self): |
| 583 | # Creates and sets message. |
| 584 | msg = any_test_pb2.TestAny() |
| 585 | msg_descriptor = msg.DESCRIPTOR |
| 586 | all_types = unittest_pb2.TestAllTypes() |
| 587 | all_descriptor = all_types.DESCRIPTOR |
| 588 | all_types.repeated_string.append(u'\u00fc\ua71f') |
| 589 | # Packs to Any. |
| 590 | msg.value.Pack(all_types) |
| 591 | self.assertEqual(msg.value.type_url, |
| 592 | 'type.googleapis.com/%s' % all_descriptor.full_name) |
| 593 | self.assertEqual(msg.value.value, |
| 594 | all_types.SerializeToString()) |
| 595 | # Tests Is() method. |
| 596 | self.assertTrue(msg.value.Is(all_descriptor)) |
| 597 | self.assertFalse(msg.value.Is(msg_descriptor)) |
| 598 | # Unpacks Any. |
| 599 | unpacked_message = unittest_pb2.TestAllTypes() |
| 600 | self.assertTrue(msg.value.Unpack(unpacked_message)) |
| 601 | self.assertEqual(all_types, unpacked_message) |
| 602 | # Unpacks to different type. |
| 603 | self.assertFalse(msg.value.Unpack(msg)) |
| 604 | # Only Any messages have Pack method. |
| 605 | try: |
| 606 | msg.Pack(all_types) |
| 607 | except AttributeError: |
| 608 | pass |
| 609 | else: |
| 610 | raise AttributeError('%s should not have Pack method.' % |
| 611 | msg_descriptor.full_name) |
| 612 | |
| 613 | def testPackWithCustomTypeUrl(self): |
| 614 | submessage = any_test_pb2.TestAny() |
| 615 | submessage.int_value = 12345 |
| 616 | msg = any_pb2.Any() |
| 617 | # Pack with a custom type URL prefix. |
| 618 | msg.Pack(submessage, 'type.myservice.com') |
| 619 | self.assertEqual(msg.type_url, |
| 620 | 'type.myservice.com/%s' % submessage.DESCRIPTOR.full_name) |
| 621 | # Pack with a custom type URL prefix ending with '/'. |
| 622 | msg.Pack(submessage, 'type.myservice.com/') |
| 623 | self.assertEqual(msg.type_url, |
| 624 | 'type.myservice.com/%s' % submessage.DESCRIPTOR.full_name) |
| 625 | # Pack with an empty type URL prefix. |
| 626 | msg.Pack(submessage, '') |
| 627 | self.assertEqual(msg.type_url, |
| 628 | '/%s' % submessage.DESCRIPTOR.full_name) |
| 629 | # Test unpacking the type. |
| 630 | unpacked_message = any_test_pb2.TestAny() |
| 631 | self.assertTrue(msg.Unpack(unpacked_message)) |
| 632 | self.assertEqual(submessage, unpacked_message) |
| 633 | |
| 634 | |
Feng Xiao | e841bac | 2015-12-11 17:09:20 -0800 | [diff] [blame] | 635 | if __name__ == '__main__': |
| 636 | unittest.main() |