blob: 59e9928796dc9e132731f9762fd1dd16f55804bc [file] [log] [blame]
Serhiy Storchaka1b80e632013-10-13 17:55:07 +03001from test.support import findfile, TESTFN, unlink
2import unittest
3import array
4import io
5import pickle
6import sys
7
8def byteswap2(data):
9 a = array.array('h', data)
10 a.byteswap()
11 return a.tobytes()
12
13def byteswap3(data):
14 ba = bytearray(data)
15 ba[::3] = data[2::3]
16 ba[2::3] = data[::3]
17 return bytes(ba)
18
19def byteswap4(data):
20 a = array.array('i', data)
21 a.byteswap()
22 return a.tobytes()
23
24
25class AudioTests:
26 close_fd = False
27
28 def setUp(self):
29 self.f = self.fout = None
30
31 def tearDown(self):
32 if self.f is not None:
33 self.f.close()
34 if self.fout is not None:
35 self.fout.close()
36 unlink(TESTFN)
37
38 def check_params(self, f, nchannels, sampwidth, framerate, nframes,
39 comptype, compname):
40 self.assertEqual(f.getnchannels(), nchannels)
41 self.assertEqual(f.getsampwidth(), sampwidth)
42 self.assertEqual(f.getframerate(), framerate)
43 self.assertEqual(f.getnframes(), nframes)
44 self.assertEqual(f.getcomptype(), comptype)
45 self.assertEqual(f.getcompname(), compname)
46
47 params = f.getparams()
48 self.assertEqual(params,
49 (nchannels, sampwidth, framerate, nframes, comptype, compname))
50
51 dump = pickle.dumps(params)
52 self.assertEqual(pickle.loads(dump), params)
53
54
55class AudioWriteTests(AudioTests):
56
57 def create_file(self, testfile):
58 f = self.fout = self.module.open(testfile, 'wb')
59 f.setnchannels(self.nchannels)
60 f.setsampwidth(self.sampwidth)
61 f.setframerate(self.framerate)
62 f.setcomptype(self.comptype, self.compname)
63 return f
64
65 def check_file(self, testfile, nframes, frames):
66 f = self.module.open(testfile, 'rb')
67 try:
68 self.assertEqual(f.getnchannels(), self.nchannels)
69 self.assertEqual(f.getsampwidth(), self.sampwidth)
70 self.assertEqual(f.getframerate(), self.framerate)
71 self.assertEqual(f.getnframes(), nframes)
72 self.assertEqual(f.readframes(nframes), frames)
73 finally:
74 f.close()
75
76 def test_write_params(self):
77 f = self.create_file(TESTFN)
78 f.setnframes(self.nframes)
79 f.writeframes(self.frames)
80 self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
81 self.nframes, self.comptype, self.compname)
82 f.close()
83
84 def test_write(self):
85 f = self.create_file(TESTFN)
86 f.setnframes(self.nframes)
87 f.writeframes(self.frames)
88 f.close()
89
90 self.check_file(TESTFN, self.nframes, self.frames)
91
92 def test_incompleted_write(self):
93 with open(TESTFN, 'wb') as testfile:
94 testfile.write(b'ababagalamaga')
95 f = self.create_file(testfile)
96 f.setnframes(self.nframes + 1)
97 f.writeframes(self.frames)
98 f.close()
99
100 with open(TESTFN, 'rb') as testfile:
101 self.assertEqual(testfile.read(13), b'ababagalamaga')
102 self.check_file(testfile, self.nframes, self.frames)
103
104 def test_multiple_writes(self):
105 with open(TESTFN, 'wb') as testfile:
106 testfile.write(b'ababagalamaga')
107 f = self.create_file(testfile)
108 f.setnframes(self.nframes)
109 framesize = self.nchannels * self.sampwidth
110 f.writeframes(self.frames[:-framesize])
111 f.writeframes(self.frames[-framesize:])
112 f.close()
113
114 with open(TESTFN, 'rb') as testfile:
115 self.assertEqual(testfile.read(13), b'ababagalamaga')
116 self.check_file(testfile, self.nframes, self.frames)
117
118 def test_overflowed_write(self):
119 with open(TESTFN, 'wb') as testfile:
120 testfile.write(b'ababagalamaga')
121 f = self.create_file(testfile)
122 f.setnframes(self.nframes - 1)
123 f.writeframes(self.frames)
124 f.close()
125
126 with open(TESTFN, 'rb') as testfile:
127 self.assertEqual(testfile.read(13), b'ababagalamaga')
128 self.check_file(testfile, self.nframes, self.frames)
129
130
131class AudioTestsWithSourceFile(AudioTests):
132
133 @classmethod
134 def setUpClass(cls):
135 cls.sndfilepath = findfile(cls.sndfilename, subdir='audiodata')
136
137 def test_read_params(self):
138 f = self.f = self.module.open(self.sndfilepath)
139 #self.assertEqual(f.getfp().name, self.sndfilepath)
140 self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
141 self.sndfilenframes, self.comptype, self.compname)
142
143 def test_close(self):
Serhiy Storchaka85812bc2013-10-14 20:09:47 +0300144 with open(self.sndfilepath, 'rb') as testfile:
145 f = self.f = self.module.open(testfile)
146 self.assertFalse(testfile.closed)
147 f.close()
148 self.assertEqual(testfile.closed, self.close_fd)
149 with open(TESTFN, 'wb') as testfile:
150 fout = self.fout = self.module.open(testfile, 'wb')
151 self.assertFalse(testfile.closed)
152 with self.assertRaises(self.module.Error):
153 fout.close()
154 self.assertEqual(testfile.closed, self.close_fd)
155 fout.close() # do nothing
Serhiy Storchaka1b80e632013-10-13 17:55:07 +0300156
157 def test_read(self):
158 framesize = self.nchannels * self.sampwidth
159 chunk1 = self.frames[:2 * framesize]
160 chunk2 = self.frames[2 * framesize: 4 * framesize]
161 f = self.f = self.module.open(self.sndfilepath)
162 self.assertEqual(f.readframes(0), b'')
163 self.assertEqual(f.tell(), 0)
164 self.assertEqual(f.readframes(2), chunk1)
165 f.rewind()
166 pos0 = f.tell()
167 self.assertEqual(pos0, 0)
168 self.assertEqual(f.readframes(2), chunk1)
169 pos2 = f.tell()
170 self.assertEqual(pos2, 2)
171 self.assertEqual(f.readframes(2), chunk2)
172 f.setpos(pos2)
173 self.assertEqual(f.readframes(2), chunk2)
174 f.setpos(pos0)
175 self.assertEqual(f.readframes(2), chunk1)
176 with self.assertRaises(self.module.Error):
177 f.setpos(-1)
178 with self.assertRaises(self.module.Error):
179 f.setpos(f.getnframes() + 1)
180
181 def test_copy(self):
182 f = self.f = self.module.open(self.sndfilepath)
183 fout = self.fout = self.module.open(TESTFN, 'wb')
184 fout.setparams(f.getparams())
185 i = 0
186 n = f.getnframes()
187 while n > 0:
188 i += 1
189 fout.writeframes(f.readframes(i))
190 n -= i
191 fout.close()
192 fout = self.fout = self.module.open(TESTFN, 'rb')
193 f.rewind()
194 self.assertEqual(f.getparams(), fout.getparams())
195 self.assertEqual(f.readframes(f.getnframes()),
196 fout.readframes(fout.getnframes()))
197
198 def test_read_not_from_start(self):
199 with open(TESTFN, 'wb') as testfile:
200 testfile.write(b'ababagalamaga')
201 with open(self.sndfilepath, 'rb') as f:
202 testfile.write(f.read())
203
204 with open(TESTFN, 'rb') as testfile:
205 self.assertEqual(testfile.read(13), b'ababagalamaga')
206 f = self.module.open(testfile, 'rb')
207 try:
208 self.assertEqual(f.getnchannels(), self.nchannels)
209 self.assertEqual(f.getsampwidth(), self.sampwidth)
210 self.assertEqual(f.getframerate(), self.framerate)
211 self.assertEqual(f.getnframes(), self.sndfilenframes)
212 self.assertEqual(f.readframes(self.nframes), self.frames)
213 finally:
214 f.close()