blob: e378277b6bcd496e6699854a221c95900a3fa838 [file] [log] [blame]
Serhiy Storchaka0fa01732013-10-13 17:47:22 +03001from test.test_support import findfile, TESTFN, unlink
2import unittest
3import array
4import io
5import pickle
6import sys
7import base64
8
9def fromhex(s):
10 return base64.b16decode(s.replace(' ', ''))
11
12def byteswap2(data):
Serhiy Storchaka5397c972013-11-21 11:04:37 +020013 a = array.array('h')
14 a.fromstring(data)
Serhiy Storchaka0fa01732013-10-13 17:47:22 +030015 a.byteswap()
16 return a.tostring()
17
18def byteswap3(data):
19 ba = bytearray(data)
20 ba[::3] = data[2::3]
21 ba[2::3] = data[::3]
22 return bytes(ba)
23
24def byteswap4(data):
Serhiy Storchaka5397c972013-11-21 11:04:37 +020025 a = array.array('i')
26 a.fromstring(data)
Serhiy Storchaka0fa01732013-10-13 17:47:22 +030027 a.byteswap()
28 return a.tostring()
29
30
31class AudioTests:
32 close_fd = False
33
34 def setUp(self):
35 self.f = self.fout = None
36
37 def tearDown(self):
38 if self.f is not None:
39 self.f.close()
40 if self.fout is not None:
41 self.fout.close()
42 unlink(TESTFN)
43
44 def check_params(self, f, nchannels, sampwidth, framerate, nframes,
45 comptype, compname):
46 self.assertEqual(f.getnchannels(), nchannels)
47 self.assertEqual(f.getsampwidth(), sampwidth)
48 self.assertEqual(f.getframerate(), framerate)
49 self.assertEqual(f.getnframes(), nframes)
50 self.assertEqual(f.getcomptype(), comptype)
51 self.assertEqual(f.getcompname(), compname)
52
53 params = f.getparams()
54 self.assertEqual(params,
55 (nchannels, sampwidth, framerate, nframes, comptype, compname))
56
57 dump = pickle.dumps(params)
58 self.assertEqual(pickle.loads(dump), params)
59
60
61class AudioWriteTests(AudioTests):
62
63 def create_file(self, testfile):
64 f = self.fout = self.module.open(testfile, 'wb')
65 f.setnchannels(self.nchannels)
66 f.setsampwidth(self.sampwidth)
67 f.setframerate(self.framerate)
68 f.setcomptype(self.comptype, self.compname)
69 return f
70
71 def check_file(self, testfile, nframes, frames):
72 f = self.module.open(testfile, 'rb')
73 try:
74 self.assertEqual(f.getnchannels(), self.nchannels)
75 self.assertEqual(f.getsampwidth(), self.sampwidth)
76 self.assertEqual(f.getframerate(), self.framerate)
77 self.assertEqual(f.getnframes(), nframes)
78 self.assertEqual(f.readframes(nframes), frames)
79 finally:
80 f.close()
81
82 def test_write_params(self):
83 f = self.create_file(TESTFN)
84 f.setnframes(self.nframes)
85 f.writeframes(self.frames)
86 self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
87 self.nframes, self.comptype, self.compname)
88 f.close()
89
90 def test_write(self):
91 f = self.create_file(TESTFN)
92 f.setnframes(self.nframes)
93 f.writeframes(self.frames)
94 f.close()
95
96 self.check_file(TESTFN, self.nframes, self.frames)
97
98 def test_incompleted_write(self):
99 with open(TESTFN, 'wb') as testfile:
100 testfile.write(b'ababagalamaga')
101 f = self.create_file(testfile)
102 f.setnframes(self.nframes + 1)
103 f.writeframes(self.frames)
104 f.close()
105
106 with open(TESTFN, 'rb') as testfile:
107 self.assertEqual(testfile.read(13), b'ababagalamaga')
108 self.check_file(testfile, self.nframes, self.frames)
109
110 def test_multiple_writes(self):
111 with open(TESTFN, 'wb') as testfile:
112 testfile.write(b'ababagalamaga')
113 f = self.create_file(testfile)
114 f.setnframes(self.nframes)
115 framesize = self.nchannels * self.sampwidth
116 f.writeframes(self.frames[:-framesize])
117 f.writeframes(self.frames[-framesize:])
118 f.close()
119
120 with open(TESTFN, 'rb') as testfile:
121 self.assertEqual(testfile.read(13), b'ababagalamaga')
122 self.check_file(testfile, self.nframes, self.frames)
123
124 def test_overflowed_write(self):
125 with open(TESTFN, 'wb') as testfile:
126 testfile.write(b'ababagalamaga')
127 f = self.create_file(testfile)
128 f.setnframes(self.nframes - 1)
129 f.writeframes(self.frames)
130 f.close()
131
132 with open(TESTFN, 'rb') as testfile:
133 self.assertEqual(testfile.read(13), b'ababagalamaga')
134 self.check_file(testfile, self.nframes, self.frames)
135
136
137class AudioTestsWithSourceFile(AudioTests):
138
139 @classmethod
140 def setUpClass(cls):
141 cls.sndfilepath = findfile(cls.sndfilename, subdir='audiodata')
142
143 def test_read_params(self):
144 f = self.f = self.module.open(self.sndfilepath)
145 #self.assertEqual(f.getfp().name, self.sndfilepath)
146 self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
147 self.sndfilenframes, self.comptype, self.compname)
148
149 def test_close(self):
Serhiy Storchakae1a8a402013-10-14 20:09:30 +0300150 with open(self.sndfilepath, 'rb') as testfile:
151 f = self.f = self.module.open(testfile)
152 self.assertFalse(testfile.closed)
153 f.close()
154 self.assertEqual(testfile.closed, self.close_fd)
155 with open(TESTFN, 'wb') as testfile:
156 fout = self.fout = self.module.open(testfile, 'wb')
157 self.assertFalse(testfile.closed)
158 with self.assertRaises(self.module.Error):
159 fout.close()
160 self.assertEqual(testfile.closed, self.close_fd)
161 fout.close() # do nothing
Serhiy Storchaka0fa01732013-10-13 17:47:22 +0300162
163 def test_read(self):
164 framesize = self.nchannels * self.sampwidth
165 chunk1 = self.frames[:2 * framesize]
166 chunk2 = self.frames[2 * framesize: 4 * framesize]
167 f = self.f = self.module.open(self.sndfilepath)
168 self.assertEqual(f.readframes(0), b'')
169 self.assertEqual(f.tell(), 0)
170 self.assertEqual(f.readframes(2), chunk1)
171 f.rewind()
172 pos0 = f.tell()
173 self.assertEqual(pos0, 0)
174 self.assertEqual(f.readframes(2), chunk1)
175 pos2 = f.tell()
176 self.assertEqual(pos2, 2)
177 self.assertEqual(f.readframes(2), chunk2)
178 f.setpos(pos2)
179 self.assertEqual(f.readframes(2), chunk2)
180 f.setpos(pos0)
181 self.assertEqual(f.readframes(2), chunk1)
182 with self.assertRaises(self.module.Error):
183 f.setpos(-1)
184 with self.assertRaises(self.module.Error):
185 f.setpos(f.getnframes() + 1)
186
187 def test_copy(self):
188 f = self.f = self.module.open(self.sndfilepath)
189 fout = self.fout = self.module.open(TESTFN, 'wb')
190 fout.setparams(f.getparams())
191 i = 0
192 n = f.getnframes()
193 while n > 0:
194 i += 1
195 fout.writeframes(f.readframes(i))
196 n -= i
197 fout.close()
198 fout = self.fout = self.module.open(TESTFN, 'rb')
199 f.rewind()
200 self.assertEqual(f.getparams(), fout.getparams())
201 self.assertEqual(f.readframes(f.getnframes()),
202 fout.readframes(fout.getnframes()))
203
204 def test_read_not_from_start(self):
205 with open(TESTFN, 'wb') as testfile:
206 testfile.write(b'ababagalamaga')
207 with open(self.sndfilepath, 'rb') as f:
208 testfile.write(f.read())
209
210 with open(TESTFN, 'rb') as testfile:
211 self.assertEqual(testfile.read(13), b'ababagalamaga')
212 f = self.module.open(testfile, 'rb')
213 try:
214 self.assertEqual(f.getnchannels(), self.nchannels)
215 self.assertEqual(f.getsampwidth(), self.sampwidth)
216 self.assertEqual(f.getframerate(), self.framerate)
217 self.assertEqual(f.getnframes(), self.sndfilenframes)
218 self.assertEqual(f.readframes(self.nframes), self.frames)
219 finally:
220 f.close()