blob: 0e9175d8f4930885795ae7b2e3a0d104fb6dc4ac [file] [log] [blame]
Serhiy Storchaka1b80e632013-10-13 17:55:07 +03001from test.support import findfile, TESTFN, unlink
2import unittest
3import array
4import io
5import pickle
6import sys
7
8def byteswap2(data):
9 a = array.array('h', data)
10 a.byteswap()
11 return a.tobytes()
12
13def byteswap3(data):
14 ba = bytearray(data)
15 ba[::3] = data[2::3]
16 ba[2::3] = data[::3]
17 return bytes(ba)
18
19def byteswap4(data):
20 a = array.array('i', data)
21 a.byteswap()
22 return a.tobytes()
23
Serhiy Storchaka7714ebb2013-11-16 13:04:00 +020024class UnseekableIO(io.FileIO):
25 def tell(self):
26 raise io.UnsupportedOperation
27
28 def seek(self, *args, **kwargs):
29 raise io.UnsupportedOperation
30
Serhiy Storchaka1b80e632013-10-13 17:55:07 +030031
32class AudioTests:
33 close_fd = False
34
35 def setUp(self):
36 self.f = self.fout = None
37
38 def tearDown(self):
39 if self.f is not None:
40 self.f.close()
41 if self.fout is not None:
42 self.fout.close()
43 unlink(TESTFN)
44
45 def check_params(self, f, nchannels, sampwidth, framerate, nframes,
46 comptype, compname):
47 self.assertEqual(f.getnchannels(), nchannels)
48 self.assertEqual(f.getsampwidth(), sampwidth)
49 self.assertEqual(f.getframerate(), framerate)
50 self.assertEqual(f.getnframes(), nframes)
51 self.assertEqual(f.getcomptype(), comptype)
52 self.assertEqual(f.getcompname(), compname)
53
54 params = f.getparams()
55 self.assertEqual(params,
56 (nchannels, sampwidth, framerate, nframes, comptype, compname))
Serhiy Storchakaeb7414f2013-10-13 18:06:45 +030057 self.assertEqual(params.nchannels, nchannels)
58 self.assertEqual(params.sampwidth, sampwidth)
59 self.assertEqual(params.framerate, framerate)
60 self.assertEqual(params.nframes, nframes)
61 self.assertEqual(params.comptype, comptype)
62 self.assertEqual(params.compname, compname)
Serhiy Storchaka1b80e632013-10-13 17:55:07 +030063
64 dump = pickle.dumps(params)
65 self.assertEqual(pickle.loads(dump), params)
66
67
68class AudioWriteTests(AudioTests):
69
70 def create_file(self, testfile):
71 f = self.fout = self.module.open(testfile, 'wb')
72 f.setnchannels(self.nchannels)
73 f.setsampwidth(self.sampwidth)
74 f.setframerate(self.framerate)
75 f.setcomptype(self.comptype, self.compname)
76 return f
77
78 def check_file(self, testfile, nframes, frames):
Serhiy Storchakaeb7414f2013-10-13 18:06:45 +030079 with self.module.open(testfile, 'rb') as f:
Serhiy Storchaka1b80e632013-10-13 17:55:07 +030080 self.assertEqual(f.getnchannels(), self.nchannels)
81 self.assertEqual(f.getsampwidth(), self.sampwidth)
82 self.assertEqual(f.getframerate(), self.framerate)
83 self.assertEqual(f.getnframes(), nframes)
84 self.assertEqual(f.readframes(nframes), frames)
Serhiy Storchaka1b80e632013-10-13 17:55:07 +030085
86 def test_write_params(self):
87 f = self.create_file(TESTFN)
88 f.setnframes(self.nframes)
89 f.writeframes(self.frames)
90 self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
91 self.nframes, self.comptype, self.compname)
92 f.close()
93
Serhiy Storchakaeb7414f2013-10-13 18:06:45 +030094 def test_write_context_manager_calls_close(self):
95 # Close checks for a minimum header and will raise an error
96 # if it is not set, so this proves that close is called.
97 with self.assertRaises(self.module.Error):
98 with self.module.open(TESTFN, 'wb'):
99 pass
100 with self.assertRaises(self.module.Error):
101 with open(TESTFN, 'wb') as testfile:
102 with self.module.open(testfile):
103 pass
104
105 def test_context_manager_with_open_file(self):
106 with open(TESTFN, 'wb') as testfile:
107 with self.module.open(testfile) as f:
108 f.setnchannels(self.nchannels)
109 f.setsampwidth(self.sampwidth)
110 f.setframerate(self.framerate)
111 f.setcomptype(self.comptype, self.compname)
112 self.assertEqual(testfile.closed, self.close_fd)
113 with open(TESTFN, 'rb') as testfile:
114 with self.module.open(testfile) as f:
115 self.assertFalse(f.getfp().closed)
116 params = f.getparams()
117 self.assertEqual(params.nchannels, self.nchannels)
118 self.assertEqual(params.sampwidth, self.sampwidth)
119 self.assertEqual(params.framerate, self.framerate)
120 if not self.close_fd:
121 self.assertIsNone(f.getfp())
122 self.assertEqual(testfile.closed, self.close_fd)
123
124 def test_context_manager_with_filename(self):
125 # If the file doesn't get closed, this test won't fail, but it will
126 # produce a resource leak warning.
127 with self.module.open(TESTFN, 'wb') as f:
128 f.setnchannels(self.nchannels)
129 f.setsampwidth(self.sampwidth)
130 f.setframerate(self.framerate)
131 f.setcomptype(self.comptype, self.compname)
132 with self.module.open(TESTFN) as f:
133 self.assertFalse(f.getfp().closed)
134 params = f.getparams()
135 self.assertEqual(params.nchannels, self.nchannels)
136 self.assertEqual(params.sampwidth, self.sampwidth)
137 self.assertEqual(params.framerate, self.framerate)
138 if not self.close_fd:
139 self.assertIsNone(f.getfp())
140
Serhiy Storchaka1b80e632013-10-13 17:55:07 +0300141 def test_write(self):
142 f = self.create_file(TESTFN)
143 f.setnframes(self.nframes)
144 f.writeframes(self.frames)
145 f.close()
146
147 self.check_file(TESTFN, self.nframes, self.frames)
148
Serhiy Storchaka452bab42013-11-16 14:01:31 +0200149 def test_write_bytearray(self):
150 f = self.create_file(TESTFN)
151 f.setnframes(self.nframes)
152 f.writeframes(bytearray(self.frames))
153 f.close()
154
155 self.check_file(TESTFN, self.nframes, self.frames)
156
157 def test_write_array(self):
158 f = self.create_file(TESTFN)
159 f.setnframes(self.nframes)
160 f.writeframes(array.array('h', self.frames))
161 f.close()
162
163 self.check_file(TESTFN, self.nframes, self.frames)
164
165 def test_write_memoryview(self):
166 f = self.create_file(TESTFN)
167 f.setnframes(self.nframes)
168 f.writeframes(memoryview(self.frames))
169 f.close()
170
171 self.check_file(TESTFN, self.nframes, self.frames)
172
Serhiy Storchaka1b80e632013-10-13 17:55:07 +0300173 def test_incompleted_write(self):
174 with open(TESTFN, 'wb') as testfile:
175 testfile.write(b'ababagalamaga')
176 f = self.create_file(testfile)
177 f.setnframes(self.nframes + 1)
178 f.writeframes(self.frames)
179 f.close()
180
181 with open(TESTFN, 'rb') as testfile:
182 self.assertEqual(testfile.read(13), b'ababagalamaga')
183 self.check_file(testfile, self.nframes, self.frames)
184
185 def test_multiple_writes(self):
186 with open(TESTFN, 'wb') as testfile:
187 testfile.write(b'ababagalamaga')
188 f = self.create_file(testfile)
189 f.setnframes(self.nframes)
190 framesize = self.nchannels * self.sampwidth
191 f.writeframes(self.frames[:-framesize])
192 f.writeframes(self.frames[-framesize:])
193 f.close()
194
195 with open(TESTFN, 'rb') as testfile:
196 self.assertEqual(testfile.read(13), b'ababagalamaga')
197 self.check_file(testfile, self.nframes, self.frames)
198
199 def test_overflowed_write(self):
200 with open(TESTFN, 'wb') as testfile:
201 testfile.write(b'ababagalamaga')
202 f = self.create_file(testfile)
203 f.setnframes(self.nframes - 1)
204 f.writeframes(self.frames)
205 f.close()
206
207 with open(TESTFN, 'rb') as testfile:
208 self.assertEqual(testfile.read(13), b'ababagalamaga')
209 self.check_file(testfile, self.nframes, self.frames)
210
Serhiy Storchaka7714ebb2013-11-16 13:04:00 +0200211 def test_unseekable_read(self):
212 with self.create_file(TESTFN) as f:
213 f.setnframes(self.nframes)
214 f.writeframes(self.frames)
215
216 with UnseekableIO(TESTFN, 'rb') as testfile:
217 self.check_file(testfile, self.nframes, self.frames)
218
219 def test_unseekable_write(self):
220 with UnseekableIO(TESTFN, 'wb') as testfile:
221 with self.create_file(testfile) as f:
222 f.setnframes(self.nframes)
223 f.writeframes(self.frames)
224
225 self.check_file(TESTFN, self.nframes, self.frames)
226
227 def test_unseekable_incompleted_write(self):
228 with UnseekableIO(TESTFN, 'wb') as testfile:
229 testfile.write(b'ababagalamaga')
230 f = self.create_file(testfile)
231 f.setnframes(self.nframes + 1)
232 try:
233 f.writeframes(self.frames)
234 except OSError:
235 pass
236 try:
237 f.close()
238 except OSError:
239 pass
240
241 with open(TESTFN, 'rb') as testfile:
242 self.assertEqual(testfile.read(13), b'ababagalamaga')
243 self.check_file(testfile, self.nframes + 1, self.frames)
244
245 def test_unseekable_overflowed_write(self):
246 with UnseekableIO(TESTFN, 'wb') as testfile:
247 testfile.write(b'ababagalamaga')
248 f = self.create_file(testfile)
249 f.setnframes(self.nframes - 1)
250 try:
251 f.writeframes(self.frames)
252 except OSError:
253 pass
254 try:
255 f.close()
256 except OSError:
257 pass
258
259 with open(TESTFN, 'rb') as testfile:
260 self.assertEqual(testfile.read(13), b'ababagalamaga')
261 framesize = self.nchannels * self.sampwidth
262 self.check_file(testfile, self.nframes - 1, self.frames[:-framesize])
263
Serhiy Storchaka1b80e632013-10-13 17:55:07 +0300264
265class AudioTestsWithSourceFile(AudioTests):
266
267 @classmethod
268 def setUpClass(cls):
269 cls.sndfilepath = findfile(cls.sndfilename, subdir='audiodata')
270
271 def test_read_params(self):
272 f = self.f = self.module.open(self.sndfilepath)
273 #self.assertEqual(f.getfp().name, self.sndfilepath)
274 self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
275 self.sndfilenframes, self.comptype, self.compname)
276
277 def test_close(self):
Serhiy Storchaka85812bc2013-10-14 20:09:47 +0300278 with open(self.sndfilepath, 'rb') as testfile:
279 f = self.f = self.module.open(testfile)
280 self.assertFalse(testfile.closed)
281 f.close()
282 self.assertEqual(testfile.closed, self.close_fd)
283 with open(TESTFN, 'wb') as testfile:
284 fout = self.fout = self.module.open(testfile, 'wb')
285 self.assertFalse(testfile.closed)
286 with self.assertRaises(self.module.Error):
287 fout.close()
288 self.assertEqual(testfile.closed, self.close_fd)
289 fout.close() # do nothing
Serhiy Storchaka1b80e632013-10-13 17:55:07 +0300290
291 def test_read(self):
292 framesize = self.nchannels * self.sampwidth
293 chunk1 = self.frames[:2 * framesize]
294 chunk2 = self.frames[2 * framesize: 4 * framesize]
295 f = self.f = self.module.open(self.sndfilepath)
296 self.assertEqual(f.readframes(0), b'')
297 self.assertEqual(f.tell(), 0)
298 self.assertEqual(f.readframes(2), chunk1)
299 f.rewind()
300 pos0 = f.tell()
301 self.assertEqual(pos0, 0)
302 self.assertEqual(f.readframes(2), chunk1)
303 pos2 = f.tell()
304 self.assertEqual(pos2, 2)
305 self.assertEqual(f.readframes(2), chunk2)
306 f.setpos(pos2)
307 self.assertEqual(f.readframes(2), chunk2)
308 f.setpos(pos0)
309 self.assertEqual(f.readframes(2), chunk1)
310 with self.assertRaises(self.module.Error):
311 f.setpos(-1)
312 with self.assertRaises(self.module.Error):
313 f.setpos(f.getnframes() + 1)
314
315 def test_copy(self):
316 f = self.f = self.module.open(self.sndfilepath)
317 fout = self.fout = self.module.open(TESTFN, 'wb')
318 fout.setparams(f.getparams())
319 i = 0
320 n = f.getnframes()
321 while n > 0:
322 i += 1
323 fout.writeframes(f.readframes(i))
324 n -= i
325 fout.close()
326 fout = self.fout = self.module.open(TESTFN, 'rb')
327 f.rewind()
328 self.assertEqual(f.getparams(), fout.getparams())
329 self.assertEqual(f.readframes(f.getnframes()),
330 fout.readframes(fout.getnframes()))
331
332 def test_read_not_from_start(self):
333 with open(TESTFN, 'wb') as testfile:
334 testfile.write(b'ababagalamaga')
335 with open(self.sndfilepath, 'rb') as f:
336 testfile.write(f.read())
337
338 with open(TESTFN, 'rb') as testfile:
339 self.assertEqual(testfile.read(13), b'ababagalamaga')
Serhiy Storchakaeb7414f2013-10-13 18:06:45 +0300340 with self.module.open(testfile, 'rb') as f:
Serhiy Storchaka1b80e632013-10-13 17:55:07 +0300341 self.assertEqual(f.getnchannels(), self.nchannels)
342 self.assertEqual(f.getsampwidth(), self.sampwidth)
343 self.assertEqual(f.getframerate(), self.framerate)
344 self.assertEqual(f.getnframes(), self.sndfilenframes)
345 self.assertEqual(f.readframes(self.nframes), self.frames)