Antoine Pitrou | 91f4380 | 2019-05-26 17:10:09 +0200 | [diff] [blame] | 1 | """Unit tests for the PickleBuffer object. |
| 2 | |
| 3 | Pickling tests themselves are in pickletester.py. |
| 4 | """ |
| 5 | |
| 6 | import gc |
| 7 | from pickle import PickleBuffer |
Antoine Pitrou | 91f4380 | 2019-05-26 17:10:09 +0200 | [diff] [blame] | 8 | import weakref |
| 9 | import unittest |
| 10 | |
| 11 | from test import support |
| 12 | |
| 13 | |
| 14 | class B(bytes): |
| 15 | pass |
| 16 | |
| 17 | |
| 18 | class PickleBufferTest(unittest.TestCase): |
| 19 | |
| 20 | def check_memoryview(self, pb, equiv): |
| 21 | with memoryview(pb) as m: |
| 22 | with memoryview(equiv) as expected: |
| 23 | self.assertEqual(m.nbytes, expected.nbytes) |
| 24 | self.assertEqual(m.readonly, expected.readonly) |
| 25 | self.assertEqual(m.itemsize, expected.itemsize) |
| 26 | self.assertEqual(m.shape, expected.shape) |
| 27 | self.assertEqual(m.strides, expected.strides) |
| 28 | self.assertEqual(m.c_contiguous, expected.c_contiguous) |
| 29 | self.assertEqual(m.f_contiguous, expected.f_contiguous) |
| 30 | self.assertEqual(m.format, expected.format) |
| 31 | self.assertEqual(m.tobytes(), expected.tobytes()) |
| 32 | |
| 33 | def test_constructor_failure(self): |
| 34 | with self.assertRaises(TypeError): |
| 35 | PickleBuffer() |
| 36 | with self.assertRaises(TypeError): |
| 37 | PickleBuffer("foo") |
| 38 | # Released memoryview fails taking a buffer |
| 39 | m = memoryview(b"foo") |
| 40 | m.release() |
| 41 | with self.assertRaises(ValueError): |
| 42 | PickleBuffer(m) |
| 43 | |
| 44 | def test_basics(self): |
| 45 | pb = PickleBuffer(b"foo") |
| 46 | self.assertEqual(b"foo", bytes(pb)) |
| 47 | with memoryview(pb) as m: |
| 48 | self.assertTrue(m.readonly) |
| 49 | |
| 50 | pb = PickleBuffer(bytearray(b"foo")) |
| 51 | self.assertEqual(b"foo", bytes(pb)) |
| 52 | with memoryview(pb) as m: |
| 53 | self.assertFalse(m.readonly) |
| 54 | m[0] = 48 |
| 55 | self.assertEqual(b"0oo", bytes(pb)) |
| 56 | |
| 57 | def test_release(self): |
| 58 | pb = PickleBuffer(b"foo") |
| 59 | pb.release() |
| 60 | with self.assertRaises(ValueError) as raises: |
| 61 | memoryview(pb) |
| 62 | self.assertIn("operation forbidden on released PickleBuffer object", |
| 63 | str(raises.exception)) |
| 64 | # Idempotency |
| 65 | pb.release() |
| 66 | |
| 67 | def test_cycle(self): |
| 68 | b = B(b"foo") |
| 69 | pb = PickleBuffer(b) |
| 70 | b.cycle = pb |
| 71 | wpb = weakref.ref(pb) |
| 72 | del b, pb |
| 73 | gc.collect() |
| 74 | self.assertIsNone(wpb()) |
| 75 | |
| 76 | def test_ndarray_2d(self): |
| 77 | # C-contiguous |
| 78 | ndarray = support.import_module("_testbuffer").ndarray |
| 79 | arr = ndarray(list(range(12)), shape=(4, 3), format='<i') |
| 80 | self.assertTrue(arr.c_contiguous) |
| 81 | self.assertFalse(arr.f_contiguous) |
| 82 | pb = PickleBuffer(arr) |
| 83 | self.check_memoryview(pb, arr) |
| 84 | # Non-contiguous |
| 85 | arr = arr[::2] |
| 86 | self.assertFalse(arr.c_contiguous) |
| 87 | self.assertFalse(arr.f_contiguous) |
| 88 | pb = PickleBuffer(arr) |
| 89 | self.check_memoryview(pb, arr) |
| 90 | # F-contiguous |
| 91 | arr = ndarray(list(range(12)), shape=(3, 4), strides=(4, 12), format='<i') |
| 92 | self.assertTrue(arr.f_contiguous) |
| 93 | self.assertFalse(arr.c_contiguous) |
| 94 | pb = PickleBuffer(arr) |
| 95 | self.check_memoryview(pb, arr) |
| 96 | |
| 97 | # Tests for PickleBuffer.raw() |
| 98 | |
| 99 | def check_raw(self, obj, equiv): |
| 100 | pb = PickleBuffer(obj) |
| 101 | with pb.raw() as m: |
| 102 | self.assertIsInstance(m, memoryview) |
| 103 | self.check_memoryview(m, equiv) |
| 104 | |
| 105 | def test_raw(self): |
| 106 | for obj in (b"foo", bytearray(b"foo")): |
| 107 | with self.subTest(obj=obj): |
| 108 | self.check_raw(obj, obj) |
| 109 | |
| 110 | def test_raw_ndarray(self): |
| 111 | # 1-D, contiguous |
| 112 | ndarray = support.import_module("_testbuffer").ndarray |
| 113 | arr = ndarray(list(range(3)), shape=(3,), format='<h') |
| 114 | equiv = b"\x00\x00\x01\x00\x02\x00" |
| 115 | self.check_raw(arr, equiv) |
| 116 | # 2-D, C-contiguous |
| 117 | arr = ndarray(list(range(6)), shape=(2, 3), format='<h') |
| 118 | equiv = b"\x00\x00\x01\x00\x02\x00\x03\x00\x04\x00\x05\x00" |
| 119 | self.check_raw(arr, equiv) |
| 120 | # 2-D, F-contiguous |
| 121 | arr = ndarray(list(range(6)), shape=(2, 3), strides=(2, 4), |
| 122 | format='<h') |
| 123 | # Note this is different from arr.tobytes() |
| 124 | equiv = b"\x00\x00\x01\x00\x02\x00\x03\x00\x04\x00\x05\x00" |
| 125 | self.check_raw(arr, equiv) |
| 126 | # 0-D |
| 127 | arr = ndarray(456, shape=(), format='<i') |
| 128 | equiv = b'\xc8\x01\x00\x00' |
| 129 | self.check_raw(arr, equiv) |
| 130 | |
| 131 | def check_raw_non_contiguous(self, obj): |
| 132 | pb = PickleBuffer(obj) |
| 133 | with self.assertRaisesRegex(BufferError, "non-contiguous"): |
| 134 | pb.raw() |
| 135 | |
| 136 | def test_raw_non_contiguous(self): |
| 137 | # 1-D |
| 138 | ndarray = support.import_module("_testbuffer").ndarray |
| 139 | arr = ndarray(list(range(6)), shape=(6,), format='<i')[::2] |
| 140 | self.check_raw_non_contiguous(arr) |
| 141 | # 2-D |
| 142 | arr = ndarray(list(range(12)), shape=(4, 3), format='<i')[::2] |
| 143 | self.check_raw_non_contiguous(arr) |
| 144 | |
| 145 | def test_raw_released(self): |
| 146 | pb = PickleBuffer(b"foo") |
| 147 | pb.release() |
| 148 | with self.assertRaises(ValueError) as raises: |
| 149 | pb.raw() |
| 150 | |
| 151 | |
| 152 | if __name__ == "__main__": |
| 153 | unittest.main() |