blob: 97981c882e825b15268802ac1bbed28493335c8d [file] [log] [blame]
Antoine Pitrou91f43802019-05-26 17:10:09 +02001"""Unit tests for the PickleBuffer object.
2
3Pickling tests themselves are in pickletester.py.
4"""
5
6import gc
7from pickle import PickleBuffer
Antoine Pitrou91f43802019-05-26 17:10:09 +02008import weakref
9import unittest
10
11from test import support
12
13
14class B(bytes):
15 pass
16
17
18class PickleBufferTest(unittest.TestCase):
19
20 def check_memoryview(self, pb, equiv):
21 with memoryview(pb) as m:
22 with memoryview(equiv) as expected:
23 self.assertEqual(m.nbytes, expected.nbytes)
24 self.assertEqual(m.readonly, expected.readonly)
25 self.assertEqual(m.itemsize, expected.itemsize)
26 self.assertEqual(m.shape, expected.shape)
27 self.assertEqual(m.strides, expected.strides)
28 self.assertEqual(m.c_contiguous, expected.c_contiguous)
29 self.assertEqual(m.f_contiguous, expected.f_contiguous)
30 self.assertEqual(m.format, expected.format)
31 self.assertEqual(m.tobytes(), expected.tobytes())
32
33 def test_constructor_failure(self):
34 with self.assertRaises(TypeError):
35 PickleBuffer()
36 with self.assertRaises(TypeError):
37 PickleBuffer("foo")
38 # Released memoryview fails taking a buffer
39 m = memoryview(b"foo")
40 m.release()
41 with self.assertRaises(ValueError):
42 PickleBuffer(m)
43
44 def test_basics(self):
45 pb = PickleBuffer(b"foo")
46 self.assertEqual(b"foo", bytes(pb))
47 with memoryview(pb) as m:
48 self.assertTrue(m.readonly)
49
50 pb = PickleBuffer(bytearray(b"foo"))
51 self.assertEqual(b"foo", bytes(pb))
52 with memoryview(pb) as m:
53 self.assertFalse(m.readonly)
54 m[0] = 48
55 self.assertEqual(b"0oo", bytes(pb))
56
57 def test_release(self):
58 pb = PickleBuffer(b"foo")
59 pb.release()
60 with self.assertRaises(ValueError) as raises:
61 memoryview(pb)
62 self.assertIn("operation forbidden on released PickleBuffer object",
63 str(raises.exception))
64 # Idempotency
65 pb.release()
66
67 def test_cycle(self):
68 b = B(b"foo")
69 pb = PickleBuffer(b)
70 b.cycle = pb
71 wpb = weakref.ref(pb)
72 del b, pb
73 gc.collect()
74 self.assertIsNone(wpb())
75
76 def test_ndarray_2d(self):
77 # C-contiguous
78 ndarray = support.import_module("_testbuffer").ndarray
79 arr = ndarray(list(range(12)), shape=(4, 3), format='<i')
80 self.assertTrue(arr.c_contiguous)
81 self.assertFalse(arr.f_contiguous)
82 pb = PickleBuffer(arr)
83 self.check_memoryview(pb, arr)
84 # Non-contiguous
85 arr = arr[::2]
86 self.assertFalse(arr.c_contiguous)
87 self.assertFalse(arr.f_contiguous)
88 pb = PickleBuffer(arr)
89 self.check_memoryview(pb, arr)
90 # F-contiguous
91 arr = ndarray(list(range(12)), shape=(3, 4), strides=(4, 12), format='<i')
92 self.assertTrue(arr.f_contiguous)
93 self.assertFalse(arr.c_contiguous)
94 pb = PickleBuffer(arr)
95 self.check_memoryview(pb, arr)
96
97 # Tests for PickleBuffer.raw()
98
99 def check_raw(self, obj, equiv):
100 pb = PickleBuffer(obj)
101 with pb.raw() as m:
102 self.assertIsInstance(m, memoryview)
103 self.check_memoryview(m, equiv)
104
105 def test_raw(self):
106 for obj in (b"foo", bytearray(b"foo")):
107 with self.subTest(obj=obj):
108 self.check_raw(obj, obj)
109
110 def test_raw_ndarray(self):
111 # 1-D, contiguous
112 ndarray = support.import_module("_testbuffer").ndarray
113 arr = ndarray(list(range(3)), shape=(3,), format='<h')
114 equiv = b"\x00\x00\x01\x00\x02\x00"
115 self.check_raw(arr, equiv)
116 # 2-D, C-contiguous
117 arr = ndarray(list(range(6)), shape=(2, 3), format='<h')
118 equiv = b"\x00\x00\x01\x00\x02\x00\x03\x00\x04\x00\x05\x00"
119 self.check_raw(arr, equiv)
120 # 2-D, F-contiguous
121 arr = ndarray(list(range(6)), shape=(2, 3), strides=(2, 4),
122 format='<h')
123 # Note this is different from arr.tobytes()
124 equiv = b"\x00\x00\x01\x00\x02\x00\x03\x00\x04\x00\x05\x00"
125 self.check_raw(arr, equiv)
126 # 0-D
127 arr = ndarray(456, shape=(), format='<i')
128 equiv = b'\xc8\x01\x00\x00'
129 self.check_raw(arr, equiv)
130
131 def check_raw_non_contiguous(self, obj):
132 pb = PickleBuffer(obj)
133 with self.assertRaisesRegex(BufferError, "non-contiguous"):
134 pb.raw()
135
136 def test_raw_non_contiguous(self):
137 # 1-D
138 ndarray = support.import_module("_testbuffer").ndarray
139 arr = ndarray(list(range(6)), shape=(6,), format='<i')[::2]
140 self.check_raw_non_contiguous(arr)
141 # 2-D
142 arr = ndarray(list(range(12)), shape=(4, 3), format='<i')[::2]
143 self.check_raw_non_contiguous(arr)
144
145 def test_raw_released(self):
146 pb = PickleBuffer(b"foo")
147 pb.release()
148 with self.assertRaises(ValueError) as raises:
149 pb.raw()
150
151
152if __name__ == "__main__":
153 unittest.main()