blob: 231916970c7fdbb2e67f5fd4561ff77244792747 [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
Victor Stinner1cae9ec2014-07-14 22:26:34 +02006import logging
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07007import os
Yury Selivanovff827f02014-02-18 18:02:19 -05008import re
Victor Stinner4271dfd2017-11-28 15:19:56 +01009import selectors
Yury Selivanov88a5bf02014-02-18 12:15:06 -050010import socket
11import socketserver
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070012import sys
Yury Selivanov88a5bf02014-02-18 12:15:06 -050013import tempfile
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070014import threading
Antoine Pitroud20afad2013-10-20 01:51:25 +020015import time
Victor Stinnerc73701d2014-06-18 01:36:32 +020016import unittest
Yury Selivanov5b8d4f92016-10-05 17:48:59 -040017import weakref
18
Victor Stinner24ba2032014-02-26 10:25:02 +010019from unittest import mock
Yury Selivanov88a5bf02014-02-18 12:15:06 -050020
21from http.server import HTTPServer
Victor Stinnerda492a82014-02-20 10:37:27 +010022from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
Yury Selivanov88a5bf02014-02-18 12:15:06 -050023
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070024try:
25 import ssl
26except ImportError: # pragma: no cover
27 ssl = None
28
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070029from . import base_events
30from . import events
Victor Stinnere6a53792014-03-06 01:00:36 +010031from . import futures
Victor Stinnere6a53792014-03-06 01:00:36 +010032from . import tasks
Victor Stinner1cae9ec2014-07-14 22:26:34 +020033from .log import logger
Victor Stinnerb9030672017-06-30 11:12:33 +020034from test import support
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070035
36
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070037def dummy_ssl_context():
38 if ssl is None:
39 return None
40 else:
Christian Heimesa170fa12017-09-15 20:27:30 +020041 return ssl.SSLContext(ssl.PROTOCOL_TLS)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070042
43
44def run_briefly(loop):
Andrew Svetlov5f841b52017-12-09 00:23:48 +020045 async def once():
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070046 pass
47 gen = once()
Victor Stinner896a25a2014-07-08 11:29:25 +020048 t = loop.create_task(gen)
Victor Stinner98b63912014-06-30 14:51:04 +020049 # Don't log a warning if the task is not done after run_until_complete().
50 # It occurs if the loop is stopped or if a task raises a BaseException.
51 t._log_destroy_pending = False
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070052 try:
53 loop.run_until_complete(t)
54 finally:
55 gen.close()
56
57
Victor Stinnere6a53792014-03-06 01:00:36 +010058def run_until(loop, pred, timeout=30):
59 deadline = time.time() + timeout
Antoine Pitroud20afad2013-10-20 01:51:25 +020060 while not pred():
61 if timeout is not None:
62 timeout = deadline - time.time()
63 if timeout <= 0:
Victor Stinnere6a53792014-03-06 01:00:36 +010064 raise futures.TimeoutError()
65 loop.run_until_complete(tasks.sleep(0.001, loop=loop))
Antoine Pitroud20afad2013-10-20 01:51:25 +020066
67
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070068def run_once(loop):
Guido van Rossum41f69f42015-11-19 13:28:47 -080069 """Legacy API to run once through the event loop.
70
71 This is the recommended pattern for test code. It will poll the
72 selector once and run all callbacks scheduled in response to I/O
73 events.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070074 """
Guido van Rossum41f69f42015-11-19 13:28:47 -080075 loop.call_soon(loop.stop)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070076 loop.run_forever()
77
78
Yury Selivanov88a5bf02014-02-18 12:15:06 -050079class SilentWSGIRequestHandler(WSGIRequestHandler):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070080
Yury Selivanov88a5bf02014-02-18 12:15:06 -050081 def get_stderr(self):
82 return io.StringIO()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070083
Yury Selivanov88a5bf02014-02-18 12:15:06 -050084 def log_message(self, format, *args):
85 pass
86
87
88class SilentWSGIServer(WSGIServer):
89
Antoine Pitroufd39a892014-10-15 16:58:21 +020090 request_timeout = 2
91
92 def get_request(self):
93 request, client_addr = super().get_request()
94 request.settimeout(self.request_timeout)
95 return request, client_addr
96
Yury Selivanov88a5bf02014-02-18 12:15:06 -050097 def handle_error(self, request, client_address):
98 pass
99
100
101class SSLWSGIServerMixin:
102
103 def finish_request(self, request, client_address):
104 # The relative location of our test directory (which
105 # contains the ssl key and certificate files) differs
106 # between the stdlib and stand-alone asyncio.
107 # Prefer our own if we can find it.
108 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
109 if not os.path.isdir(here):
110 here = os.path.join(os.path.dirname(os.__file__),
111 'test', 'test_asyncio')
112 keyfile = os.path.join(here, 'ssl_key.pem')
113 certfile = os.path.join(here, 'ssl_cert.pem')
Christian Heimesd0486372016-09-10 23:23:33 +0200114 context = ssl.SSLContext()
115 context.load_cert_chain(certfile, keyfile)
116
117 ssock = context.wrap_socket(request, server_side=True)
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500118 try:
119 self.RequestHandlerClass(ssock, client_address, self)
120 ssock.close()
121 except OSError:
122 # maybe socket has been closed by peer
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700123 pass
124
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700125
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500126class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
127 pass
128
129
130def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700131
132 def app(environ, start_response):
133 status = '200 OK'
134 headers = [('Content-type', 'text/plain')]
135 start_response(status, headers)
136 return [b'Test message']
137
138 # Run the test WSGI server in a separate thread in order not to
139 # interfere with event handling in the main thread
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500140 server_class = server_ssl_cls if use_ssl else server_cls
141 httpd = server_class(address, SilentWSGIRequestHandler)
142 httpd.set_app(app)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700143 httpd.address = httpd.server_address
Antoine Pitroufd39a892014-10-15 16:58:21 +0200144 server_thread = threading.Thread(
145 target=lambda: httpd.serve_forever(poll_interval=0.05))
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700146 server_thread.start()
147 try:
148 yield httpd
149 finally:
150 httpd.shutdown()
Antoine Pitroua7a150c2013-10-20 23:26:23 +0200151 httpd.server_close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700152 server_thread.join()
153
154
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500155if hasattr(socket, 'AF_UNIX'):
156
157 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
158
159 def server_bind(self):
160 socketserver.UnixStreamServer.server_bind(self)
161 self.server_name = '127.0.0.1'
162 self.server_port = 80
163
164
165 class UnixWSGIServer(UnixHTTPServer, WSGIServer):
166
Antoine Pitroufd39a892014-10-15 16:58:21 +0200167 request_timeout = 2
168
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500169 def server_bind(self):
170 UnixHTTPServer.server_bind(self)
171 self.setup_environ()
172
173 def get_request(self):
174 request, client_addr = super().get_request()
Antoine Pitroufd39a892014-10-15 16:58:21 +0200175 request.settimeout(self.request_timeout)
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500176 # Code in the stdlib expects that get_request
177 # will return a socket and a tuple (host, port).
178 # However, this isn't true for UNIX sockets,
179 # as the second return value will be a path;
180 # hence we return some fake data sufficient
181 # to get the tests going
182 return request, ('127.0.0.1', '')
183
184
185 class SilentUnixWSGIServer(UnixWSGIServer):
186
187 def handle_error(self, request, client_address):
188 pass
189
190
191 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
192 pass
193
194
195 def gen_unix_socket_path():
196 with tempfile.NamedTemporaryFile() as file:
197 return file.name
198
199
200 @contextlib.contextmanager
201 def unix_socket_path():
202 path = gen_unix_socket_path()
203 try:
204 yield path
205 finally:
206 try:
207 os.unlink(path)
208 except OSError:
209 pass
210
211
212 @contextlib.contextmanager
213 def run_test_unix_server(*, use_ssl=False):
214 with unix_socket_path() as path:
215 yield from _run_test_server(address=path, use_ssl=use_ssl,
216 server_cls=SilentUnixWSGIServer,
217 server_ssl_cls=UnixSSLWSGIServer)
218
219
220@contextlib.contextmanager
221def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
222 yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
223 server_cls=SilentWSGIServer,
224 server_ssl_cls=SSLWSGIServer)
225
226
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700227def make_test_protocol(base):
228 dct = {}
229 for name in dir(base):
230 if name.startswith('__') and name.endswith('__'):
231 # skip magic names
232 continue
Victor Stinnera1254972014-02-11 11:34:30 +0100233 dct[name] = MockCallback(return_value=None)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700234 return type('TestProtocol', (base,) + base.__bases__, dct)()
235
236
237class TestSelector(selectors.BaseSelector):
238
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100239 def __init__(self):
240 self.keys = {}
241
242 def register(self, fileobj, events, data=None):
243 key = selectors.SelectorKey(fileobj, 0, events, data)
244 self.keys[fileobj] = key
245 return key
246
247 def unregister(self, fileobj):
248 return self.keys.pop(fileobj)
249
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700250 def select(self, timeout):
251 return []
252
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100253 def get_map(self):
254 return self.keys
255
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700256
257class TestLoop(base_events.BaseEventLoop):
258 """Loop for unittests.
259
260 It manages self time directly.
261 If something scheduled to be executed later then
262 on next loop iteration after all ready handlers done
263 generator passed to __init__ is calling.
264
265 Generator should be like this:
266
267 def gen():
268 ...
269 when = yield ...
270 ... = yield time_advance
271
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500272 Value returned by yield is absolute time of next scheduled handler.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700273 Value passed to yield is time advance to move loop's time forward.
274 """
275
276 def __init__(self, gen=None):
277 super().__init__()
278
279 if gen is None:
280 def gen():
281 yield
282 self._check_on_close = False
283 else:
284 self._check_on_close = True
285
286 self._gen = gen()
287 next(self._gen)
288 self._time = 0
Victor Stinner06847d92014-02-11 09:03:47 +0100289 self._clock_resolution = 1e-9
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700290 self._timers = []
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700291 self._selector = TestSelector()
292
293 self.readers = {}
294 self.writers = {}
295 self.reset_counters()
296
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400297 self._transports = weakref.WeakValueDictionary()
298
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700299 def time(self):
300 return self._time
301
302 def advance_time(self, advance):
303 """Move test time forward."""
304 if advance:
305 self._time += advance
306
307 def close(self):
Victor Stinner29ad0112015-01-15 00:04:21 +0100308 super().close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700309 if self._check_on_close:
310 try:
311 self._gen.send(0)
312 except StopIteration:
313 pass
314 else: # pragma: no cover
315 raise AssertionError("Time generator is not finished")
316
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400317 def _add_reader(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500318 self.readers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700319
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400320 def _remove_reader(self, fd):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700321 self.remove_reader_count[fd] += 1
322 if fd in self.readers:
323 del self.readers[fd]
324 return True
325 else:
326 return False
327
328 def assert_reader(self, fd, callback, *args):
329 assert fd in self.readers, 'fd {} is not registered'.format(fd)
330 handle = self.readers[fd]
331 assert handle._callback == callback, '{!r} != {!r}'.format(
332 handle._callback, callback)
333 assert handle._args == args, '{!r} != {!r}'.format(
334 handle._args, args)
335
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400336 def _add_writer(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500337 self.writers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700338
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400339 def _remove_writer(self, fd):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700340 self.remove_writer_count[fd] += 1
341 if fd in self.writers:
342 del self.writers[fd]
343 return True
344 else:
345 return False
346
347 def assert_writer(self, fd, callback, *args):
348 assert fd in self.writers, 'fd {} is not registered'.format(fd)
349 handle = self.writers[fd]
350 assert handle._callback == callback, '{!r} != {!r}'.format(
351 handle._callback, callback)
352 assert handle._args == args, '{!r} != {!r}'.format(
353 handle._args, args)
354
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400355 def _ensure_fd_no_transport(self, fd):
Yury Selivanovce126292017-11-13 13:38:22 -0500356 if not isinstance(fd, int):
357 try:
358 fd = int(fd.fileno())
359 except (AttributeError, TypeError, ValueError):
360 # This code matches selectors._fileobj_to_fd function.
361 raise ValueError("Invalid file object: "
362 "{!r}".format(fd)) from None
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400363 try:
364 transport = self._transports[fd]
365 except KeyError:
366 pass
367 else:
368 raise RuntimeError(
369 'File descriptor {!r} is used by transport {!r}'.format(
370 fd, transport))
371
372 def add_reader(self, fd, callback, *args):
373 """Add a reader callback."""
374 self._ensure_fd_no_transport(fd)
375 return self._add_reader(fd, callback, *args)
376
377 def remove_reader(self, fd):
378 """Remove a reader callback."""
379 self._ensure_fd_no_transport(fd)
380 return self._remove_reader(fd)
381
382 def add_writer(self, fd, callback, *args):
383 """Add a writer callback.."""
384 self._ensure_fd_no_transport(fd)
385 return self._add_writer(fd, callback, *args)
386
387 def remove_writer(self, fd):
388 """Remove a writer callback."""
389 self._ensure_fd_no_transport(fd)
390 return self._remove_writer(fd)
391
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700392 def reset_counters(self):
393 self.remove_reader_count = collections.defaultdict(int)
394 self.remove_writer_count = collections.defaultdict(int)
395
396 def _run_once(self):
397 super()._run_once()
398 for when in self._timers:
399 advance = self._gen.send(when)
400 self.advance_time(advance)
401 self._timers = []
402
403 def call_at(self, when, callback, *args):
404 self._timers.append(when)
405 return super().call_at(when, callback, *args)
406
407 def _process_events(self, event_list):
408 return
409
410 def _write_to_self(self):
411 pass
Victor Stinnera1254972014-02-11 11:34:30 +0100412
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500413
Victor Stinnera1254972014-02-11 11:34:30 +0100414def MockCallback(**kwargs):
Victor Stinner24ba2032014-02-26 10:25:02 +0100415 return mock.Mock(spec=['__call__'], **kwargs)
Yury Selivanovff827f02014-02-18 18:02:19 -0500416
417
418class MockPattern(str):
419 """A regex based str with a fuzzy __eq__.
420
421 Use this helper with 'mock.assert_called_with', or anywhere
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500422 where a regex comparison between strings is needed.
Yury Selivanovff827f02014-02-18 18:02:19 -0500423
424 For instance:
425 mock_call.assert_called_with(MockPattern('spam.*ham'))
426 """
427 def __eq__(self, other):
428 return bool(re.search(str(self), other, re.S))
Victor Stinner307bccc2014-06-12 18:39:26 +0200429
430
431def get_function_source(func):
432 source = events._get_function_source(func)
433 if source is None:
434 raise ValueError("unable to get the source of %r" % (func,))
435 return source
Victor Stinnerc73701d2014-06-18 01:36:32 +0200436
437
438class TestCase(unittest.TestCase):
Victor Stinner16432be2017-09-01 14:46:06 +0200439 @staticmethod
440 def close_loop(loop):
441 executor = loop._default_executor
442 if executor is not None:
443 executor.shutdown(wait=True)
444 loop.close()
445
Victor Stinnerc73701d2014-06-18 01:36:32 +0200446 def set_event_loop(self, loop, *, cleanup=True):
447 assert loop is not None
448 # ensure that the event loop is passed explicitly in asyncio
449 events.set_event_loop(None)
450 if cleanup:
Victor Stinner16432be2017-09-01 14:46:06 +0200451 self.addCleanup(self.close_loop, loop)
Victor Stinnerc73701d2014-06-18 01:36:32 +0200452
453 def new_test_loop(self, gen=None):
454 loop = TestLoop(gen)
455 self.set_event_loop(loop)
456 return loop
457
Yury Selivanovba7e1f92017-03-02 20:07:11 -0500458 def unpatch_get_running_loop(self):
459 events._get_running_loop = self._get_running_loop
460
Yury Selivanov600a3492016-11-04 14:29:28 -0400461 def setUp(self):
462 self._get_running_loop = events._get_running_loop
463 events._get_running_loop = lambda: None
Victor Stinnerb9030672017-06-30 11:12:33 +0200464 self._thread_cleanup = support.threading_setup()
Yury Selivanov600a3492016-11-04 14:29:28 -0400465
Victor Stinnerc73701d2014-06-18 01:36:32 +0200466 def tearDown(self):
Yury Selivanovba7e1f92017-03-02 20:07:11 -0500467 self.unpatch_get_running_loop()
Yury Selivanov600a3492016-11-04 14:29:28 -0400468
Victor Stinnerc73701d2014-06-18 01:36:32 +0200469 events.set_event_loop(None)
Victor Stinner1cae9ec2014-07-14 22:26:34 +0200470
Victor Stinner5d44c082015-02-02 18:36:31 +0100471 # Detect CPython bug #23353: ensure that yield/yield-from is not used
472 # in an except block of a generator
473 self.assertEqual(sys.exc_info(), (None, None, None))
474
Victor Stinnerb9030672017-06-30 11:12:33 +0200475 self.doCleanups()
476 support.threading_cleanup(*self._thread_cleanup)
477 support.reap_children()
478
Victor Stinner1cae9ec2014-07-14 22:26:34 +0200479
480@contextlib.contextmanager
481def disable_logger():
482 """Context manager to disable asyncio logger.
483
484 For example, it can be used to ignore warnings in debug mode.
485 """
486 old_level = logger.level
487 try:
488 logger.setLevel(logging.CRITICAL+1)
489 yield
490 finally:
491 logger.setLevel(old_level)
Victor Stinnerb2614752014-08-25 23:20:52 +0200492
Yury Selivanovd5c2a622015-12-16 19:31:17 -0500493
494def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
495 family=socket.AF_INET):
Victor Stinnerb2614752014-08-25 23:20:52 +0200496 """Create a mock of a non-blocking socket."""
Yury Selivanovd5c2a622015-12-16 19:31:17 -0500497 sock = mock.MagicMock(socket.socket)
498 sock.proto = proto
499 sock.type = type
500 sock.family = family
Victor Stinnerb2614752014-08-25 23:20:52 +0200501 sock.gettimeout.return_value = 0.0
502 return sock