blob: 9c4ce6cd119b862efd8d4567212a8744743e7314 [file] [log] [blame]
Serhiy Storchaka0fa01732013-10-13 17:47:22 +03001from test.test_support import findfile, TESTFN, unlink
2import unittest
3import array
4import io
5import pickle
6import sys
7import base64
8
9def fromhex(s):
10 return base64.b16decode(s.replace(' ', ''))
11
12def byteswap2(data):
13 a = array.array('h', data)
14 a.byteswap()
15 return a.tostring()
16
17def byteswap3(data):
18 ba = bytearray(data)
19 ba[::3] = data[2::3]
20 ba[2::3] = data[::3]
21 return bytes(ba)
22
23def byteswap4(data):
24 a = array.array('i', data)
25 a.byteswap()
26 return a.tostring()
27
28
29class AudioTests:
30 close_fd = False
31
32 def setUp(self):
33 self.f = self.fout = None
34
35 def tearDown(self):
36 if self.f is not None:
37 self.f.close()
38 if self.fout is not None:
39 self.fout.close()
40 unlink(TESTFN)
41
42 def check_params(self, f, nchannels, sampwidth, framerate, nframes,
43 comptype, compname):
44 self.assertEqual(f.getnchannels(), nchannels)
45 self.assertEqual(f.getsampwidth(), sampwidth)
46 self.assertEqual(f.getframerate(), framerate)
47 self.assertEqual(f.getnframes(), nframes)
48 self.assertEqual(f.getcomptype(), comptype)
49 self.assertEqual(f.getcompname(), compname)
50
51 params = f.getparams()
52 self.assertEqual(params,
53 (nchannels, sampwidth, framerate, nframes, comptype, compname))
54
55 dump = pickle.dumps(params)
56 self.assertEqual(pickle.loads(dump), params)
57
58
59class AudioWriteTests(AudioTests):
60
61 def create_file(self, testfile):
62 f = self.fout = self.module.open(testfile, 'wb')
63 f.setnchannels(self.nchannels)
64 f.setsampwidth(self.sampwidth)
65 f.setframerate(self.framerate)
66 f.setcomptype(self.comptype, self.compname)
67 return f
68
69 def check_file(self, testfile, nframes, frames):
70 f = self.module.open(testfile, 'rb')
71 try:
72 self.assertEqual(f.getnchannels(), self.nchannels)
73 self.assertEqual(f.getsampwidth(), self.sampwidth)
74 self.assertEqual(f.getframerate(), self.framerate)
75 self.assertEqual(f.getnframes(), nframes)
76 self.assertEqual(f.readframes(nframes), frames)
77 finally:
78 f.close()
79
80 def test_write_params(self):
81 f = self.create_file(TESTFN)
82 f.setnframes(self.nframes)
83 f.writeframes(self.frames)
84 self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
85 self.nframes, self.comptype, self.compname)
86 f.close()
87
88 def test_write(self):
89 f = self.create_file(TESTFN)
90 f.setnframes(self.nframes)
91 f.writeframes(self.frames)
92 f.close()
93
94 self.check_file(TESTFN, self.nframes, self.frames)
95
96 def test_incompleted_write(self):
97 with open(TESTFN, 'wb') as testfile:
98 testfile.write(b'ababagalamaga')
99 f = self.create_file(testfile)
100 f.setnframes(self.nframes + 1)
101 f.writeframes(self.frames)
102 f.close()
103
104 with open(TESTFN, 'rb') as testfile:
105 self.assertEqual(testfile.read(13), b'ababagalamaga')
106 self.check_file(testfile, self.nframes, self.frames)
107
108 def test_multiple_writes(self):
109 with open(TESTFN, 'wb') as testfile:
110 testfile.write(b'ababagalamaga')
111 f = self.create_file(testfile)
112 f.setnframes(self.nframes)
113 framesize = self.nchannels * self.sampwidth
114 f.writeframes(self.frames[:-framesize])
115 f.writeframes(self.frames[-framesize:])
116 f.close()
117
118 with open(TESTFN, 'rb') as testfile:
119 self.assertEqual(testfile.read(13), b'ababagalamaga')
120 self.check_file(testfile, self.nframes, self.frames)
121
122 def test_overflowed_write(self):
123 with open(TESTFN, 'wb') as testfile:
124 testfile.write(b'ababagalamaga')
125 f = self.create_file(testfile)
126 f.setnframes(self.nframes - 1)
127 f.writeframes(self.frames)
128 f.close()
129
130 with open(TESTFN, 'rb') as testfile:
131 self.assertEqual(testfile.read(13), b'ababagalamaga')
132 self.check_file(testfile, self.nframes, self.frames)
133
134
135class AudioTestsWithSourceFile(AudioTests):
136
137 @classmethod
138 def setUpClass(cls):
139 cls.sndfilepath = findfile(cls.sndfilename, subdir='audiodata')
140
141 def test_read_params(self):
142 f = self.f = self.module.open(self.sndfilepath)
143 #self.assertEqual(f.getfp().name, self.sndfilepath)
144 self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
145 self.sndfilenframes, self.comptype, self.compname)
146
147 def test_close(self):
Serhiy Storchakae1a8a402013-10-14 20:09:30 +0300148 with open(self.sndfilepath, 'rb') as testfile:
149 f = self.f = self.module.open(testfile)
150 self.assertFalse(testfile.closed)
151 f.close()
152 self.assertEqual(testfile.closed, self.close_fd)
153 with open(TESTFN, 'wb') as testfile:
154 fout = self.fout = self.module.open(testfile, 'wb')
155 self.assertFalse(testfile.closed)
156 with self.assertRaises(self.module.Error):
157 fout.close()
158 self.assertEqual(testfile.closed, self.close_fd)
159 fout.close() # do nothing
Serhiy Storchaka0fa01732013-10-13 17:47:22 +0300160
161 def test_read(self):
162 framesize = self.nchannels * self.sampwidth
163 chunk1 = self.frames[:2 * framesize]
164 chunk2 = self.frames[2 * framesize: 4 * framesize]
165 f = self.f = self.module.open(self.sndfilepath)
166 self.assertEqual(f.readframes(0), b'')
167 self.assertEqual(f.tell(), 0)
168 self.assertEqual(f.readframes(2), chunk1)
169 f.rewind()
170 pos0 = f.tell()
171 self.assertEqual(pos0, 0)
172 self.assertEqual(f.readframes(2), chunk1)
173 pos2 = f.tell()
174 self.assertEqual(pos2, 2)
175 self.assertEqual(f.readframes(2), chunk2)
176 f.setpos(pos2)
177 self.assertEqual(f.readframes(2), chunk2)
178 f.setpos(pos0)
179 self.assertEqual(f.readframes(2), chunk1)
180 with self.assertRaises(self.module.Error):
181 f.setpos(-1)
182 with self.assertRaises(self.module.Error):
183 f.setpos(f.getnframes() + 1)
184
185 def test_copy(self):
186 f = self.f = self.module.open(self.sndfilepath)
187 fout = self.fout = self.module.open(TESTFN, 'wb')
188 fout.setparams(f.getparams())
189 i = 0
190 n = f.getnframes()
191 while n > 0:
192 i += 1
193 fout.writeframes(f.readframes(i))
194 n -= i
195 fout.close()
196 fout = self.fout = self.module.open(TESTFN, 'rb')
197 f.rewind()
198 self.assertEqual(f.getparams(), fout.getparams())
199 self.assertEqual(f.readframes(f.getnframes()),
200 fout.readframes(fout.getnframes()))
201
202 def test_read_not_from_start(self):
203 with open(TESTFN, 'wb') as testfile:
204 testfile.write(b'ababagalamaga')
205 with open(self.sndfilepath, 'rb') as f:
206 testfile.write(f.read())
207
208 with open(TESTFN, 'rb') as testfile:
209 self.assertEqual(testfile.read(13), b'ababagalamaga')
210 f = self.module.open(testfile, 'rb')
211 try:
212 self.assertEqual(f.getnchannels(), self.nchannels)
213 self.assertEqual(f.getsampwidth(), self.sampwidth)
214 self.assertEqual(f.getframerate(), self.framerate)
215 self.assertEqual(f.getnframes(), self.sndfilenframes)
216 self.assertEqual(f.readframes(self.nframes), self.frames)
217 finally:
218 f.close()