blob: 0dad01722922a9a8f5a82dd1e1bda9e8487c0a42 [file] [log] [blame]
Serhiy Storchaka1b80e632013-10-13 17:55:07 +03001from test.support import findfile, TESTFN, unlink
Serhiy Storchaka1b80e632013-10-13 17:55:07 +03002import array
3import io
Brian Curtin9f914a02017-11-10 11:38:25 -05004from unittest import mock
Serhiy Storchaka1b80e632013-10-13 17:55:07 +03005import pickle
Berker Peksag1e8ee9b2016-04-24 07:31:42 +03006
Serhiy Storchaka1b80e632013-10-13 17:55:07 +03007
Serhiy Storchaka7714ebb2013-11-16 13:04:00 +02008class UnseekableIO(io.FileIO):
9 def tell(self):
10 raise io.UnsupportedOperation
11
12 def seek(self, *args, **kwargs):
13 raise io.UnsupportedOperation
14
Serhiy Storchaka1b80e632013-10-13 17:55:07 +030015
16class AudioTests:
17 close_fd = False
18
19 def setUp(self):
20 self.f = self.fout = None
21
22 def tearDown(self):
23 if self.f is not None:
24 self.f.close()
25 if self.fout is not None:
26 self.fout.close()
27 unlink(TESTFN)
28
29 def check_params(self, f, nchannels, sampwidth, framerate, nframes,
30 comptype, compname):
31 self.assertEqual(f.getnchannels(), nchannels)
32 self.assertEqual(f.getsampwidth(), sampwidth)
33 self.assertEqual(f.getframerate(), framerate)
34 self.assertEqual(f.getnframes(), nframes)
35 self.assertEqual(f.getcomptype(), comptype)
36 self.assertEqual(f.getcompname(), compname)
37
38 params = f.getparams()
39 self.assertEqual(params,
40 (nchannels, sampwidth, framerate, nframes, comptype, compname))
Serhiy Storchakaeb7414f2013-10-13 18:06:45 +030041 self.assertEqual(params.nchannels, nchannels)
42 self.assertEqual(params.sampwidth, sampwidth)
43 self.assertEqual(params.framerate, framerate)
44 self.assertEqual(params.nframes, nframes)
45 self.assertEqual(params.comptype, comptype)
46 self.assertEqual(params.compname, compname)
Serhiy Storchaka1b80e632013-10-13 17:55:07 +030047
Serhiy Storchakabad12572014-12-15 14:03:42 +020048 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
49 dump = pickle.dumps(params, proto)
50 self.assertEqual(pickle.loads(dump), params)
Serhiy Storchaka1b80e632013-10-13 17:55:07 +030051
52
Brian Curtin9f914a02017-11-10 11:38:25 -050053class AudioMiscTests(AudioTests):
54
55 def test_openfp_deprecated(self):
56 arg = "arg"
57 mode = "mode"
58 with mock.patch(f"{self.module.__name__}.open") as mock_open, \
59 self.assertWarns(DeprecationWarning):
60 self.module.openfp(arg, mode=mode)
61 mock_open.assert_called_with(arg, mode=mode)
62
63
Serhiy Storchaka1b80e632013-10-13 17:55:07 +030064class AudioWriteTests(AudioTests):
65
66 def create_file(self, testfile):
67 f = self.fout = self.module.open(testfile, 'wb')
68 f.setnchannels(self.nchannels)
69 f.setsampwidth(self.sampwidth)
70 f.setframerate(self.framerate)
71 f.setcomptype(self.comptype, self.compname)
72 return f
73
74 def check_file(self, testfile, nframes, frames):
Serhiy Storchakaeb7414f2013-10-13 18:06:45 +030075 with self.module.open(testfile, 'rb') as f:
Serhiy Storchaka1b80e632013-10-13 17:55:07 +030076 self.assertEqual(f.getnchannels(), self.nchannels)
77 self.assertEqual(f.getsampwidth(), self.sampwidth)
78 self.assertEqual(f.getframerate(), self.framerate)
79 self.assertEqual(f.getnframes(), nframes)
80 self.assertEqual(f.readframes(nframes), frames)
Serhiy Storchaka1b80e632013-10-13 17:55:07 +030081
82 def test_write_params(self):
83 f = self.create_file(TESTFN)
84 f.setnframes(self.nframes)
85 f.writeframes(self.frames)
86 self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
87 self.nframes, self.comptype, self.compname)
88 f.close()
89
Serhiy Storchakaeb7414f2013-10-13 18:06:45 +030090 def test_write_context_manager_calls_close(self):
91 # Close checks for a minimum header and will raise an error
92 # if it is not set, so this proves that close is called.
93 with self.assertRaises(self.module.Error):
94 with self.module.open(TESTFN, 'wb'):
95 pass
96 with self.assertRaises(self.module.Error):
97 with open(TESTFN, 'wb') as testfile:
98 with self.module.open(testfile):
99 pass
100
101 def test_context_manager_with_open_file(self):
102 with open(TESTFN, 'wb') as testfile:
103 with self.module.open(testfile) as f:
104 f.setnchannels(self.nchannels)
105 f.setsampwidth(self.sampwidth)
106 f.setframerate(self.framerate)
107 f.setcomptype(self.comptype, self.compname)
108 self.assertEqual(testfile.closed, self.close_fd)
109 with open(TESTFN, 'rb') as testfile:
110 with self.module.open(testfile) as f:
111 self.assertFalse(f.getfp().closed)
112 params = f.getparams()
113 self.assertEqual(params.nchannels, self.nchannels)
114 self.assertEqual(params.sampwidth, self.sampwidth)
115 self.assertEqual(params.framerate, self.framerate)
116 if not self.close_fd:
117 self.assertIsNone(f.getfp())
118 self.assertEqual(testfile.closed, self.close_fd)
119
120 def test_context_manager_with_filename(self):
121 # If the file doesn't get closed, this test won't fail, but it will
122 # produce a resource leak warning.
123 with self.module.open(TESTFN, 'wb') as f:
124 f.setnchannels(self.nchannels)
125 f.setsampwidth(self.sampwidth)
126 f.setframerate(self.framerate)
127 f.setcomptype(self.comptype, self.compname)
128 with self.module.open(TESTFN) as f:
129 self.assertFalse(f.getfp().closed)
130 params = f.getparams()
131 self.assertEqual(params.nchannels, self.nchannels)
132 self.assertEqual(params.sampwidth, self.sampwidth)
133 self.assertEqual(params.framerate, self.framerate)
134 if not self.close_fd:
135 self.assertIsNone(f.getfp())
136
Serhiy Storchaka1b80e632013-10-13 17:55:07 +0300137 def test_write(self):
138 f = self.create_file(TESTFN)
139 f.setnframes(self.nframes)
140 f.writeframes(self.frames)
141 f.close()
142
143 self.check_file(TESTFN, self.nframes, self.frames)
144
Serhiy Storchaka452bab42013-11-16 14:01:31 +0200145 def test_write_bytearray(self):
146 f = self.create_file(TESTFN)
147 f.setnframes(self.nframes)
148 f.writeframes(bytearray(self.frames))
149 f.close()
150
151 self.check_file(TESTFN, self.nframes, self.frames)
152
153 def test_write_array(self):
154 f = self.create_file(TESTFN)
155 f.setnframes(self.nframes)
156 f.writeframes(array.array('h', self.frames))
157 f.close()
158
159 self.check_file(TESTFN, self.nframes, self.frames)
160
161 def test_write_memoryview(self):
162 f = self.create_file(TESTFN)
163 f.setnframes(self.nframes)
164 f.writeframes(memoryview(self.frames))
165 f.close()
166
167 self.check_file(TESTFN, self.nframes, self.frames)
168
Serhiy Storchaka1b80e632013-10-13 17:55:07 +0300169 def test_incompleted_write(self):
170 with open(TESTFN, 'wb') as testfile:
171 testfile.write(b'ababagalamaga')
172 f = self.create_file(testfile)
173 f.setnframes(self.nframes + 1)
174 f.writeframes(self.frames)
175 f.close()
176
177 with open(TESTFN, 'rb') as testfile:
178 self.assertEqual(testfile.read(13), b'ababagalamaga')
179 self.check_file(testfile, self.nframes, self.frames)
180
181 def test_multiple_writes(self):
182 with open(TESTFN, 'wb') as testfile:
183 testfile.write(b'ababagalamaga')
184 f = self.create_file(testfile)
185 f.setnframes(self.nframes)
186 framesize = self.nchannels * self.sampwidth
187 f.writeframes(self.frames[:-framesize])
188 f.writeframes(self.frames[-framesize:])
189 f.close()
190
191 with open(TESTFN, 'rb') as testfile:
192 self.assertEqual(testfile.read(13), b'ababagalamaga')
193 self.check_file(testfile, self.nframes, self.frames)
194
195 def test_overflowed_write(self):
196 with open(TESTFN, 'wb') as testfile:
197 testfile.write(b'ababagalamaga')
198 f = self.create_file(testfile)
199 f.setnframes(self.nframes - 1)
200 f.writeframes(self.frames)
201 f.close()
202
203 with open(TESTFN, 'rb') as testfile:
204 self.assertEqual(testfile.read(13), b'ababagalamaga')
205 self.check_file(testfile, self.nframes, self.frames)
206
Serhiy Storchaka7714ebb2013-11-16 13:04:00 +0200207 def test_unseekable_read(self):
208 with self.create_file(TESTFN) as f:
209 f.setnframes(self.nframes)
210 f.writeframes(self.frames)
211
212 with UnseekableIO(TESTFN, 'rb') as testfile:
213 self.check_file(testfile, self.nframes, self.frames)
214
215 def test_unseekable_write(self):
216 with UnseekableIO(TESTFN, 'wb') as testfile:
217 with self.create_file(testfile) as f:
218 f.setnframes(self.nframes)
219 f.writeframes(self.frames)
220
221 self.check_file(TESTFN, self.nframes, self.frames)
222
223 def test_unseekable_incompleted_write(self):
224 with UnseekableIO(TESTFN, 'wb') as testfile:
225 testfile.write(b'ababagalamaga')
226 f = self.create_file(testfile)
227 f.setnframes(self.nframes + 1)
228 try:
229 f.writeframes(self.frames)
230 except OSError:
231 pass
232 try:
233 f.close()
234 except OSError:
235 pass
236
237 with open(TESTFN, 'rb') as testfile:
238 self.assertEqual(testfile.read(13), b'ababagalamaga')
239 self.check_file(testfile, self.nframes + 1, self.frames)
240
241 def test_unseekable_overflowed_write(self):
242 with UnseekableIO(TESTFN, 'wb') as testfile:
243 testfile.write(b'ababagalamaga')
244 f = self.create_file(testfile)
245 f.setnframes(self.nframes - 1)
246 try:
247 f.writeframes(self.frames)
248 except OSError:
249 pass
250 try:
251 f.close()
252 except OSError:
253 pass
254
255 with open(TESTFN, 'rb') as testfile:
256 self.assertEqual(testfile.read(13), b'ababagalamaga')
257 framesize = self.nchannels * self.sampwidth
258 self.check_file(testfile, self.nframes - 1, self.frames[:-framesize])
259
Serhiy Storchaka1b80e632013-10-13 17:55:07 +0300260
261class AudioTestsWithSourceFile(AudioTests):
262
263 @classmethod
264 def setUpClass(cls):
265 cls.sndfilepath = findfile(cls.sndfilename, subdir='audiodata')
266
267 def test_read_params(self):
268 f = self.f = self.module.open(self.sndfilepath)
269 #self.assertEqual(f.getfp().name, self.sndfilepath)
270 self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
271 self.sndfilenframes, self.comptype, self.compname)
272
273 def test_close(self):
Serhiy Storchaka85812bc2013-10-14 20:09:47 +0300274 with open(self.sndfilepath, 'rb') as testfile:
275 f = self.f = self.module.open(testfile)
276 self.assertFalse(testfile.closed)
277 f.close()
278 self.assertEqual(testfile.closed, self.close_fd)
279 with open(TESTFN, 'wb') as testfile:
280 fout = self.fout = self.module.open(testfile, 'wb')
281 self.assertFalse(testfile.closed)
282 with self.assertRaises(self.module.Error):
283 fout.close()
284 self.assertEqual(testfile.closed, self.close_fd)
285 fout.close() # do nothing
Serhiy Storchaka1b80e632013-10-13 17:55:07 +0300286
287 def test_read(self):
288 framesize = self.nchannels * self.sampwidth
289 chunk1 = self.frames[:2 * framesize]
290 chunk2 = self.frames[2 * framesize: 4 * framesize]
291 f = self.f = self.module.open(self.sndfilepath)
292 self.assertEqual(f.readframes(0), b'')
293 self.assertEqual(f.tell(), 0)
294 self.assertEqual(f.readframes(2), chunk1)
295 f.rewind()
296 pos0 = f.tell()
297 self.assertEqual(pos0, 0)
298 self.assertEqual(f.readframes(2), chunk1)
299 pos2 = f.tell()
300 self.assertEqual(pos2, 2)
301 self.assertEqual(f.readframes(2), chunk2)
302 f.setpos(pos2)
303 self.assertEqual(f.readframes(2), chunk2)
304 f.setpos(pos0)
305 self.assertEqual(f.readframes(2), chunk1)
306 with self.assertRaises(self.module.Error):
307 f.setpos(-1)
308 with self.assertRaises(self.module.Error):
309 f.setpos(f.getnframes() + 1)
310
311 def test_copy(self):
312 f = self.f = self.module.open(self.sndfilepath)
313 fout = self.fout = self.module.open(TESTFN, 'wb')
314 fout.setparams(f.getparams())
315 i = 0
316 n = f.getnframes()
317 while n > 0:
318 i += 1
319 fout.writeframes(f.readframes(i))
320 n -= i
321 fout.close()
322 fout = self.fout = self.module.open(TESTFN, 'rb')
323 f.rewind()
324 self.assertEqual(f.getparams(), fout.getparams())
325 self.assertEqual(f.readframes(f.getnframes()),
326 fout.readframes(fout.getnframes()))
327
328 def test_read_not_from_start(self):
329 with open(TESTFN, 'wb') as testfile:
330 testfile.write(b'ababagalamaga')
331 with open(self.sndfilepath, 'rb') as f:
332 testfile.write(f.read())
333
334 with open(TESTFN, 'rb') as testfile:
335 self.assertEqual(testfile.read(13), b'ababagalamaga')
Serhiy Storchakaeb7414f2013-10-13 18:06:45 +0300336 with self.module.open(testfile, 'rb') as f:
Serhiy Storchaka1b80e632013-10-13 17:55:07 +0300337 self.assertEqual(f.getnchannels(), self.nchannels)
338 self.assertEqual(f.getsampwidth(), self.sampwidth)
339 self.assertEqual(f.getframerate(), self.framerate)
340 self.assertEqual(f.getnframes(), self.sndfilenframes)
341 self.assertEqual(f.readframes(self.nframes), self.frames)