blob: 0ad4343f520b1b0909e28fcbacd1e53ad1ce7c93 [file] [log] [blame]
Serhiy Storchaka1ac00952014-01-26 23:48:38 +02001import imghdr
2import io
3import os
4import unittest
5import warnings
6from test.support import findfile, TESTFN, unlink
7
8TEST_FILES = (
9 ('python.png', 'png'),
10 ('python.gif', 'gif'),
11 ('python.bmp', 'bmp'),
12 ('python.ppm', 'ppm'),
13 ('python.pgm', 'pgm'),
14 ('python.pbm', 'pbm'),
15 ('python.jpg', 'jpeg'),
16 ('python.ras', 'rast'),
17 ('python.sgi', 'rgb'),
18 ('python.tiff', 'tiff'),
19 ('python.xbm', 'xbm')
20)
21
22class UnseekableIO(io.FileIO):
23 def tell(self):
24 raise io.UnsupportedOperation
25
26 def seek(self, *args, **kwargs):
27 raise io.UnsupportedOperation
28
29class TestImghdr(unittest.TestCase):
30 @classmethod
31 def setUpClass(cls):
32 cls.testfile = findfile('python.png', subdir='imghdrdata')
33 with open(cls.testfile, 'rb') as stream:
34 cls.testdata = stream.read()
35
36 def tearDown(self):
37 unlink(TESTFN)
38
39 def test_data(self):
40 for filename, expected in TEST_FILES:
41 filename = findfile(filename, subdir='imghdrdata')
42 self.assertEqual(imghdr.what(filename), expected)
43 with open(filename, 'rb') as stream:
44 self.assertEqual(imghdr.what(stream), expected)
45 with open(filename, 'rb') as stream:
46 data = stream.read()
47 self.assertEqual(imghdr.what(None, data), expected)
48 self.assertEqual(imghdr.what(None, bytearray(data)), expected)
49
50 def test_register_test(self):
51 def test_jumbo(h, file):
52 if h.startswith(b'eggs'):
53 return 'ham'
54 imghdr.tests.append(test_jumbo)
55 self.addCleanup(imghdr.tests.pop)
56 self.assertEqual(imghdr.what(None, b'eggs'), 'ham')
57
58 def test_file_pos(self):
59 with open(TESTFN, 'wb') as stream:
60 stream.write(b'ababagalamaga')
61 pos = stream.tell()
62 stream.write(self.testdata)
63 with open(TESTFN, 'rb') as stream:
64 stream.seek(pos)
65 self.assertEqual(imghdr.what(stream), 'png')
66 self.assertEqual(stream.tell(), pos)
67
68 def test_bad_args(self):
69 with self.assertRaises(TypeError):
70 imghdr.what()
71 with self.assertRaises(AttributeError):
72 imghdr.what(None)
73 with self.assertRaises(TypeError):
74 imghdr.what(self.testfile, 1)
75 with self.assertRaises(AttributeError):
76 imghdr.what(os.fsencode(self.testfile))
77 with open(self.testfile, 'rb') as f:
78 with self.assertRaises(AttributeError):
79 imghdr.what(f.fileno())
80
81 def test_invalid_headers(self):
82 for header in (b'\211PN\r\n',
83 b'\001\331',
84 b'\x59\xA6',
85 b'cutecat',
86 b'000000JFI',
87 b'GIF80'):
88 self.assertIsNone(imghdr.what(None, header))
89
90 def test_string_data(self):
91 with warnings.catch_warnings():
92 warnings.simplefilter("ignore", BytesWarning)
93 for filename, _ in TEST_FILES:
94 filename = findfile(filename, subdir='imghdrdata')
95 with open(filename, 'rb') as stream:
96 data = stream.read().decode('latin1')
97 with self.assertRaises(TypeError):
98 imghdr.what(io.StringIO(data))
99 with self.assertRaises(TypeError):
100 imghdr.what(None, data)
101
102 def test_missing_file(self):
103 with self.assertRaises(FileNotFoundError):
104 imghdr.what('missing')
105
106 def test_closed_file(self):
107 stream = open(self.testfile, 'rb')
108 stream.close()
109 with self.assertRaises(ValueError) as cm:
110 imghdr.what(stream)
111 stream = io.BytesIO(self.testdata)
112 stream.close()
113 with self.assertRaises(ValueError) as cm:
114 imghdr.what(stream)
115
116 def test_unseekable(self):
117 with open(TESTFN, 'wb') as stream:
118 stream.write(self.testdata)
119 with UnseekableIO(TESTFN, 'rb') as stream:
120 with self.assertRaises(io.UnsupportedOperation):
121 imghdr.what(stream)
122
123 def test_output_stream(self):
124 with open(TESTFN, 'wb') as stream:
125 stream.write(self.testdata)
126 stream.seek(0)
127 with self.assertRaises(OSError) as cm:
128 imghdr.what(stream)
129
130if __name__ == '__main__':
131 unittest.main()