blob: f797b2f0bed4812a68c2907d63b1a7fe6cab7686 [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
Victor Stinner1cae9ec2014-07-14 22:26:34 +02006import logging
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07007import os
Yury Selivanovff827f02014-02-18 18:02:19 -05008import re
Victor Stinner4271dfd2017-11-28 15:19:56 +01009import selectors
Yury Selivanov88a5bf02014-02-18 12:15:06 -050010import socket
11import socketserver
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070012import sys
Yury Selivanov88a5bf02014-02-18 12:15:06 -050013import tempfile
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070014import threading
Antoine Pitroud20afad2013-10-20 01:51:25 +020015import time
Victor Stinnerc73701d2014-06-18 01:36:32 +020016import unittest
Yury Selivanov5b8d4f92016-10-05 17:48:59 -040017import weakref
18
Victor Stinner24ba2032014-02-26 10:25:02 +010019from unittest import mock
Yury Selivanov88a5bf02014-02-18 12:15:06 -050020
21from http.server import HTTPServer
Victor Stinnerda492a82014-02-20 10:37:27 +010022from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
Yury Selivanov88a5bf02014-02-18 12:15:06 -050023
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070024try:
25 import ssl
26except ImportError: # pragma: no cover
27 ssl = None
28
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070029from . import base_events
30from . import events
Victor Stinnere6a53792014-03-06 01:00:36 +010031from . import futures
Victor Stinnere6a53792014-03-06 01:00:36 +010032from . import tasks
Victor Stinnerf951d282014-06-29 00:46:45 +020033from .coroutines import coroutine
Victor Stinner1cae9ec2014-07-14 22:26:34 +020034from .log import logger
Victor Stinnerb9030672017-06-30 11:12:33 +020035from test import support
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070036
37
38if sys.platform == 'win32': # pragma: no cover
39 from .windows_utils import socketpair
40else:
41 from socket import socketpair # pragma: no cover
42
43
44def dummy_ssl_context():
45 if ssl is None:
46 return None
47 else:
Christian Heimesa170fa12017-09-15 20:27:30 +020048 return ssl.SSLContext(ssl.PROTOCOL_TLS)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070049
50
51def run_briefly(loop):
Victor Stinnerf951d282014-06-29 00:46:45 +020052 @coroutine
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070053 def once():
54 pass
55 gen = once()
Victor Stinner896a25a2014-07-08 11:29:25 +020056 t = loop.create_task(gen)
Victor Stinner98b63912014-06-30 14:51:04 +020057 # Don't log a warning if the task is not done after run_until_complete().
58 # It occurs if the loop is stopped or if a task raises a BaseException.
59 t._log_destroy_pending = False
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070060 try:
61 loop.run_until_complete(t)
62 finally:
63 gen.close()
64
65
Victor Stinnere6a53792014-03-06 01:00:36 +010066def run_until(loop, pred, timeout=30):
67 deadline = time.time() + timeout
Antoine Pitroud20afad2013-10-20 01:51:25 +020068 while not pred():
69 if timeout is not None:
70 timeout = deadline - time.time()
71 if timeout <= 0:
Victor Stinnere6a53792014-03-06 01:00:36 +010072 raise futures.TimeoutError()
73 loop.run_until_complete(tasks.sleep(0.001, loop=loop))
Antoine Pitroud20afad2013-10-20 01:51:25 +020074
75
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070076def run_once(loop):
Guido van Rossum41f69f42015-11-19 13:28:47 -080077 """Legacy API to run once through the event loop.
78
79 This is the recommended pattern for test code. It will poll the
80 selector once and run all callbacks scheduled in response to I/O
81 events.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070082 """
Guido van Rossum41f69f42015-11-19 13:28:47 -080083 loop.call_soon(loop.stop)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070084 loop.run_forever()
85
86
Yury Selivanov88a5bf02014-02-18 12:15:06 -050087class SilentWSGIRequestHandler(WSGIRequestHandler):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070088
Yury Selivanov88a5bf02014-02-18 12:15:06 -050089 def get_stderr(self):
90 return io.StringIO()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070091
Yury Selivanov88a5bf02014-02-18 12:15:06 -050092 def log_message(self, format, *args):
93 pass
94
95
96class SilentWSGIServer(WSGIServer):
97
Antoine Pitroufd39a892014-10-15 16:58:21 +020098 request_timeout = 2
99
100 def get_request(self):
101 request, client_addr = super().get_request()
102 request.settimeout(self.request_timeout)
103 return request, client_addr
104
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500105 def handle_error(self, request, client_address):
106 pass
107
108
109class SSLWSGIServerMixin:
110
111 def finish_request(self, request, client_address):
112 # The relative location of our test directory (which
113 # contains the ssl key and certificate files) differs
114 # between the stdlib and stand-alone asyncio.
115 # Prefer our own if we can find it.
116 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
117 if not os.path.isdir(here):
118 here = os.path.join(os.path.dirname(os.__file__),
119 'test', 'test_asyncio')
120 keyfile = os.path.join(here, 'ssl_key.pem')
121 certfile = os.path.join(here, 'ssl_cert.pem')
Christian Heimesd0486372016-09-10 23:23:33 +0200122 context = ssl.SSLContext()
123 context.load_cert_chain(certfile, keyfile)
124
125 ssock = context.wrap_socket(request, server_side=True)
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500126 try:
127 self.RequestHandlerClass(ssock, client_address, self)
128 ssock.close()
129 except OSError:
130 # maybe socket has been closed by peer
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700131 pass
132
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700133
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500134class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
135 pass
136
137
138def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700139
140 def app(environ, start_response):
141 status = '200 OK'
142 headers = [('Content-type', 'text/plain')]
143 start_response(status, headers)
144 return [b'Test message']
145
146 # Run the test WSGI server in a separate thread in order not to
147 # interfere with event handling in the main thread
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500148 server_class = server_ssl_cls if use_ssl else server_cls
149 httpd = server_class(address, SilentWSGIRequestHandler)
150 httpd.set_app(app)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700151 httpd.address = httpd.server_address
Antoine Pitroufd39a892014-10-15 16:58:21 +0200152 server_thread = threading.Thread(
153 target=lambda: httpd.serve_forever(poll_interval=0.05))
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700154 server_thread.start()
155 try:
156 yield httpd
157 finally:
158 httpd.shutdown()
Antoine Pitroua7a150c2013-10-20 23:26:23 +0200159 httpd.server_close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700160 server_thread.join()
161
162
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500163if hasattr(socket, 'AF_UNIX'):
164
165 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
166
167 def server_bind(self):
168 socketserver.UnixStreamServer.server_bind(self)
169 self.server_name = '127.0.0.1'
170 self.server_port = 80
171
172
173 class UnixWSGIServer(UnixHTTPServer, WSGIServer):
174
Antoine Pitroufd39a892014-10-15 16:58:21 +0200175 request_timeout = 2
176
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500177 def server_bind(self):
178 UnixHTTPServer.server_bind(self)
179 self.setup_environ()
180
181 def get_request(self):
182 request, client_addr = super().get_request()
Antoine Pitroufd39a892014-10-15 16:58:21 +0200183 request.settimeout(self.request_timeout)
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500184 # Code in the stdlib expects that get_request
185 # will return a socket and a tuple (host, port).
186 # However, this isn't true for UNIX sockets,
187 # as the second return value will be a path;
188 # hence we return some fake data sufficient
189 # to get the tests going
190 return request, ('127.0.0.1', '')
191
192
193 class SilentUnixWSGIServer(UnixWSGIServer):
194
195 def handle_error(self, request, client_address):
196 pass
197
198
199 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
200 pass
201
202
203 def gen_unix_socket_path():
204 with tempfile.NamedTemporaryFile() as file:
205 return file.name
206
207
208 @contextlib.contextmanager
209 def unix_socket_path():
210 path = gen_unix_socket_path()
211 try:
212 yield path
213 finally:
214 try:
215 os.unlink(path)
216 except OSError:
217 pass
218
219
220 @contextlib.contextmanager
221 def run_test_unix_server(*, use_ssl=False):
222 with unix_socket_path() as path:
223 yield from _run_test_server(address=path, use_ssl=use_ssl,
224 server_cls=SilentUnixWSGIServer,
225 server_ssl_cls=UnixSSLWSGIServer)
226
227
228@contextlib.contextmanager
229def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
230 yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
231 server_cls=SilentWSGIServer,
232 server_ssl_cls=SSLWSGIServer)
233
234
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700235def make_test_protocol(base):
236 dct = {}
237 for name in dir(base):
238 if name.startswith('__') and name.endswith('__'):
239 # skip magic names
240 continue
Victor Stinnera1254972014-02-11 11:34:30 +0100241 dct[name] = MockCallback(return_value=None)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700242 return type('TestProtocol', (base,) + base.__bases__, dct)()
243
244
245class TestSelector(selectors.BaseSelector):
246
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100247 def __init__(self):
248 self.keys = {}
249
250 def register(self, fileobj, events, data=None):
251 key = selectors.SelectorKey(fileobj, 0, events, data)
252 self.keys[fileobj] = key
253 return key
254
255 def unregister(self, fileobj):
256 return self.keys.pop(fileobj)
257
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700258 def select(self, timeout):
259 return []
260
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100261 def get_map(self):
262 return self.keys
263
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700264
265class TestLoop(base_events.BaseEventLoop):
266 """Loop for unittests.
267
268 It manages self time directly.
269 If something scheduled to be executed later then
270 on next loop iteration after all ready handlers done
271 generator passed to __init__ is calling.
272
273 Generator should be like this:
274
275 def gen():
276 ...
277 when = yield ...
278 ... = yield time_advance
279
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500280 Value returned by yield is absolute time of next scheduled handler.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700281 Value passed to yield is time advance to move loop's time forward.
282 """
283
284 def __init__(self, gen=None):
285 super().__init__()
286
287 if gen is None:
288 def gen():
289 yield
290 self._check_on_close = False
291 else:
292 self._check_on_close = True
293
294 self._gen = gen()
295 next(self._gen)
296 self._time = 0
Victor Stinner06847d92014-02-11 09:03:47 +0100297 self._clock_resolution = 1e-9
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700298 self._timers = []
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700299 self._selector = TestSelector()
300
301 self.readers = {}
302 self.writers = {}
303 self.reset_counters()
304
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400305 self._transports = weakref.WeakValueDictionary()
306
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700307 def time(self):
308 return self._time
309
310 def advance_time(self, advance):
311 """Move test time forward."""
312 if advance:
313 self._time += advance
314
315 def close(self):
Victor Stinner29ad0112015-01-15 00:04:21 +0100316 super().close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700317 if self._check_on_close:
318 try:
319 self._gen.send(0)
320 except StopIteration:
321 pass
322 else: # pragma: no cover
323 raise AssertionError("Time generator is not finished")
324
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400325 def _add_reader(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500326 self.readers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700327
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400328 def _remove_reader(self, fd):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700329 self.remove_reader_count[fd] += 1
330 if fd in self.readers:
331 del self.readers[fd]
332 return True
333 else:
334 return False
335
336 def assert_reader(self, fd, callback, *args):
337 assert fd in self.readers, 'fd {} is not registered'.format(fd)
338 handle = self.readers[fd]
339 assert handle._callback == callback, '{!r} != {!r}'.format(
340 handle._callback, callback)
341 assert handle._args == args, '{!r} != {!r}'.format(
342 handle._args, args)
343
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400344 def _add_writer(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500345 self.writers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700346
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400347 def _remove_writer(self, fd):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700348 self.remove_writer_count[fd] += 1
349 if fd in self.writers:
350 del self.writers[fd]
351 return True
352 else:
353 return False
354
355 def assert_writer(self, fd, callback, *args):
356 assert fd in self.writers, 'fd {} is not registered'.format(fd)
357 handle = self.writers[fd]
358 assert handle._callback == callback, '{!r} != {!r}'.format(
359 handle._callback, callback)
360 assert handle._args == args, '{!r} != {!r}'.format(
361 handle._args, args)
362
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400363 def _ensure_fd_no_transport(self, fd):
Yury Selivanovce126292017-11-13 13:38:22 -0500364 if not isinstance(fd, int):
365 try:
366 fd = int(fd.fileno())
367 except (AttributeError, TypeError, ValueError):
368 # This code matches selectors._fileobj_to_fd function.
369 raise ValueError("Invalid file object: "
370 "{!r}".format(fd)) from None
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400371 try:
372 transport = self._transports[fd]
373 except KeyError:
374 pass
375 else:
376 raise RuntimeError(
377 'File descriptor {!r} is used by transport {!r}'.format(
378 fd, transport))
379
380 def add_reader(self, fd, callback, *args):
381 """Add a reader callback."""
382 self._ensure_fd_no_transport(fd)
383 return self._add_reader(fd, callback, *args)
384
385 def remove_reader(self, fd):
386 """Remove a reader callback."""
387 self._ensure_fd_no_transport(fd)
388 return self._remove_reader(fd)
389
390 def add_writer(self, fd, callback, *args):
391 """Add a writer callback.."""
392 self._ensure_fd_no_transport(fd)
393 return self._add_writer(fd, callback, *args)
394
395 def remove_writer(self, fd):
396 """Remove a writer callback."""
397 self._ensure_fd_no_transport(fd)
398 return self._remove_writer(fd)
399
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700400 def reset_counters(self):
401 self.remove_reader_count = collections.defaultdict(int)
402 self.remove_writer_count = collections.defaultdict(int)
403
404 def _run_once(self):
405 super()._run_once()
406 for when in self._timers:
407 advance = self._gen.send(when)
408 self.advance_time(advance)
409 self._timers = []
410
411 def call_at(self, when, callback, *args):
412 self._timers.append(when)
413 return super().call_at(when, callback, *args)
414
415 def _process_events(self, event_list):
416 return
417
418 def _write_to_self(self):
419 pass
Victor Stinnera1254972014-02-11 11:34:30 +0100420
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500421
Victor Stinnera1254972014-02-11 11:34:30 +0100422def MockCallback(**kwargs):
Victor Stinner24ba2032014-02-26 10:25:02 +0100423 return mock.Mock(spec=['__call__'], **kwargs)
Yury Selivanovff827f02014-02-18 18:02:19 -0500424
425
426class MockPattern(str):
427 """A regex based str with a fuzzy __eq__.
428
429 Use this helper with 'mock.assert_called_with', or anywhere
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500430 where a regex comparison between strings is needed.
Yury Selivanovff827f02014-02-18 18:02:19 -0500431
432 For instance:
433 mock_call.assert_called_with(MockPattern('spam.*ham'))
434 """
435 def __eq__(self, other):
436 return bool(re.search(str(self), other, re.S))
Victor Stinner307bccc2014-06-12 18:39:26 +0200437
438
439def get_function_source(func):
440 source = events._get_function_source(func)
441 if source is None:
442 raise ValueError("unable to get the source of %r" % (func,))
443 return source
Victor Stinnerc73701d2014-06-18 01:36:32 +0200444
445
446class TestCase(unittest.TestCase):
Victor Stinner16432be2017-09-01 14:46:06 +0200447 @staticmethod
448 def close_loop(loop):
449 executor = loop._default_executor
450 if executor is not None:
451 executor.shutdown(wait=True)
452 loop.close()
453
Victor Stinnerc73701d2014-06-18 01:36:32 +0200454 def set_event_loop(self, loop, *, cleanup=True):
455 assert loop is not None
456 # ensure that the event loop is passed explicitly in asyncio
457 events.set_event_loop(None)
458 if cleanup:
Victor Stinner16432be2017-09-01 14:46:06 +0200459 self.addCleanup(self.close_loop, loop)
Victor Stinnerc73701d2014-06-18 01:36:32 +0200460
461 def new_test_loop(self, gen=None):
462 loop = TestLoop(gen)
463 self.set_event_loop(loop)
464 return loop
465
Yury Selivanovba7e1f92017-03-02 20:07:11 -0500466 def unpatch_get_running_loop(self):
467 events._get_running_loop = self._get_running_loop
468
Yury Selivanov600a3492016-11-04 14:29:28 -0400469 def setUp(self):
470 self._get_running_loop = events._get_running_loop
471 events._get_running_loop = lambda: None
Victor Stinnerb9030672017-06-30 11:12:33 +0200472 self._thread_cleanup = support.threading_setup()
Yury Selivanov600a3492016-11-04 14:29:28 -0400473
Victor Stinnerc73701d2014-06-18 01:36:32 +0200474 def tearDown(self):
Yury Selivanovba7e1f92017-03-02 20:07:11 -0500475 self.unpatch_get_running_loop()
Yury Selivanov600a3492016-11-04 14:29:28 -0400476
Victor Stinnerc73701d2014-06-18 01:36:32 +0200477 events.set_event_loop(None)
Victor Stinner1cae9ec2014-07-14 22:26:34 +0200478
Victor Stinner5d44c082015-02-02 18:36:31 +0100479 # Detect CPython bug #23353: ensure that yield/yield-from is not used
480 # in an except block of a generator
481 self.assertEqual(sys.exc_info(), (None, None, None))
482
Victor Stinnerb9030672017-06-30 11:12:33 +0200483 self.doCleanups()
484 support.threading_cleanup(*self._thread_cleanup)
485 support.reap_children()
486
Victor Stinner1cae9ec2014-07-14 22:26:34 +0200487
488@contextlib.contextmanager
489def disable_logger():
490 """Context manager to disable asyncio logger.
491
492 For example, it can be used to ignore warnings in debug mode.
493 """
494 old_level = logger.level
495 try:
496 logger.setLevel(logging.CRITICAL+1)
497 yield
498 finally:
499 logger.setLevel(old_level)
Victor Stinnerb2614752014-08-25 23:20:52 +0200500
Yury Selivanovd5c2a622015-12-16 19:31:17 -0500501
502def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
503 family=socket.AF_INET):
Victor Stinnerb2614752014-08-25 23:20:52 +0200504 """Create a mock of a non-blocking socket."""
Yury Selivanovd5c2a622015-12-16 19:31:17 -0500505 sock = mock.MagicMock(socket.socket)
506 sock.proto = proto
507 sock.type = type
508 sock.family = family
Victor Stinnerb2614752014-08-25 23:20:52 +0200509 sock.gettimeout.return_value = 0.0
510 return sock