blob: 0f6fc5534a3f9f66df59e59ccbd2389ec38ecee0 [file] [log] [blame]
Christian Heimes7b6fc8e2007-11-08 02:28:11 +00001"""Unit tests for the memoryview
2
3XXX We need more tests! Some tests are in test_bytes
4"""
5
6import unittest
Benjamin Petersonee8712c2008-05-20 21:35:26 +00007import test.support
Antoine Pitrou616d2852008-08-19 22:09:34 +00008import sys
Antoine Pitrouc6b09eb2008-09-01 15:10:14 +00009import gc
10import weakref
Christian Heimes7b6fc8e2007-11-08 02:28:11 +000011
Antoine Pitrou616d2852008-08-19 22:09:34 +000012
13class CommonMemoryTests:
14 #
15 # Tests common to direct memoryviews and sliced memoryviews
16 #
17
18 base_object = b"abcdef"
19
20 def check_getitem_with_type(self, tp):
21 b = tp(self.base_object)
22 oldrefcount = sys.getrefcount(b)
23 m = self._view(b)
24 self.assertEquals(m[0], b"a")
25 self.assert_(isinstance(m[0], bytes), type(m[0]))
26 self.assertEquals(m[5], b"f")
27 self.assertEquals(m[-1], b"f")
28 self.assertEquals(m[-6], b"a")
29 # Bounds checking
30 self.assertRaises(IndexError, lambda: m[6])
31 self.assertRaises(IndexError, lambda: m[-7])
32 self.assertRaises(IndexError, lambda: m[sys.maxsize])
33 self.assertRaises(IndexError, lambda: m[-sys.maxsize])
34 # Type checking
35 self.assertRaises(TypeError, lambda: m[None])
36 self.assertRaises(TypeError, lambda: m[0.0])
37 self.assertRaises(TypeError, lambda: m["a"])
38 m = None
39 self.assertEquals(sys.getrefcount(b), oldrefcount)
40
41 def test_getitem_readonly(self):
42 self.check_getitem_with_type(bytes)
43
44 def test_getitem_writable(self):
45 self.check_getitem_with_type(bytearray)
46
47 def test_setitem_readonly(self):
48 b = self.base_object
49 oldrefcount = sys.getrefcount(b)
50 m = self._view(b)
51 def setitem(value):
52 m[0] = value
53 self.assertRaises(TypeError, setitem, b"a")
54 self.assertRaises(TypeError, setitem, 65)
55 self.assertRaises(TypeError, setitem, memoryview(b"a"))
56 m = None
57 self.assertEquals(sys.getrefcount(b), oldrefcount)
58
59 def test_setitem_writable(self):
60 b = bytearray(self.base_object)
61 oldrefcount = sys.getrefcount(b)
62 m = self._view(b)
63 m[0] = b"0"
64 self._check_contents(b, b"0bcdef")
65 m[1:3] = b"12"
66 self._check_contents(b, b"012def")
67 m[1:1] = b""
68 self._check_contents(b, b"012def")
69 m[:] = b"abcdef"
70 self._check_contents(b, b"abcdef")
71
72 # Overlapping copies of a view into itself
73 m[0:3] = m[2:5]
74 self._check_contents(b, b"cdedef")
75 m[:] = b"abcdef"
76 m[2:5] = m[0:3]
77 self._check_contents(b, b"ababcf")
78
79 def setitem(key, value):
80 m[key] = value
81 # Bounds checking
82 self.assertRaises(IndexError, setitem, 6, b"a")
83 self.assertRaises(IndexError, setitem, -7, b"a")
84 self.assertRaises(IndexError, setitem, sys.maxsize, b"a")
85 self.assertRaises(IndexError, setitem, -sys.maxsize, b"a")
86 # Wrong index/slice types
87 self.assertRaises(TypeError, setitem, 0.0, b"a")
88 self.assertRaises(TypeError, setitem, (0,), b"a")
89 self.assertRaises(TypeError, setitem, "a", b"a")
90 # Trying to resize the memory object
91 self.assertRaises(ValueError, setitem, 0, b"")
92 self.assertRaises(ValueError, setitem, 0, b"ab")
93 self.assertRaises(ValueError, setitem, slice(1,1), b"a")
94 self.assertRaises(ValueError, setitem, slice(0,2), b"a")
95
96 m = None
97 self.assertEquals(sys.getrefcount(b), oldrefcount)
98
99 def test_len(self):
100 self.assertEquals(len(self._view(self.base_object)), 6)
101
102 def test_tobytes(self):
103 m = self._view(self.base_object)
104 b = m.tobytes()
105 self.assertEquals(b, b"abcdef")
106 self.assert_(isinstance(b, bytes), type(b))
107
108 def test_tolist(self):
109 m = self._view(self.base_object)
110 l = m.tolist()
111 self.assertEquals(l, list(b"abcdef"))
112
113 def test_compare(self):
114 # memoryviews can compare for equality with other objects
115 # having the buffer interface.
116 m = self._view(self.base_object)
117 for tp in (bytes, bytearray):
118 self.assertTrue(m == tp(b"abcdef"))
119 self.assertFalse(m != tp(b"abcdef"))
120 self.assertFalse(m == tp(b"abcde"))
121 self.assertTrue(m != tp(b"abcde"))
122 self.assertFalse(m == tp(b"abcde1"))
123 self.assertTrue(m != tp(b"abcde1"))
124 self.assertTrue(m == m)
125 self.assertTrue(m == m[:])
126 self.assertTrue(m[0:6] == m[:])
127 self.assertFalse(m[0:5] == m)
128
129 # Comparison with objects which don't support the buffer API
130 self.assertFalse(m == "abc")
131 self.assertTrue(m != "abc")
132 self.assertFalse("abc" == m)
133 self.assertTrue("abc" != m)
134
135 # Unordered comparisons
136 for c in (m, b"abcdef"):
137 self.assertRaises(TypeError, lambda: m < c)
138 self.assertRaises(TypeError, lambda: c <= m)
139 self.assertRaises(TypeError, lambda: m >= c)
140 self.assertRaises(TypeError, lambda: c > m)
141
142 def check_attributes_with_type(self, tp):
143 b = tp(self.base_object)
144 m = self._view(b)
145 self.assertEquals(m.format, 'B')
146 self.assertEquals(m.itemsize, 1)
147 self.assertEquals(m.ndim, 1)
148 self.assertEquals(m.shape, (6,))
149 self.assertEquals(m.size, 6)
150 self.assertEquals(m.strides, (1,))
151 self.assertEquals(m.suboffsets, None)
152 return m
153
154 def test_attributes_readonly(self):
155 m = self.check_attributes_with_type(bytes)
156 self.assertEquals(m.readonly, True)
157
158 def test_attributes_writable(self):
159 m = self.check_attributes_with_type(bytearray)
160 self.assertEquals(m.readonly, False)
161
Antoine Pitrouc6b09eb2008-09-01 15:10:14 +0000162 def test_getbuffer(self):
163 # Test PyObject_GetBuffer() on a memoryview object.
164 b = self.base_object
165 oldrefcount = sys.getrefcount(b)
166 m = self._view(b)
167 oldviewrefcount = sys.getrefcount(m)
168 s = str(m, "utf-8")
169 self._check_contents(b, s.encode("utf-8"))
170 self.assertEquals(sys.getrefcount(m), oldviewrefcount)
171 m = None
172 self.assertEquals(sys.getrefcount(b), oldrefcount)
173
174 def test_gc(self):
175 class MyBytes(bytes):
176 pass
177 class MyObject:
178 pass
179
180 # Create a reference cycle through a memoryview object
181 b = MyBytes(b'abc')
182 m = self._view(b)
183 o = MyObject()
184 b.m = m
185 b.o = o
186 wr = weakref.ref(o)
187 b = m = o = None
188 # The cycle must be broken
189 gc.collect()
190 self.assert_(wr() is None, wr())
191
Antoine Pitrou616d2852008-08-19 22:09:34 +0000192
193class MemoryviewTest(unittest.TestCase, CommonMemoryTests):
194
195 def _view(self, obj):
196 return memoryview(obj)
197
198 def _check_contents(self, obj, contents):
199 self.assertEquals(obj, contents)
Christian Heimes7b6fc8e2007-11-08 02:28:11 +0000200
201 def test_constructor(self):
202 ob = b'test'
203 self.assert_(memoryview(ob))
204 self.assert_(memoryview(object=ob))
205 self.assertRaises(TypeError, memoryview)
206 self.assertRaises(TypeError, memoryview, ob, ob)
207 self.assertRaises(TypeError, memoryview, argument=ob)
208 self.assertRaises(TypeError, memoryview, ob, argument=True)
209
Antoine Pitrou616d2852008-08-19 22:09:34 +0000210
211class MemorySliceTest(unittest.TestCase, CommonMemoryTests):
212 base_object = b"XabcdefY"
213
214 def _view(self, obj):
215 m = memoryview(obj)
216 return m[1:7]
217
218 def _check_contents(self, obj, contents):
219 self.assertEquals(obj[1:7], contents)
220
221 def test_refs(self):
222 m = memoryview(b"ab")
223 oldrefcount = sys.getrefcount(m)
224 m[1:2]
225 self.assertEquals(sys.getrefcount(m), oldrefcount)
226
227
228class MemorySliceSliceTest(unittest.TestCase, CommonMemoryTests):
229 base_object = b"XabcdefY"
230
231 def _view(self, obj):
232 m = memoryview(obj)
233 return m[:7][1:]
234
235 def _check_contents(self, obj, contents):
236 self.assertEquals(obj[1:7], contents)
237
238
Christian Heimes7b6fc8e2007-11-08 02:28:11 +0000239def test_main():
Antoine Pitrou616d2852008-08-19 22:09:34 +0000240 test.support.run_unittest(
241 MemoryviewTest, MemorySliceTest, MemorySliceSliceTest)
Christian Heimes7b6fc8e2007-11-08 02:28:11 +0000242
243
244if __name__ == "__main__":
245 test_main()