[3.8] bpo-38136: Updates await_count and call_count to be different things (GH-16192) (GH-16431)
diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py
index 30f4663..5ea5624 100644
--- a/Lib/unittest/mock.py
+++ b/Lib/unittest/mock.py
@@ -1070,14 +1070,20 @@
# can't use self in-case a function / method we are mocking uses self
# in the signature
self._mock_check_sig(*args, **kwargs)
+ self._increment_mock_call(*args, **kwargs)
return self._mock_call(*args, **kwargs)
def _mock_call(self, /, *args, **kwargs):
+ return self._execute_mock_call(*args, **kwargs)
+
+ def _increment_mock_call(self, /, *args, **kwargs):
self.called = True
self.call_count += 1
# handle call_args
+ # needs to be set here so assertions on call arguments pass before
+ # execution in the case of awaited calls
_call = _Call((args, kwargs), two=True)
self.call_args = _call
self.call_args_list.append(_call)
@@ -1117,6 +1123,10 @@
# follow the parental chain:
_new_parent = _new_parent._mock_new_parent
+ def _execute_mock_call(self, /, *args, **kwargs):
+ # seperate from _increment_mock_call so that awaited functions are
+ # executed seperately from their call
+
effect = self.side_effect
if effect is not None:
if _is_exception(effect):
diff --git a/Lib/unittest/test/testmock/testasync.py b/Lib/unittest/test/testmock/testasync.py
index af53210..86b0d0e 100644
--- a/Lib/unittest/test/testmock/testasync.py
+++ b/Lib/unittest/test/testmock/testasync.py
@@ -3,8 +3,8 @@
import re
import unittest
-from unittest.mock import (call, AsyncMock, patch, MagicMock, create_autospec,
- _AwaitEvent)
+from unittest.mock import (ANY, call, AsyncMock, patch, MagicMock,
+ create_autospec, _AwaitEvent, sentinel, _CallList)
def tearDownModule():
@@ -591,11 +591,173 @@
def setUp(self):
self.mock = AsyncMock()
- async def _runnable_test(self, *args):
- if not args:
- await self.mock()
- else:
- await self.mock(*args)
+ async def _runnable_test(self, *args, **kwargs):
+ await self.mock(*args, **kwargs)
+
+ async def _await_coroutine(self, coroutine):
+ return await coroutine
+
+ def test_assert_called_but_not_awaited(self):
+ mock = AsyncMock(AsyncClass)
+ with self.assertWarns(RuntimeWarning):
+ # Will raise a warning because never awaited
+ mock.async_method()
+ self.assertTrue(asyncio.iscoroutinefunction(mock.async_method))
+ mock.async_method.assert_called()
+ mock.async_method.assert_called_once()
+ mock.async_method.assert_called_once_with()
+ with self.assertRaises(AssertionError):
+ mock.assert_awaited()
+ with self.assertRaises(AssertionError):
+ mock.async_method.assert_awaited()
+
+ def test_assert_called_then_awaited(self):
+ mock = AsyncMock(AsyncClass)
+ mock_coroutine = mock.async_method()
+ mock.async_method.assert_called()
+ mock.async_method.assert_called_once()
+ mock.async_method.assert_called_once_with()
+ with self.assertRaises(AssertionError):
+ mock.async_method.assert_awaited()
+
+ asyncio.run(self._await_coroutine(mock_coroutine))
+ # Assert we haven't re-called the function
+ mock.async_method.assert_called_once()
+ mock.async_method.assert_awaited()
+ mock.async_method.assert_awaited_once()
+ mock.async_method.assert_awaited_once_with()
+
+ def test_assert_called_and_awaited_at_same_time(self):
+ with self.assertRaises(AssertionError):
+ self.mock.assert_awaited()
+
+ with self.assertRaises(AssertionError):
+ self.mock.assert_called()
+
+ asyncio.run(self._runnable_test())
+ self.mock.assert_called_once()
+ self.mock.assert_awaited_once()
+
+ def test_assert_called_twice_and_awaited_once(self):
+ mock = AsyncMock(AsyncClass)
+ coroutine = mock.async_method()
+ with self.assertWarns(RuntimeWarning):
+ # The first call will be awaited so no warning there
+ # But this call will never get awaited, so it will warn here
+ mock.async_method()
+ with self.assertRaises(AssertionError):
+ mock.async_method.assert_awaited()
+ mock.async_method.assert_called()
+ asyncio.run(self._await_coroutine(coroutine))
+ mock.async_method.assert_awaited()
+ mock.async_method.assert_awaited_once()
+
+ def test_assert_called_once_and_awaited_twice(self):
+ mock = AsyncMock(AsyncClass)
+ coroutine = mock.async_method()
+ mock.async_method.assert_called_once()
+ asyncio.run(self._await_coroutine(coroutine))
+ with self.assertRaises(RuntimeError):
+ # Cannot reuse already awaited coroutine
+ asyncio.run(self._await_coroutine(coroutine))
+ mock.async_method.assert_awaited()
+
+ def test_assert_awaited_but_not_called(self):
+ with self.assertRaises(AssertionError):
+ self.mock.assert_awaited()
+ with self.assertRaises(AssertionError):
+ self.mock.assert_called()
+ with self.assertRaises(TypeError):
+ # You cannot await an AsyncMock, it must be a coroutine
+ asyncio.run(self._await_coroutine(self.mock))
+
+ with self.assertRaises(AssertionError):
+ self.mock.assert_awaited()
+ with self.assertRaises(AssertionError):
+ self.mock.assert_called()
+
+ def test_assert_has_calls_not_awaits(self):
+ kalls = [call('foo')]
+ with self.assertWarns(RuntimeWarning):
+ # Will raise a warning because never awaited
+ self.mock('foo')
+ self.mock.assert_has_calls(kalls)
+ with self.assertRaises(AssertionError):
+ self.mock.assert_has_awaits(kalls)
+
+ def test_assert_has_mock_calls_on_async_mock_no_spec(self):
+ with self.assertWarns(RuntimeWarning):
+ # Will raise a warning because never awaited
+ self.mock()
+ kalls_empty = [('', (), {})]
+ self.assertEqual(self.mock.mock_calls, kalls_empty)
+
+ with self.assertWarns(RuntimeWarning):
+ # Will raise a warning because never awaited
+ self.mock('foo')
+ self.mock('baz')
+ mock_kalls = ([call(), call('foo'), call('baz')])
+ self.assertEqual(self.mock.mock_calls, mock_kalls)
+
+ def test_assert_has_mock_calls_on_async_mock_with_spec(self):
+ a_class_mock = AsyncMock(AsyncClass)
+ with self.assertWarns(RuntimeWarning):
+ # Will raise a warning because never awaited
+ a_class_mock.async_method()
+ kalls_empty = [('', (), {})]
+ self.assertEqual(a_class_mock.async_method.mock_calls, kalls_empty)
+ self.assertEqual(a_class_mock.mock_calls, [call.async_method()])
+
+ with self.assertWarns(RuntimeWarning):
+ # Will raise a warning because never awaited
+ a_class_mock.async_method(1, 2, 3, a=4, b=5)
+ method_kalls = [call(), call(1, 2, 3, a=4, b=5)]
+ mock_kalls = [call.async_method(), call.async_method(1, 2, 3, a=4, b=5)]
+ self.assertEqual(a_class_mock.async_method.mock_calls, method_kalls)
+ self.assertEqual(a_class_mock.mock_calls, mock_kalls)
+
+ def test_async_method_calls_recorded(self):
+ with self.assertWarns(RuntimeWarning):
+ # Will raise warnings because never awaited
+ self.mock.something(3, fish=None)
+ self.mock.something_else.something(6, cake=sentinel.Cake)
+
+ self.assertEqual(self.mock.method_calls, [
+ ("something", (3,), {'fish': None}),
+ ("something_else.something", (6,), {'cake': sentinel.Cake})
+ ],
+ "method calls not recorded correctly")
+ self.assertEqual(self.mock.something_else.method_calls,
+ [("something", (6,), {'cake': sentinel.Cake})],
+ "method calls not recorded correctly")
+
+ def test_async_arg_lists(self):
+ def assert_attrs(mock):
+ names = ('call_args_list', 'method_calls', 'mock_calls')
+ for name in names:
+ attr = getattr(mock, name)
+ self.assertIsInstance(attr, _CallList)
+ self.assertIsInstance(attr, list)
+ self.assertEqual(attr, [])
+
+ assert_attrs(self.mock)
+ with self.assertWarns(RuntimeWarning):
+ # Will raise warnings because never awaited
+ self.mock()
+ self.mock(1, 2)
+ self.mock(a=3)
+
+ self.mock.reset_mock()
+ assert_attrs(self.mock)
+
+ a_mock = AsyncMock(AsyncClass)
+ with self.assertWarns(RuntimeWarning):
+ # Will raise warnings because never awaited
+ a_mock.async_method()
+ a_mock.async_method(1, a=3)
+
+ a_mock.reset_mock()
+ assert_attrs(a_mock)
def test_assert_awaited(self):
with self.assertRaises(AssertionError):
@@ -641,20 +803,20 @@
def test_assert_any_wait(self):
with self.assertRaises(AssertionError):
- self.mock.assert_any_await('NormalFoo')
+ self.mock.assert_any_await('foo')
+
+ asyncio.run(self._runnable_test('baz'))
+ with self.assertRaises(AssertionError):
+ self.mock.assert_any_await('foo')
asyncio.run(self._runnable_test('foo'))
- with self.assertRaises(AssertionError):
- self.mock.assert_any_await('NormalFoo')
-
- asyncio.run(self._runnable_test('NormalFoo'))
- self.mock.assert_any_await('NormalFoo')
+ self.mock.assert_any_await('foo')
asyncio.run(self._runnable_test('SomethingElse'))
- self.mock.assert_any_await('NormalFoo')
+ self.mock.assert_any_await('foo')
def test_assert_has_awaits_no_order(self):
- calls = [call('NormalFoo'), call('baz')]
+ calls = [call('foo'), call('baz')]
with self.assertRaises(AssertionError) as cm:
self.mock.assert_has_awaits(calls)
@@ -664,7 +826,7 @@
with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls)
- asyncio.run(self._runnable_test('NormalFoo'))
+ asyncio.run(self._runnable_test('foo'))
with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls)
@@ -675,7 +837,7 @@
self.mock.assert_has_awaits(calls)
def test_assert_has_awaits_ordered(self):
- calls = [call('NormalFoo'), call('baz')]
+ calls = [call('foo'), call('baz')]
with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls, any_order=True)
@@ -683,11 +845,11 @@
with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls, any_order=True)
- asyncio.run(self._runnable_test('foo'))
+ asyncio.run(self._runnable_test('bamf'))
with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls, any_order=True)
- asyncio.run(self._runnable_test('NormalFoo'))
+ asyncio.run(self._runnable_test('foo'))
self.mock.assert_has_awaits(calls, any_order=True)
asyncio.run(self._runnable_test('qux'))
diff --git a/Lib/unittest/test/testmock/testmock.py b/Lib/unittest/test/testmock/testmock.py
index 817c548..6dc2725 100644
--- a/Lib/unittest/test/testmock/testmock.py
+++ b/Lib/unittest/test/testmock/testmock.py
@@ -842,6 +842,7 @@
def test_setting_call(self):
mock = Mock()
def __call__(self, a):
+ self._increment_mock_call(a)
return self._mock_call(a)
type(mock).__call__ = __call__
@@ -2043,7 +2044,7 @@
)
mocks = [
- Mock, MagicMock, NonCallableMock, NonCallableMagicMock
+ Mock, MagicMock, NonCallableMock, NonCallableMagicMock, AsyncMock
]
for mock in mocks: