blob: 60c6550cf906171be13dd74925488e820eba32bd [file] [log] [blame]
Serhiy Storchaka1b80e632013-10-13 17:55:07 +03001from test.support import findfile, TESTFN, unlink
2import unittest
3import array
4import io
5import pickle
6import sys
7
8def byteswap2(data):
9 a = array.array('h', data)
10 a.byteswap()
11 return a.tobytes()
12
13def byteswap3(data):
14 ba = bytearray(data)
15 ba[::3] = data[2::3]
16 ba[2::3] = data[::3]
17 return bytes(ba)
18
19def byteswap4(data):
20 a = array.array('i', data)
21 a.byteswap()
22 return a.tobytes()
23
24
25class AudioTests:
26 close_fd = False
27
28 def setUp(self):
29 self.f = self.fout = None
30
31 def tearDown(self):
32 if self.f is not None:
33 self.f.close()
34 if self.fout is not None:
35 self.fout.close()
36 unlink(TESTFN)
37
38 def check_params(self, f, nchannels, sampwidth, framerate, nframes,
39 comptype, compname):
40 self.assertEqual(f.getnchannels(), nchannels)
41 self.assertEqual(f.getsampwidth(), sampwidth)
42 self.assertEqual(f.getframerate(), framerate)
43 self.assertEqual(f.getnframes(), nframes)
44 self.assertEqual(f.getcomptype(), comptype)
45 self.assertEqual(f.getcompname(), compname)
46
47 params = f.getparams()
48 self.assertEqual(params,
49 (nchannels, sampwidth, framerate, nframes, comptype, compname))
Serhiy Storchakaeb7414f2013-10-13 18:06:45 +030050 self.assertEqual(params.nchannels, nchannels)
51 self.assertEqual(params.sampwidth, sampwidth)
52 self.assertEqual(params.framerate, framerate)
53 self.assertEqual(params.nframes, nframes)
54 self.assertEqual(params.comptype, comptype)
55 self.assertEqual(params.compname, compname)
Serhiy Storchaka1b80e632013-10-13 17:55:07 +030056
57 dump = pickle.dumps(params)
58 self.assertEqual(pickle.loads(dump), params)
59
60
61class AudioWriteTests(AudioTests):
62
63 def create_file(self, testfile):
64 f = self.fout = self.module.open(testfile, 'wb')
65 f.setnchannels(self.nchannels)
66 f.setsampwidth(self.sampwidth)
67 f.setframerate(self.framerate)
68 f.setcomptype(self.comptype, self.compname)
69 return f
70
71 def check_file(self, testfile, nframes, frames):
Serhiy Storchakaeb7414f2013-10-13 18:06:45 +030072 with self.module.open(testfile, 'rb') as f:
Serhiy Storchaka1b80e632013-10-13 17:55:07 +030073 self.assertEqual(f.getnchannels(), self.nchannels)
74 self.assertEqual(f.getsampwidth(), self.sampwidth)
75 self.assertEqual(f.getframerate(), self.framerate)
76 self.assertEqual(f.getnframes(), nframes)
77 self.assertEqual(f.readframes(nframes), frames)
Serhiy Storchaka1b80e632013-10-13 17:55:07 +030078
79 def test_write_params(self):
80 f = self.create_file(TESTFN)
81 f.setnframes(self.nframes)
82 f.writeframes(self.frames)
83 self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
84 self.nframes, self.comptype, self.compname)
85 f.close()
86
Serhiy Storchakaeb7414f2013-10-13 18:06:45 +030087 def test_write_context_manager_calls_close(self):
88 # Close checks for a minimum header and will raise an error
89 # if it is not set, so this proves that close is called.
90 with self.assertRaises(self.module.Error):
91 with self.module.open(TESTFN, 'wb'):
92 pass
93 with self.assertRaises(self.module.Error):
94 with open(TESTFN, 'wb') as testfile:
95 with self.module.open(testfile):
96 pass
97
98 def test_context_manager_with_open_file(self):
99 with open(TESTFN, 'wb') as testfile:
100 with self.module.open(testfile) as f:
101 f.setnchannels(self.nchannels)
102 f.setsampwidth(self.sampwidth)
103 f.setframerate(self.framerate)
104 f.setcomptype(self.comptype, self.compname)
105 self.assertEqual(testfile.closed, self.close_fd)
106 with open(TESTFN, 'rb') as testfile:
107 with self.module.open(testfile) as f:
108 self.assertFalse(f.getfp().closed)
109 params = f.getparams()
110 self.assertEqual(params.nchannels, self.nchannels)
111 self.assertEqual(params.sampwidth, self.sampwidth)
112 self.assertEqual(params.framerate, self.framerate)
113 if not self.close_fd:
114 self.assertIsNone(f.getfp())
115 self.assertEqual(testfile.closed, self.close_fd)
116
117 def test_context_manager_with_filename(self):
118 # If the file doesn't get closed, this test won't fail, but it will
119 # produce a resource leak warning.
120 with self.module.open(TESTFN, 'wb') as f:
121 f.setnchannels(self.nchannels)
122 f.setsampwidth(self.sampwidth)
123 f.setframerate(self.framerate)
124 f.setcomptype(self.comptype, self.compname)
125 with self.module.open(TESTFN) as f:
126 self.assertFalse(f.getfp().closed)
127 params = f.getparams()
128 self.assertEqual(params.nchannels, self.nchannels)
129 self.assertEqual(params.sampwidth, self.sampwidth)
130 self.assertEqual(params.framerate, self.framerate)
131 if not self.close_fd:
132 self.assertIsNone(f.getfp())
133
Serhiy Storchaka1b80e632013-10-13 17:55:07 +0300134 def test_write(self):
135 f = self.create_file(TESTFN)
136 f.setnframes(self.nframes)
137 f.writeframes(self.frames)
138 f.close()
139
140 self.check_file(TESTFN, self.nframes, self.frames)
141
142 def test_incompleted_write(self):
143 with open(TESTFN, 'wb') as testfile:
144 testfile.write(b'ababagalamaga')
145 f = self.create_file(testfile)
146 f.setnframes(self.nframes + 1)
147 f.writeframes(self.frames)
148 f.close()
149
150 with open(TESTFN, 'rb') as testfile:
151 self.assertEqual(testfile.read(13), b'ababagalamaga')
152 self.check_file(testfile, self.nframes, self.frames)
153
154 def test_multiple_writes(self):
155 with open(TESTFN, 'wb') as testfile:
156 testfile.write(b'ababagalamaga')
157 f = self.create_file(testfile)
158 f.setnframes(self.nframes)
159 framesize = self.nchannels * self.sampwidth
160 f.writeframes(self.frames[:-framesize])
161 f.writeframes(self.frames[-framesize:])
162 f.close()
163
164 with open(TESTFN, 'rb') as testfile:
165 self.assertEqual(testfile.read(13), b'ababagalamaga')
166 self.check_file(testfile, self.nframes, self.frames)
167
168 def test_overflowed_write(self):
169 with open(TESTFN, 'wb') as testfile:
170 testfile.write(b'ababagalamaga')
171 f = self.create_file(testfile)
172 f.setnframes(self.nframes - 1)
173 f.writeframes(self.frames)
174 f.close()
175
176 with open(TESTFN, 'rb') as testfile:
177 self.assertEqual(testfile.read(13), b'ababagalamaga')
178 self.check_file(testfile, self.nframes, self.frames)
179
180
181class AudioTestsWithSourceFile(AudioTests):
182
183 @classmethod
184 def setUpClass(cls):
185 cls.sndfilepath = findfile(cls.sndfilename, subdir='audiodata')
186
187 def test_read_params(self):
188 f = self.f = self.module.open(self.sndfilepath)
189 #self.assertEqual(f.getfp().name, self.sndfilepath)
190 self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
191 self.sndfilenframes, self.comptype, self.compname)
192
193 def test_close(self):
194 testfile = open(self.sndfilepath, 'rb')
195 f = self.f = self.module.open(testfile)
196 self.assertFalse(testfile.closed)
197 f.close()
198 self.assertEqual(testfile.closed, self.close_fd)
199 testfile = open(TESTFN, 'wb')
200 fout = self.module.open(testfile, 'wb')
201 self.assertFalse(testfile.closed)
202 with self.assertRaises(self.module.Error):
203 fout.close()
204 self.assertEqual(testfile.closed, self.close_fd)
205 fout.close() # do nothing
206
207 def test_read(self):
208 framesize = self.nchannels * self.sampwidth
209 chunk1 = self.frames[:2 * framesize]
210 chunk2 = self.frames[2 * framesize: 4 * framesize]
211 f = self.f = self.module.open(self.sndfilepath)
212 self.assertEqual(f.readframes(0), b'')
213 self.assertEqual(f.tell(), 0)
214 self.assertEqual(f.readframes(2), chunk1)
215 f.rewind()
216 pos0 = f.tell()
217 self.assertEqual(pos0, 0)
218 self.assertEqual(f.readframes(2), chunk1)
219 pos2 = f.tell()
220 self.assertEqual(pos2, 2)
221 self.assertEqual(f.readframes(2), chunk2)
222 f.setpos(pos2)
223 self.assertEqual(f.readframes(2), chunk2)
224 f.setpos(pos0)
225 self.assertEqual(f.readframes(2), chunk1)
226 with self.assertRaises(self.module.Error):
227 f.setpos(-1)
228 with self.assertRaises(self.module.Error):
229 f.setpos(f.getnframes() + 1)
230
231 def test_copy(self):
232 f = self.f = self.module.open(self.sndfilepath)
233 fout = self.fout = self.module.open(TESTFN, 'wb')
234 fout.setparams(f.getparams())
235 i = 0
236 n = f.getnframes()
237 while n > 0:
238 i += 1
239 fout.writeframes(f.readframes(i))
240 n -= i
241 fout.close()
242 fout = self.fout = self.module.open(TESTFN, 'rb')
243 f.rewind()
244 self.assertEqual(f.getparams(), fout.getparams())
245 self.assertEqual(f.readframes(f.getnframes()),
246 fout.readframes(fout.getnframes()))
247
248 def test_read_not_from_start(self):
249 with open(TESTFN, 'wb') as testfile:
250 testfile.write(b'ababagalamaga')
251 with open(self.sndfilepath, 'rb') as f:
252 testfile.write(f.read())
253
254 with open(TESTFN, 'rb') as testfile:
255 self.assertEqual(testfile.read(13), b'ababagalamaga')
Serhiy Storchakaeb7414f2013-10-13 18:06:45 +0300256 with self.module.open(testfile, 'rb') as f:
Serhiy Storchaka1b80e632013-10-13 17:55:07 +0300257 self.assertEqual(f.getnchannels(), self.nchannels)
258 self.assertEqual(f.getsampwidth(), self.sampwidth)
259 self.assertEqual(f.getframerate(), self.framerate)
260 self.assertEqual(f.getnframes(), self.sndfilenframes)
261 self.assertEqual(f.readframes(self.nframes), self.frames)