blob: 9306d9bb23fca6248e008ebd8651ffce177c5b58 [file] [log] [blame]
Serhiy Storchaka30d68c62014-01-26 23:48:20 +02001import imghdr
2import io
3import sys
4import unittest
5from test.test_support import findfile, TESTFN, unlink, run_unittest
6
7TEST_FILES = (
8 ('python.png', 'png'),
9 ('python.gif', 'gif'),
10 ('python.bmp', 'bmp'),
11 ('python.ppm', 'ppm'),
12 ('python.pgm', 'pgm'),
13 ('python.pbm', 'pbm'),
14 ('python.jpg', 'jpeg'),
15 ('python.ras', 'rast'),
16 ('python.sgi', 'rgb'),
17 ('python.tiff', 'tiff'),
18 ('python.xbm', 'xbm')
19)
20
21class UnseekableIO(io.FileIO):
22 def tell(self):
23 raise io.UnsupportedOperation
24
25 def seek(self, *args, **kwargs):
26 raise io.UnsupportedOperation
27
28class TestImghdr(unittest.TestCase):
29 @classmethod
30 def setUpClass(cls):
31 cls.testfile = findfile('python.png', subdir='imghdrdata')
32 with open(cls.testfile, 'rb') as stream:
33 cls.testdata = stream.read()
34
35 def tearDown(self):
36 unlink(TESTFN)
37
38 def test_data(self):
39 for filename, expected in TEST_FILES:
40 filename = findfile(filename, subdir='imghdrdata')
41 self.assertEqual(imghdr.what(filename), expected)
42 ufilename = filename.decode(sys.getfilesystemencoding())
43 self.assertEqual(imghdr.what(ufilename), expected)
44 with open(filename, 'rb') as stream:
45 self.assertEqual(imghdr.what(stream), expected)
46 with open(filename, 'rb') as stream:
47 data = stream.read()
48 self.assertEqual(imghdr.what(None, data), expected)
49
50 def test_register_test(self):
51 def test_jumbo(h, file):
52 if h.startswith(b'eggs'):
53 return 'ham'
54 imghdr.tests.append(test_jumbo)
55 self.addCleanup(imghdr.tests.pop)
56 self.assertEqual(imghdr.what(None, b'eggs'), 'ham')
57
58 def test_file_pos(self):
59 with open(TESTFN, 'wb') as stream:
60 stream.write(b'ababagalamaga')
61 pos = stream.tell()
62 stream.write(self.testdata)
63 with open(TESTFN, 'rb') as stream:
64 stream.seek(pos)
65 self.assertEqual(imghdr.what(stream), 'png')
66 self.assertEqual(stream.tell(), pos)
67
68 def test_bad_args(self):
69 with self.assertRaises(TypeError):
70 imghdr.what()
71 with self.assertRaises(AttributeError):
72 imghdr.what(None)
73 with self.assertRaises(TypeError):
74 imghdr.what(self.testfile, 1)
75 with open(self.testfile, 'rb') as f:
76 with self.assertRaises(AttributeError):
77 imghdr.what(f.fileno())
78
79 def test_invalid_headers(self):
80 for header in (b'\211PN\r\n',
81 b'\001\331',
82 b'\x59\xA6',
83 b'cutecat',
84 b'000000JFI',
85 b'GIF80'):
86 self.assertIsNone(imghdr.what(None, header))
87
88 def test_missing_file(self):
89 with self.assertRaises(IOError):
90 imghdr.what('missing')
91
92 def test_closed_file(self):
93 stream = open(self.testfile, 'rb')
94 stream.close()
95 with self.assertRaises(ValueError) as cm:
96 imghdr.what(stream)
97 stream = io.BytesIO(self.testdata)
98 stream.close()
99 with self.assertRaises(ValueError) as cm:
100 imghdr.what(stream)
101
102 def test_unseekable(self):
103 with open(TESTFN, 'wb') as stream:
104 stream.write(self.testdata)
105 with UnseekableIO(TESTFN, 'rb') as stream:
106 with self.assertRaises(io.UnsupportedOperation):
107 imghdr.what(stream)
108
109 def test_output_stream(self):
110 with open(TESTFN, 'wb') as stream:
111 stream.write(self.testdata)
112 stream.seek(0)
113 with self.assertRaises(IOError) as cm:
114 imghdr.what(stream)
115
116def test_main():
117 run_unittest(TestImghdr)
118
119if __name__ == '__main__':
120 test_main()