blob: 32d3b0bf630849b8a003779a18d4cec2a59ca57b [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
Victor Stinner1cae9ec2014-07-14 22:26:34 +02006import logging
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07007import os
Yury Selivanovff827f02014-02-18 18:02:19 -05008import re
Victor Stinner4271dfd2017-11-28 15:19:56 +01009import selectors
Yury Selivanov88a5bf02014-02-18 12:15:06 -050010import socket
11import socketserver
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070012import sys
Yury Selivanov88a5bf02014-02-18 12:15:06 -050013import tempfile
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070014import threading
Antoine Pitroud20afad2013-10-20 01:51:25 +020015import time
Victor Stinnerc73701d2014-06-18 01:36:32 +020016import unittest
Yury Selivanov5b8d4f92016-10-05 17:48:59 -040017import weakref
18
Victor Stinner24ba2032014-02-26 10:25:02 +010019from unittest import mock
Yury Selivanov88a5bf02014-02-18 12:15:06 -050020
21from http.server import HTTPServer
Victor Stinnerda492a82014-02-20 10:37:27 +010022from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
Yury Selivanov88a5bf02014-02-18 12:15:06 -050023
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070024try:
25 import ssl
26except ImportError: # pragma: no cover
27 ssl = None
28
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070029from . import base_events
30from . import events
Victor Stinnere6a53792014-03-06 01:00:36 +010031from . import futures
Victor Stinnere6a53792014-03-06 01:00:36 +010032from . import tasks
Victor Stinnerf951d282014-06-29 00:46:45 +020033from .coroutines import coroutine
Victor Stinner1cae9ec2014-07-14 22:26:34 +020034from .log import logger
Victor Stinnerb9030672017-06-30 11:12:33 +020035from test import support
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070036
37
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070038def dummy_ssl_context():
39 if ssl is None:
40 return None
41 else:
Christian Heimesa170fa12017-09-15 20:27:30 +020042 return ssl.SSLContext(ssl.PROTOCOL_TLS)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070043
44
45def run_briefly(loop):
Victor Stinnerf951d282014-06-29 00:46:45 +020046 @coroutine
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070047 def once():
48 pass
49 gen = once()
Victor Stinner896a25a2014-07-08 11:29:25 +020050 t = loop.create_task(gen)
Victor Stinner98b63912014-06-30 14:51:04 +020051 # Don't log a warning if the task is not done after run_until_complete().
52 # It occurs if the loop is stopped or if a task raises a BaseException.
53 t._log_destroy_pending = False
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070054 try:
55 loop.run_until_complete(t)
56 finally:
57 gen.close()
58
59
Victor Stinnere6a53792014-03-06 01:00:36 +010060def run_until(loop, pred, timeout=30):
61 deadline = time.time() + timeout
Antoine Pitroud20afad2013-10-20 01:51:25 +020062 while not pred():
63 if timeout is not None:
64 timeout = deadline - time.time()
65 if timeout <= 0:
Victor Stinnere6a53792014-03-06 01:00:36 +010066 raise futures.TimeoutError()
67 loop.run_until_complete(tasks.sleep(0.001, loop=loop))
Antoine Pitroud20afad2013-10-20 01:51:25 +020068
69
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070070def run_once(loop):
Guido van Rossum41f69f42015-11-19 13:28:47 -080071 """Legacy API to run once through the event loop.
72
73 This is the recommended pattern for test code. It will poll the
74 selector once and run all callbacks scheduled in response to I/O
75 events.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070076 """
Guido van Rossum41f69f42015-11-19 13:28:47 -080077 loop.call_soon(loop.stop)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070078 loop.run_forever()
79
80
Yury Selivanov88a5bf02014-02-18 12:15:06 -050081class SilentWSGIRequestHandler(WSGIRequestHandler):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070082
Yury Selivanov88a5bf02014-02-18 12:15:06 -050083 def get_stderr(self):
84 return io.StringIO()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070085
Yury Selivanov88a5bf02014-02-18 12:15:06 -050086 def log_message(self, format, *args):
87 pass
88
89
90class SilentWSGIServer(WSGIServer):
91
Antoine Pitroufd39a892014-10-15 16:58:21 +020092 request_timeout = 2
93
94 def get_request(self):
95 request, client_addr = super().get_request()
96 request.settimeout(self.request_timeout)
97 return request, client_addr
98
Yury Selivanov88a5bf02014-02-18 12:15:06 -050099 def handle_error(self, request, client_address):
100 pass
101
102
103class SSLWSGIServerMixin:
104
105 def finish_request(self, request, client_address):
106 # The relative location of our test directory (which
107 # contains the ssl key and certificate files) differs
108 # between the stdlib and stand-alone asyncio.
109 # Prefer our own if we can find it.
110 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
111 if not os.path.isdir(here):
112 here = os.path.join(os.path.dirname(os.__file__),
113 'test', 'test_asyncio')
114 keyfile = os.path.join(here, 'ssl_key.pem')
115 certfile = os.path.join(here, 'ssl_cert.pem')
Christian Heimesd0486372016-09-10 23:23:33 +0200116 context = ssl.SSLContext()
117 context.load_cert_chain(certfile, keyfile)
118
119 ssock = context.wrap_socket(request, server_side=True)
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500120 try:
121 self.RequestHandlerClass(ssock, client_address, self)
122 ssock.close()
123 except OSError:
124 # maybe socket has been closed by peer
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700125 pass
126
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700127
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500128class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
129 pass
130
131
132def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700133
134 def app(environ, start_response):
135 status = '200 OK'
136 headers = [('Content-type', 'text/plain')]
137 start_response(status, headers)
138 return [b'Test message']
139
140 # Run the test WSGI server in a separate thread in order not to
141 # interfere with event handling in the main thread
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500142 server_class = server_ssl_cls if use_ssl else server_cls
143 httpd = server_class(address, SilentWSGIRequestHandler)
144 httpd.set_app(app)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700145 httpd.address = httpd.server_address
Antoine Pitroufd39a892014-10-15 16:58:21 +0200146 server_thread = threading.Thread(
147 target=lambda: httpd.serve_forever(poll_interval=0.05))
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700148 server_thread.start()
149 try:
150 yield httpd
151 finally:
152 httpd.shutdown()
Antoine Pitroua7a150c2013-10-20 23:26:23 +0200153 httpd.server_close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700154 server_thread.join()
155
156
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500157if hasattr(socket, 'AF_UNIX'):
158
159 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
160
161 def server_bind(self):
162 socketserver.UnixStreamServer.server_bind(self)
163 self.server_name = '127.0.0.1'
164 self.server_port = 80
165
166
167 class UnixWSGIServer(UnixHTTPServer, WSGIServer):
168
Antoine Pitroufd39a892014-10-15 16:58:21 +0200169 request_timeout = 2
170
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500171 def server_bind(self):
172 UnixHTTPServer.server_bind(self)
173 self.setup_environ()
174
175 def get_request(self):
176 request, client_addr = super().get_request()
Antoine Pitroufd39a892014-10-15 16:58:21 +0200177 request.settimeout(self.request_timeout)
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500178 # Code in the stdlib expects that get_request
179 # will return a socket and a tuple (host, port).
180 # However, this isn't true for UNIX sockets,
181 # as the second return value will be a path;
182 # hence we return some fake data sufficient
183 # to get the tests going
184 return request, ('127.0.0.1', '')
185
186
187 class SilentUnixWSGIServer(UnixWSGIServer):
188
189 def handle_error(self, request, client_address):
190 pass
191
192
193 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
194 pass
195
196
197 def gen_unix_socket_path():
198 with tempfile.NamedTemporaryFile() as file:
199 return file.name
200
201
202 @contextlib.contextmanager
203 def unix_socket_path():
204 path = gen_unix_socket_path()
205 try:
206 yield path
207 finally:
208 try:
209 os.unlink(path)
210 except OSError:
211 pass
212
213
214 @contextlib.contextmanager
215 def run_test_unix_server(*, use_ssl=False):
216 with unix_socket_path() as path:
217 yield from _run_test_server(address=path, use_ssl=use_ssl,
218 server_cls=SilentUnixWSGIServer,
219 server_ssl_cls=UnixSSLWSGIServer)
220
221
222@contextlib.contextmanager
223def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
224 yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
225 server_cls=SilentWSGIServer,
226 server_ssl_cls=SSLWSGIServer)
227
228
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700229def make_test_protocol(base):
230 dct = {}
231 for name in dir(base):
232 if name.startswith('__') and name.endswith('__'):
233 # skip magic names
234 continue
Victor Stinnera1254972014-02-11 11:34:30 +0100235 dct[name] = MockCallback(return_value=None)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700236 return type('TestProtocol', (base,) + base.__bases__, dct)()
237
238
239class TestSelector(selectors.BaseSelector):
240
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100241 def __init__(self):
242 self.keys = {}
243
244 def register(self, fileobj, events, data=None):
245 key = selectors.SelectorKey(fileobj, 0, events, data)
246 self.keys[fileobj] = key
247 return key
248
249 def unregister(self, fileobj):
250 return self.keys.pop(fileobj)
251
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700252 def select(self, timeout):
253 return []
254
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100255 def get_map(self):
256 return self.keys
257
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700258
259class TestLoop(base_events.BaseEventLoop):
260 """Loop for unittests.
261
262 It manages self time directly.
263 If something scheduled to be executed later then
264 on next loop iteration after all ready handlers done
265 generator passed to __init__ is calling.
266
267 Generator should be like this:
268
269 def gen():
270 ...
271 when = yield ...
272 ... = yield time_advance
273
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500274 Value returned by yield is absolute time of next scheduled handler.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700275 Value passed to yield is time advance to move loop's time forward.
276 """
277
278 def __init__(self, gen=None):
279 super().__init__()
280
281 if gen is None:
282 def gen():
283 yield
284 self._check_on_close = False
285 else:
286 self._check_on_close = True
287
288 self._gen = gen()
289 next(self._gen)
290 self._time = 0
Victor Stinner06847d92014-02-11 09:03:47 +0100291 self._clock_resolution = 1e-9
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700292 self._timers = []
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700293 self._selector = TestSelector()
294
295 self.readers = {}
296 self.writers = {}
297 self.reset_counters()
298
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400299 self._transports = weakref.WeakValueDictionary()
300
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700301 def time(self):
302 return self._time
303
304 def advance_time(self, advance):
305 """Move test time forward."""
306 if advance:
307 self._time += advance
308
309 def close(self):
Victor Stinner29ad0112015-01-15 00:04:21 +0100310 super().close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700311 if self._check_on_close:
312 try:
313 self._gen.send(0)
314 except StopIteration:
315 pass
316 else: # pragma: no cover
317 raise AssertionError("Time generator is not finished")
318
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400319 def _add_reader(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500320 self.readers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700321
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400322 def _remove_reader(self, fd):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700323 self.remove_reader_count[fd] += 1
324 if fd in self.readers:
325 del self.readers[fd]
326 return True
327 else:
328 return False
329
330 def assert_reader(self, fd, callback, *args):
331 assert fd in self.readers, 'fd {} is not registered'.format(fd)
332 handle = self.readers[fd]
333 assert handle._callback == callback, '{!r} != {!r}'.format(
334 handle._callback, callback)
335 assert handle._args == args, '{!r} != {!r}'.format(
336 handle._args, args)
337
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400338 def _add_writer(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500339 self.writers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700340
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400341 def _remove_writer(self, fd):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700342 self.remove_writer_count[fd] += 1
343 if fd in self.writers:
344 del self.writers[fd]
345 return True
346 else:
347 return False
348
349 def assert_writer(self, fd, callback, *args):
350 assert fd in self.writers, 'fd {} is not registered'.format(fd)
351 handle = self.writers[fd]
352 assert handle._callback == callback, '{!r} != {!r}'.format(
353 handle._callback, callback)
354 assert handle._args == args, '{!r} != {!r}'.format(
355 handle._args, args)
356
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400357 def _ensure_fd_no_transport(self, fd):
Yury Selivanovce126292017-11-13 13:38:22 -0500358 if not isinstance(fd, int):
359 try:
360 fd = int(fd.fileno())
361 except (AttributeError, TypeError, ValueError):
362 # This code matches selectors._fileobj_to_fd function.
363 raise ValueError("Invalid file object: "
364 "{!r}".format(fd)) from None
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400365 try:
366 transport = self._transports[fd]
367 except KeyError:
368 pass
369 else:
370 raise RuntimeError(
371 'File descriptor {!r} is used by transport {!r}'.format(
372 fd, transport))
373
374 def add_reader(self, fd, callback, *args):
375 """Add a reader callback."""
376 self._ensure_fd_no_transport(fd)
377 return self._add_reader(fd, callback, *args)
378
379 def remove_reader(self, fd):
380 """Remove a reader callback."""
381 self._ensure_fd_no_transport(fd)
382 return self._remove_reader(fd)
383
384 def add_writer(self, fd, callback, *args):
385 """Add a writer callback.."""
386 self._ensure_fd_no_transport(fd)
387 return self._add_writer(fd, callback, *args)
388
389 def remove_writer(self, fd):
390 """Remove a writer callback."""
391 self._ensure_fd_no_transport(fd)
392 return self._remove_writer(fd)
393
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700394 def reset_counters(self):
395 self.remove_reader_count = collections.defaultdict(int)
396 self.remove_writer_count = collections.defaultdict(int)
397
398 def _run_once(self):
399 super()._run_once()
400 for when in self._timers:
401 advance = self._gen.send(when)
402 self.advance_time(advance)
403 self._timers = []
404
405 def call_at(self, when, callback, *args):
406 self._timers.append(when)
407 return super().call_at(when, callback, *args)
408
409 def _process_events(self, event_list):
410 return
411
412 def _write_to_self(self):
413 pass
Victor Stinnera1254972014-02-11 11:34:30 +0100414
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500415
Victor Stinnera1254972014-02-11 11:34:30 +0100416def MockCallback(**kwargs):
Victor Stinner24ba2032014-02-26 10:25:02 +0100417 return mock.Mock(spec=['__call__'], **kwargs)
Yury Selivanovff827f02014-02-18 18:02:19 -0500418
419
420class MockPattern(str):
421 """A regex based str with a fuzzy __eq__.
422
423 Use this helper with 'mock.assert_called_with', or anywhere
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500424 where a regex comparison between strings is needed.
Yury Selivanovff827f02014-02-18 18:02:19 -0500425
426 For instance:
427 mock_call.assert_called_with(MockPattern('spam.*ham'))
428 """
429 def __eq__(self, other):
430 return bool(re.search(str(self), other, re.S))
Victor Stinner307bccc2014-06-12 18:39:26 +0200431
432
433def get_function_source(func):
434 source = events._get_function_source(func)
435 if source is None:
436 raise ValueError("unable to get the source of %r" % (func,))
437 return source
Victor Stinnerc73701d2014-06-18 01:36:32 +0200438
439
440class TestCase(unittest.TestCase):
Victor Stinner16432be2017-09-01 14:46:06 +0200441 @staticmethod
442 def close_loop(loop):
443 executor = loop._default_executor
444 if executor is not None:
445 executor.shutdown(wait=True)
446 loop.close()
447
Victor Stinnerc73701d2014-06-18 01:36:32 +0200448 def set_event_loop(self, loop, *, cleanup=True):
449 assert loop is not None
450 # ensure that the event loop is passed explicitly in asyncio
451 events.set_event_loop(None)
452 if cleanup:
Victor Stinner16432be2017-09-01 14:46:06 +0200453 self.addCleanup(self.close_loop, loop)
Victor Stinnerc73701d2014-06-18 01:36:32 +0200454
455 def new_test_loop(self, gen=None):
456 loop = TestLoop(gen)
457 self.set_event_loop(loop)
458 return loop
459
Yury Selivanovba7e1f92017-03-02 20:07:11 -0500460 def unpatch_get_running_loop(self):
461 events._get_running_loop = self._get_running_loop
462
Yury Selivanov600a3492016-11-04 14:29:28 -0400463 def setUp(self):
464 self._get_running_loop = events._get_running_loop
465 events._get_running_loop = lambda: None
Victor Stinnerb9030672017-06-30 11:12:33 +0200466 self._thread_cleanup = support.threading_setup()
Yury Selivanov600a3492016-11-04 14:29:28 -0400467
Victor Stinnerc73701d2014-06-18 01:36:32 +0200468 def tearDown(self):
Yury Selivanovba7e1f92017-03-02 20:07:11 -0500469 self.unpatch_get_running_loop()
Yury Selivanov600a3492016-11-04 14:29:28 -0400470
Victor Stinnerc73701d2014-06-18 01:36:32 +0200471 events.set_event_loop(None)
Victor Stinner1cae9ec2014-07-14 22:26:34 +0200472
Victor Stinner5d44c082015-02-02 18:36:31 +0100473 # Detect CPython bug #23353: ensure that yield/yield-from is not used
474 # in an except block of a generator
475 self.assertEqual(sys.exc_info(), (None, None, None))
476
Victor Stinnerb9030672017-06-30 11:12:33 +0200477 self.doCleanups()
478 support.threading_cleanup(*self._thread_cleanup)
479 support.reap_children()
480
Victor Stinner1cae9ec2014-07-14 22:26:34 +0200481
482@contextlib.contextmanager
483def disable_logger():
484 """Context manager to disable asyncio logger.
485
486 For example, it can be used to ignore warnings in debug mode.
487 """
488 old_level = logger.level
489 try:
490 logger.setLevel(logging.CRITICAL+1)
491 yield
492 finally:
493 logger.setLevel(old_level)
Victor Stinnerb2614752014-08-25 23:20:52 +0200494
Yury Selivanovd5c2a622015-12-16 19:31:17 -0500495
496def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
497 family=socket.AF_INET):
Victor Stinnerb2614752014-08-25 23:20:52 +0200498 """Create a mock of a non-blocking socket."""
Yury Selivanovd5c2a622015-12-16 19:31:17 -0500499 sock = mock.MagicMock(socket.socket)
500 sock.proto = proto
501 sock.type = type
502 sock.family = family
Victor Stinnerb2614752014-08-25 23:20:52 +0200503 sock.gettimeout.return_value = 0.0
504 return sock