blob: 7e72157fd022169a5577a94dfff22082b49bd6f6 [file] [log] [blame]
Antoine Pitrou91f43802019-05-26 17:10:09 +02001"""Unit tests for the PickleBuffer object.
2
3Pickling tests themselves are in pickletester.py.
4"""
5
6import gc
7from pickle import PickleBuffer
8import sys
9import weakref
10import unittest
11
12from test import support
13
14
15class B(bytes):
16 pass
17
18
19class PickleBufferTest(unittest.TestCase):
20
21 def check_memoryview(self, pb, equiv):
22 with memoryview(pb) as m:
23 with memoryview(equiv) as expected:
24 self.assertEqual(m.nbytes, expected.nbytes)
25 self.assertEqual(m.readonly, expected.readonly)
26 self.assertEqual(m.itemsize, expected.itemsize)
27 self.assertEqual(m.shape, expected.shape)
28 self.assertEqual(m.strides, expected.strides)
29 self.assertEqual(m.c_contiguous, expected.c_contiguous)
30 self.assertEqual(m.f_contiguous, expected.f_contiguous)
31 self.assertEqual(m.format, expected.format)
32 self.assertEqual(m.tobytes(), expected.tobytes())
33
34 def test_constructor_failure(self):
35 with self.assertRaises(TypeError):
36 PickleBuffer()
37 with self.assertRaises(TypeError):
38 PickleBuffer("foo")
39 # Released memoryview fails taking a buffer
40 m = memoryview(b"foo")
41 m.release()
42 with self.assertRaises(ValueError):
43 PickleBuffer(m)
44
45 def test_basics(self):
46 pb = PickleBuffer(b"foo")
47 self.assertEqual(b"foo", bytes(pb))
48 with memoryview(pb) as m:
49 self.assertTrue(m.readonly)
50
51 pb = PickleBuffer(bytearray(b"foo"))
52 self.assertEqual(b"foo", bytes(pb))
53 with memoryview(pb) as m:
54 self.assertFalse(m.readonly)
55 m[0] = 48
56 self.assertEqual(b"0oo", bytes(pb))
57
58 def test_release(self):
59 pb = PickleBuffer(b"foo")
60 pb.release()
61 with self.assertRaises(ValueError) as raises:
62 memoryview(pb)
63 self.assertIn("operation forbidden on released PickleBuffer object",
64 str(raises.exception))
65 # Idempotency
66 pb.release()
67
68 def test_cycle(self):
69 b = B(b"foo")
70 pb = PickleBuffer(b)
71 b.cycle = pb
72 wpb = weakref.ref(pb)
73 del b, pb
74 gc.collect()
75 self.assertIsNone(wpb())
76
77 def test_ndarray_2d(self):
78 # C-contiguous
79 ndarray = support.import_module("_testbuffer").ndarray
80 arr = ndarray(list(range(12)), shape=(4, 3), format='<i')
81 self.assertTrue(arr.c_contiguous)
82 self.assertFalse(arr.f_contiguous)
83 pb = PickleBuffer(arr)
84 self.check_memoryview(pb, arr)
85 # Non-contiguous
86 arr = arr[::2]
87 self.assertFalse(arr.c_contiguous)
88 self.assertFalse(arr.f_contiguous)
89 pb = PickleBuffer(arr)
90 self.check_memoryview(pb, arr)
91 # F-contiguous
92 arr = ndarray(list(range(12)), shape=(3, 4), strides=(4, 12), format='<i')
93 self.assertTrue(arr.f_contiguous)
94 self.assertFalse(arr.c_contiguous)
95 pb = PickleBuffer(arr)
96 self.check_memoryview(pb, arr)
97
98 # Tests for PickleBuffer.raw()
99
100 def check_raw(self, obj, equiv):
101 pb = PickleBuffer(obj)
102 with pb.raw() as m:
103 self.assertIsInstance(m, memoryview)
104 self.check_memoryview(m, equiv)
105
106 def test_raw(self):
107 for obj in (b"foo", bytearray(b"foo")):
108 with self.subTest(obj=obj):
109 self.check_raw(obj, obj)
110
111 def test_raw_ndarray(self):
112 # 1-D, contiguous
113 ndarray = support.import_module("_testbuffer").ndarray
114 arr = ndarray(list(range(3)), shape=(3,), format='<h')
115 equiv = b"\x00\x00\x01\x00\x02\x00"
116 self.check_raw(arr, equiv)
117 # 2-D, C-contiguous
118 arr = ndarray(list(range(6)), shape=(2, 3), format='<h')
119 equiv = b"\x00\x00\x01\x00\x02\x00\x03\x00\x04\x00\x05\x00"
120 self.check_raw(arr, equiv)
121 # 2-D, F-contiguous
122 arr = ndarray(list(range(6)), shape=(2, 3), strides=(2, 4),
123 format='<h')
124 # Note this is different from arr.tobytes()
125 equiv = b"\x00\x00\x01\x00\x02\x00\x03\x00\x04\x00\x05\x00"
126 self.check_raw(arr, equiv)
127 # 0-D
128 arr = ndarray(456, shape=(), format='<i')
129 equiv = b'\xc8\x01\x00\x00'
130 self.check_raw(arr, equiv)
131
132 def check_raw_non_contiguous(self, obj):
133 pb = PickleBuffer(obj)
134 with self.assertRaisesRegex(BufferError, "non-contiguous"):
135 pb.raw()
136
137 def test_raw_non_contiguous(self):
138 # 1-D
139 ndarray = support.import_module("_testbuffer").ndarray
140 arr = ndarray(list(range(6)), shape=(6,), format='<i')[::2]
141 self.check_raw_non_contiguous(arr)
142 # 2-D
143 arr = ndarray(list(range(12)), shape=(4, 3), format='<i')[::2]
144 self.check_raw_non_contiguous(arr)
145
146 def test_raw_released(self):
147 pb = PickleBuffer(b"foo")
148 pb.release()
149 with self.assertRaises(ValueError) as raises:
150 pb.raw()
151
152
153if __name__ == "__main__":
154 unittest.main()