blob: c22f0a10a9651e84c7e7a70969fbd723da86b42f [file] [log] [blame]
Serhiy Storchaka1b80e632013-10-13 17:55:07 +03001from test.support import findfile, TESTFN, unlink
2import unittest
3import array
4import io
5import pickle
6import sys
7
8def byteswap2(data):
9 a = array.array('h', data)
10 a.byteswap()
11 return a.tobytes()
12
13def byteswap3(data):
14 ba = bytearray(data)
15 ba[::3] = data[2::3]
16 ba[2::3] = data[::3]
17 return bytes(ba)
18
19def byteswap4(data):
20 a = array.array('i', data)
21 a.byteswap()
22 return a.tobytes()
23
Serhiy Storchaka7714ebb2013-11-16 13:04:00 +020024class UnseekableIO(io.FileIO):
25 def tell(self):
26 raise io.UnsupportedOperation
27
28 def seek(self, *args, **kwargs):
29 raise io.UnsupportedOperation
30
Serhiy Storchaka1b80e632013-10-13 17:55:07 +030031
32class AudioTests:
33 close_fd = False
34
35 def setUp(self):
36 self.f = self.fout = None
37
38 def tearDown(self):
39 if self.f is not None:
40 self.f.close()
41 if self.fout is not None:
42 self.fout.close()
43 unlink(TESTFN)
44
45 def check_params(self, f, nchannels, sampwidth, framerate, nframes,
46 comptype, compname):
47 self.assertEqual(f.getnchannels(), nchannels)
48 self.assertEqual(f.getsampwidth(), sampwidth)
49 self.assertEqual(f.getframerate(), framerate)
50 self.assertEqual(f.getnframes(), nframes)
51 self.assertEqual(f.getcomptype(), comptype)
52 self.assertEqual(f.getcompname(), compname)
53
54 params = f.getparams()
55 self.assertEqual(params,
56 (nchannels, sampwidth, framerate, nframes, comptype, compname))
Serhiy Storchakaeb7414f2013-10-13 18:06:45 +030057 self.assertEqual(params.nchannels, nchannels)
58 self.assertEqual(params.sampwidth, sampwidth)
59 self.assertEqual(params.framerate, framerate)
60 self.assertEqual(params.nframes, nframes)
61 self.assertEqual(params.comptype, comptype)
62 self.assertEqual(params.compname, compname)
Serhiy Storchaka1b80e632013-10-13 17:55:07 +030063
64 dump = pickle.dumps(params)
65 self.assertEqual(pickle.loads(dump), params)
66
67
68class AudioWriteTests(AudioTests):
69
70 def create_file(self, testfile):
71 f = self.fout = self.module.open(testfile, 'wb')
72 f.setnchannels(self.nchannels)
73 f.setsampwidth(self.sampwidth)
74 f.setframerate(self.framerate)
75 f.setcomptype(self.comptype, self.compname)
76 return f
77
78 def check_file(self, testfile, nframes, frames):
Serhiy Storchakaeb7414f2013-10-13 18:06:45 +030079 with self.module.open(testfile, 'rb') as f:
Serhiy Storchaka1b80e632013-10-13 17:55:07 +030080 self.assertEqual(f.getnchannels(), self.nchannels)
81 self.assertEqual(f.getsampwidth(), self.sampwidth)
82 self.assertEqual(f.getframerate(), self.framerate)
83 self.assertEqual(f.getnframes(), nframes)
84 self.assertEqual(f.readframes(nframes), frames)
Serhiy Storchaka1b80e632013-10-13 17:55:07 +030085
86 def test_write_params(self):
87 f = self.create_file(TESTFN)
88 f.setnframes(self.nframes)
89 f.writeframes(self.frames)
90 self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
91 self.nframes, self.comptype, self.compname)
92 f.close()
93
Serhiy Storchakaeb7414f2013-10-13 18:06:45 +030094 def test_write_context_manager_calls_close(self):
95 # Close checks for a minimum header and will raise an error
96 # if it is not set, so this proves that close is called.
97 with self.assertRaises(self.module.Error):
98 with self.module.open(TESTFN, 'wb'):
99 pass
100 with self.assertRaises(self.module.Error):
101 with open(TESTFN, 'wb') as testfile:
102 with self.module.open(testfile):
103 pass
104
105 def test_context_manager_with_open_file(self):
106 with open(TESTFN, 'wb') as testfile:
107 with self.module.open(testfile) as f:
108 f.setnchannels(self.nchannels)
109 f.setsampwidth(self.sampwidth)
110 f.setframerate(self.framerate)
111 f.setcomptype(self.comptype, self.compname)
112 self.assertEqual(testfile.closed, self.close_fd)
113 with open(TESTFN, 'rb') as testfile:
114 with self.module.open(testfile) as f:
115 self.assertFalse(f.getfp().closed)
116 params = f.getparams()
117 self.assertEqual(params.nchannels, self.nchannels)
118 self.assertEqual(params.sampwidth, self.sampwidth)
119 self.assertEqual(params.framerate, self.framerate)
120 if not self.close_fd:
121 self.assertIsNone(f.getfp())
122 self.assertEqual(testfile.closed, self.close_fd)
123
124 def test_context_manager_with_filename(self):
125 # If the file doesn't get closed, this test won't fail, but it will
126 # produce a resource leak warning.
127 with self.module.open(TESTFN, 'wb') as f:
128 f.setnchannels(self.nchannels)
129 f.setsampwidth(self.sampwidth)
130 f.setframerate(self.framerate)
131 f.setcomptype(self.comptype, self.compname)
132 with self.module.open(TESTFN) as f:
133 self.assertFalse(f.getfp().closed)
134 params = f.getparams()
135 self.assertEqual(params.nchannels, self.nchannels)
136 self.assertEqual(params.sampwidth, self.sampwidth)
137 self.assertEqual(params.framerate, self.framerate)
138 if not self.close_fd:
139 self.assertIsNone(f.getfp())
140
Serhiy Storchaka1b80e632013-10-13 17:55:07 +0300141 def test_write(self):
142 f = self.create_file(TESTFN)
143 f.setnframes(self.nframes)
144 f.writeframes(self.frames)
145 f.close()
146
147 self.check_file(TESTFN, self.nframes, self.frames)
148
149 def test_incompleted_write(self):
150 with open(TESTFN, 'wb') as testfile:
151 testfile.write(b'ababagalamaga')
152 f = self.create_file(testfile)
153 f.setnframes(self.nframes + 1)
154 f.writeframes(self.frames)
155 f.close()
156
157 with open(TESTFN, 'rb') as testfile:
158 self.assertEqual(testfile.read(13), b'ababagalamaga')
159 self.check_file(testfile, self.nframes, self.frames)
160
161 def test_multiple_writes(self):
162 with open(TESTFN, 'wb') as testfile:
163 testfile.write(b'ababagalamaga')
164 f = self.create_file(testfile)
165 f.setnframes(self.nframes)
166 framesize = self.nchannels * self.sampwidth
167 f.writeframes(self.frames[:-framesize])
168 f.writeframes(self.frames[-framesize:])
169 f.close()
170
171 with open(TESTFN, 'rb') as testfile:
172 self.assertEqual(testfile.read(13), b'ababagalamaga')
173 self.check_file(testfile, self.nframes, self.frames)
174
175 def test_overflowed_write(self):
176 with open(TESTFN, 'wb') as testfile:
177 testfile.write(b'ababagalamaga')
178 f = self.create_file(testfile)
179 f.setnframes(self.nframes - 1)
180 f.writeframes(self.frames)
181 f.close()
182
183 with open(TESTFN, 'rb') as testfile:
184 self.assertEqual(testfile.read(13), b'ababagalamaga')
185 self.check_file(testfile, self.nframes, self.frames)
186
Serhiy Storchaka7714ebb2013-11-16 13:04:00 +0200187 def test_unseekable_read(self):
188 with self.create_file(TESTFN) as f:
189 f.setnframes(self.nframes)
190 f.writeframes(self.frames)
191
192 with UnseekableIO(TESTFN, 'rb') as testfile:
193 self.check_file(testfile, self.nframes, self.frames)
194
195 def test_unseekable_write(self):
196 with UnseekableIO(TESTFN, 'wb') as testfile:
197 with self.create_file(testfile) as f:
198 f.setnframes(self.nframes)
199 f.writeframes(self.frames)
200
201 self.check_file(TESTFN, self.nframes, self.frames)
202
203 def test_unseekable_incompleted_write(self):
204 with UnseekableIO(TESTFN, 'wb') as testfile:
205 testfile.write(b'ababagalamaga')
206 f = self.create_file(testfile)
207 f.setnframes(self.nframes + 1)
208 try:
209 f.writeframes(self.frames)
210 except OSError:
211 pass
212 try:
213 f.close()
214 except OSError:
215 pass
216
217 with open(TESTFN, 'rb') as testfile:
218 self.assertEqual(testfile.read(13), b'ababagalamaga')
219 self.check_file(testfile, self.nframes + 1, self.frames)
220
221 def test_unseekable_overflowed_write(self):
222 with UnseekableIO(TESTFN, 'wb') as testfile:
223 testfile.write(b'ababagalamaga')
224 f = self.create_file(testfile)
225 f.setnframes(self.nframes - 1)
226 try:
227 f.writeframes(self.frames)
228 except OSError:
229 pass
230 try:
231 f.close()
232 except OSError:
233 pass
234
235 with open(TESTFN, 'rb') as testfile:
236 self.assertEqual(testfile.read(13), b'ababagalamaga')
237 framesize = self.nchannels * self.sampwidth
238 self.check_file(testfile, self.nframes - 1, self.frames[:-framesize])
239
Serhiy Storchaka1b80e632013-10-13 17:55:07 +0300240
241class AudioTestsWithSourceFile(AudioTests):
242
243 @classmethod
244 def setUpClass(cls):
245 cls.sndfilepath = findfile(cls.sndfilename, subdir='audiodata')
246
247 def test_read_params(self):
248 f = self.f = self.module.open(self.sndfilepath)
249 #self.assertEqual(f.getfp().name, self.sndfilepath)
250 self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
251 self.sndfilenframes, self.comptype, self.compname)
252
253 def test_close(self):
Serhiy Storchaka85812bc2013-10-14 20:09:47 +0300254 with open(self.sndfilepath, 'rb') as testfile:
255 f = self.f = self.module.open(testfile)
256 self.assertFalse(testfile.closed)
257 f.close()
258 self.assertEqual(testfile.closed, self.close_fd)
259 with open(TESTFN, 'wb') as testfile:
260 fout = self.fout = self.module.open(testfile, 'wb')
261 self.assertFalse(testfile.closed)
262 with self.assertRaises(self.module.Error):
263 fout.close()
264 self.assertEqual(testfile.closed, self.close_fd)
265 fout.close() # do nothing
Serhiy Storchaka1b80e632013-10-13 17:55:07 +0300266
267 def test_read(self):
268 framesize = self.nchannels * self.sampwidth
269 chunk1 = self.frames[:2 * framesize]
270 chunk2 = self.frames[2 * framesize: 4 * framesize]
271 f = self.f = self.module.open(self.sndfilepath)
272 self.assertEqual(f.readframes(0), b'')
273 self.assertEqual(f.tell(), 0)
274 self.assertEqual(f.readframes(2), chunk1)
275 f.rewind()
276 pos0 = f.tell()
277 self.assertEqual(pos0, 0)
278 self.assertEqual(f.readframes(2), chunk1)
279 pos2 = f.tell()
280 self.assertEqual(pos2, 2)
281 self.assertEqual(f.readframes(2), chunk2)
282 f.setpos(pos2)
283 self.assertEqual(f.readframes(2), chunk2)
284 f.setpos(pos0)
285 self.assertEqual(f.readframes(2), chunk1)
286 with self.assertRaises(self.module.Error):
287 f.setpos(-1)
288 with self.assertRaises(self.module.Error):
289 f.setpos(f.getnframes() + 1)
290
291 def test_copy(self):
292 f = self.f = self.module.open(self.sndfilepath)
293 fout = self.fout = self.module.open(TESTFN, 'wb')
294 fout.setparams(f.getparams())
295 i = 0
296 n = f.getnframes()
297 while n > 0:
298 i += 1
299 fout.writeframes(f.readframes(i))
300 n -= i
301 fout.close()
302 fout = self.fout = self.module.open(TESTFN, 'rb')
303 f.rewind()
304 self.assertEqual(f.getparams(), fout.getparams())
305 self.assertEqual(f.readframes(f.getnframes()),
306 fout.readframes(fout.getnframes()))
307
308 def test_read_not_from_start(self):
309 with open(TESTFN, 'wb') as testfile:
310 testfile.write(b'ababagalamaga')
311 with open(self.sndfilepath, 'rb') as f:
312 testfile.write(f.read())
313
314 with open(TESTFN, 'rb') as testfile:
315 self.assertEqual(testfile.read(13), b'ababagalamaga')
Serhiy Storchakaeb7414f2013-10-13 18:06:45 +0300316 with self.module.open(testfile, 'rb') as f:
Serhiy Storchaka1b80e632013-10-13 17:55:07 +0300317 self.assertEqual(f.getnchannels(), self.nchannels)
318 self.assertEqual(f.getsampwidth(), self.sampwidth)
319 self.assertEqual(f.getframerate(), self.framerate)
320 self.assertEqual(f.getnframes(), self.sndfilenframes)
321 self.assertEqual(f.readframes(self.nframes), self.frames)