blob: f54e051e94568c43559e80503284a6217d1aaf38 [file] [log] [blame]
Michael Foord345266a2012-03-14 12:24:34 -07001import unittest
2from warnings import catch_warnings
3
4from unittest.test.testmock.support import is_instance
5from unittest.mock import MagicMock, Mock, patch, sentinel, mock_open, call
6
7
8
9something = sentinel.Something
10something_else = sentinel.SomethingElse
11
12
13
14class WithTest(unittest.TestCase):
15
16 def test_with_statement(self):
17 with patch('%s.something' % __name__, sentinel.Something2):
18 self.assertEqual(something, sentinel.Something2, "unpatched")
19 self.assertEqual(something, sentinel.Something)
20
21
22 def test_with_statement_exception(self):
23 try:
24 with patch('%s.something' % __name__, sentinel.Something2):
25 self.assertEqual(something, sentinel.Something2, "unpatched")
26 raise Exception('pow')
27 except Exception:
28 pass
29 else:
30 self.fail("patch swallowed exception")
31 self.assertEqual(something, sentinel.Something)
32
33
34 def test_with_statement_as(self):
35 with patch('%s.something' % __name__) as mock_something:
36 self.assertEqual(something, mock_something, "unpatched")
37 self.assertTrue(is_instance(mock_something, MagicMock),
38 "patching wrong type")
39 self.assertEqual(something, sentinel.Something)
40
41
42 def test_patch_object_with_statement(self):
43 class Foo(object):
44 something = 'foo'
45 original = Foo.something
46 with patch.object(Foo, 'something'):
47 self.assertNotEqual(Foo.something, original, "unpatched")
48 self.assertEqual(Foo.something, original)
49
50
51 def test_with_statement_nested(self):
52 with catch_warnings(record=True):
53 with patch('%s.something' % __name__) as mock_something, patch('%s.something_else' % __name__) as mock_something_else:
54 self.assertEqual(something, mock_something, "unpatched")
55 self.assertEqual(something_else, mock_something_else,
56 "unpatched")
57
58 self.assertEqual(something, sentinel.Something)
59 self.assertEqual(something_else, sentinel.SomethingElse)
60
61
62 def test_with_statement_specified(self):
63 with patch('%s.something' % __name__, sentinel.Patched) as mock_something:
64 self.assertEqual(something, mock_something, "unpatched")
65 self.assertEqual(mock_something, sentinel.Patched, "wrong patch")
66 self.assertEqual(something, sentinel.Something)
67
68
69 def testContextManagerMocking(self):
70 mock = Mock()
71 mock.__enter__ = Mock()
72 mock.__exit__ = Mock()
73 mock.__exit__.return_value = False
74
75 with mock as m:
76 self.assertEqual(m, mock.__enter__.return_value)
77 mock.__enter__.assert_called_with()
78 mock.__exit__.assert_called_with(None, None, None)
79
80
81 def test_context_manager_with_magic_mock(self):
82 mock = MagicMock()
83
84 with self.assertRaises(TypeError):
85 with mock:
86 'foo' + 3
87 mock.__enter__.assert_called_with()
88 self.assertTrue(mock.__exit__.called)
89
90
91 def test_with_statement_same_attribute(self):
92 with patch('%s.something' % __name__, sentinel.Patched) as mock_something:
93 self.assertEqual(something, mock_something, "unpatched")
94
95 with patch('%s.something' % __name__) as mock_again:
96 self.assertEqual(something, mock_again, "unpatched")
97
98 self.assertEqual(something, mock_something,
99 "restored with wrong instance")
100
101 self.assertEqual(something, sentinel.Something, "not restored")
102
103
104 def test_with_statement_imbricated(self):
105 with patch('%s.something' % __name__) as mock_something:
106 self.assertEqual(something, mock_something, "unpatched")
107
108 with patch('%s.something_else' % __name__) as mock_something_else:
109 self.assertEqual(something_else, mock_something_else,
110 "unpatched")
111
112 self.assertEqual(something, sentinel.Something)
113 self.assertEqual(something_else, sentinel.SomethingElse)
114
115
116 def test_dict_context_manager(self):
117 foo = {}
118 with patch.dict(foo, {'a': 'b'}):
119 self.assertEqual(foo, {'a': 'b'})
120 self.assertEqual(foo, {})
121
122 with self.assertRaises(NameError):
123 with patch.dict(foo, {'a': 'b'}):
124 self.assertEqual(foo, {'a': 'b'})
125 raise NameError('Konrad')
126
127 self.assertEqual(foo, {})
128
129
130
131class TestMockOpen(unittest.TestCase):
132
133 def test_mock_open(self):
134 mock = mock_open()
135 with patch('%s.open' % __name__, mock, create=True) as patched:
136 self.assertIs(patched, mock)
137 open('foo')
138
139 mock.assert_called_once_with('foo')
140
141
142 def test_mock_open_context_manager(self):
143 mock = mock_open()
144 handle = mock.return_value
145 with patch('%s.open' % __name__, mock, create=True):
146 with open('foo') as f:
147 f.read()
148
149 expected_calls = [call('foo'), call().__enter__(), call().read(),
150 call().__exit__(None, None, None)]
151 self.assertEqual(mock.mock_calls, expected_calls)
152 self.assertIs(f, handle)
153
154
155 def test_explicit_mock(self):
156 mock = MagicMock()
157 mock_open(mock)
158
159 with patch('%s.open' % __name__, mock, create=True) as patched:
160 self.assertIs(patched, mock)
161 open('foo')
162
163 mock.assert_called_once_with('foo')
164
165
166 def test_read_data(self):
167 mock = mock_open(read_data='foo')
168 with patch('%s.open' % __name__, mock, create=True):
169 h = open('bar')
170 result = h.read()
171
172 self.assertEqual(result, 'foo')
173
174
Michael Foord04cbe0c2013-03-19 17:22:51 -0700175 def test_readline_data(self):
176 # Check that readline will return all the lines from the fake file
177 mock = mock_open(read_data='foo\nbar\nbaz\n')
178 with patch('%s.open' % __name__, mock, create=True):
179 h = open('bar')
180 line1 = h.readline()
181 line2 = h.readline()
182 line3 = h.readline()
183 self.assertEqual(line1, 'foo\n')
184 self.assertEqual(line2, 'bar\n')
185 self.assertEqual(line3, 'baz\n')
186
187 # Check that we properly emulate a file that doesn't end in a newline
188 mock = mock_open(read_data='foo')
189 with patch('%s.open' % __name__, mock, create=True):
190 h = open('bar')
191 result = h.readline()
192 self.assertEqual(result, 'foo')
193
194
195 def test_readlines_data(self):
196 # Test that emulating a file that ends in a newline character works
197 mock = mock_open(read_data='foo\nbar\nbaz\n')
198 with patch('%s.open' % __name__, mock, create=True):
199 h = open('bar')
200 result = h.readlines()
201 self.assertEqual(result, ['foo\n', 'bar\n', 'baz\n'])
202
203 # Test that files without a final newline will also be correctly
204 # emulated
205 mock = mock_open(read_data='foo\nbar\nbaz')
206 with patch('%s.open' % __name__, mock, create=True):
207 h = open('bar')
208 result = h.readlines()
209
210 self.assertEqual(result, ['foo\n', 'bar\n', 'baz'])
211
212
213 def test_mock_open_read_with_argument(self):
214 # At one point calling read with an argument was broken
215 # for mocks returned by mock_open
216 some_data = 'foo\nbar\nbaz'
217 mock = mock_open(read_data=some_data)
218 self.assertEqual(mock().read(10), some_data)
219
220
221 def test_interleaved_reads(self):
222 # Test that calling read, readline, and readlines pulls data
223 # sequentially from the data we preload with
224 mock = mock_open(read_data='foo\nbar\nbaz\n')
225 with patch('%s.open' % __name__, mock, create=True):
226 h = open('bar')
227 line1 = h.readline()
228 rest = h.readlines()
229 self.assertEqual(line1, 'foo\n')
230 self.assertEqual(rest, ['bar\n', 'baz\n'])
231
232 mock = mock_open(read_data='foo\nbar\nbaz\n')
233 with patch('%s.open' % __name__, mock, create=True):
234 h = open('bar')
235 line1 = h.readline()
236 rest = h.read()
237 self.assertEqual(line1, 'foo\n')
238 self.assertEqual(rest, 'bar\nbaz\n')
239
240
241 def test_overriding_return_values(self):
242 mock = mock_open(read_data='foo')
243 handle = mock()
244
245 handle.read.return_value = 'bar'
246 handle.readline.return_value = 'bar'
247 handle.readlines.return_value = ['bar']
248
249 self.assertEqual(handle.read(), 'bar')
250 self.assertEqual(handle.readline(), 'bar')
251 self.assertEqual(handle.readlines(), ['bar'])
252
253 # call repeatedly to check that a StopIteration is not propagated
254 self.assertEqual(handle.readline(), 'bar')
255 self.assertEqual(handle.readline(), 'bar')
256
257
Michael Foord345266a2012-03-14 12:24:34 -0700258if __name__ == '__main__':
259 unittest.main()