blob: d273b08e8fac597558ca03fcf6f879687960ffbe [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
Victor Stinner1cae9ec2014-07-14 22:26:34 +02006import logging
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07007import os
Yury Selivanovff827f02014-02-18 18:02:19 -05008import re
Yury Selivanov88a5bf02014-02-18 12:15:06 -05009import socket
10import socketserver
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070011import sys
Yury Selivanov88a5bf02014-02-18 12:15:06 -050012import tempfile
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070013import threading
Antoine Pitroud20afad2013-10-20 01:51:25 +020014import time
Victor Stinnerc73701d2014-06-18 01:36:32 +020015import unittest
Yury Selivanov5b8d4f92016-10-05 17:48:59 -040016import weakref
17
Victor Stinner24ba2032014-02-26 10:25:02 +010018from unittest import mock
Yury Selivanov88a5bf02014-02-18 12:15:06 -050019
20from http.server import HTTPServer
Victor Stinnerda492a82014-02-20 10:37:27 +010021from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
Yury Selivanov88a5bf02014-02-18 12:15:06 -050022
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070023try:
24 import ssl
25except ImportError: # pragma: no cover
26 ssl = None
27
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070028from . import base_events
29from . import events
Victor Stinnere6a53792014-03-06 01:00:36 +010030from . import futures
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070031from . import selectors
Victor Stinnere6a53792014-03-06 01:00:36 +010032from . import tasks
Victor Stinnerf951d282014-06-29 00:46:45 +020033from .coroutines import coroutine
Victor Stinner1cae9ec2014-07-14 22:26:34 +020034from .log import logger
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070035
36
37if sys.platform == 'win32': # pragma: no cover
38 from .windows_utils import socketpair
39else:
40 from socket import socketpair # pragma: no cover
41
42
43def dummy_ssl_context():
44 if ssl is None:
45 return None
46 else:
47 return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
48
49
50def run_briefly(loop):
Victor Stinnerf951d282014-06-29 00:46:45 +020051 @coroutine
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070052 def once():
53 pass
54 gen = once()
Victor Stinner896a25a2014-07-08 11:29:25 +020055 t = loop.create_task(gen)
Victor Stinner98b63912014-06-30 14:51:04 +020056 # Don't log a warning if the task is not done after run_until_complete().
57 # It occurs if the loop is stopped or if a task raises a BaseException.
58 t._log_destroy_pending = False
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070059 try:
60 loop.run_until_complete(t)
61 finally:
62 gen.close()
63
64
Victor Stinnere6a53792014-03-06 01:00:36 +010065def run_until(loop, pred, timeout=30):
66 deadline = time.time() + timeout
Antoine Pitroud20afad2013-10-20 01:51:25 +020067 while not pred():
68 if timeout is not None:
69 timeout = deadline - time.time()
70 if timeout <= 0:
Victor Stinnere6a53792014-03-06 01:00:36 +010071 raise futures.TimeoutError()
72 loop.run_until_complete(tasks.sleep(0.001, loop=loop))
Antoine Pitroud20afad2013-10-20 01:51:25 +020073
74
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070075def run_once(loop):
Guido van Rossum41f69f42015-11-19 13:28:47 -080076 """Legacy API to run once through the event loop.
77
78 This is the recommended pattern for test code. It will poll the
79 selector once and run all callbacks scheduled in response to I/O
80 events.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070081 """
Guido van Rossum41f69f42015-11-19 13:28:47 -080082 loop.call_soon(loop.stop)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070083 loop.run_forever()
84
85
Yury Selivanov88a5bf02014-02-18 12:15:06 -050086class SilentWSGIRequestHandler(WSGIRequestHandler):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070087
Yury Selivanov88a5bf02014-02-18 12:15:06 -050088 def get_stderr(self):
89 return io.StringIO()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070090
Yury Selivanov88a5bf02014-02-18 12:15:06 -050091 def log_message(self, format, *args):
92 pass
93
94
95class SilentWSGIServer(WSGIServer):
96
Antoine Pitroufd39a892014-10-15 16:58:21 +020097 request_timeout = 2
98
99 def get_request(self):
100 request, client_addr = super().get_request()
101 request.settimeout(self.request_timeout)
102 return request, client_addr
103
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500104 def handle_error(self, request, client_address):
105 pass
106
107
108class SSLWSGIServerMixin:
109
110 def finish_request(self, request, client_address):
111 # The relative location of our test directory (which
112 # contains the ssl key and certificate files) differs
113 # between the stdlib and stand-alone asyncio.
114 # Prefer our own if we can find it.
115 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
116 if not os.path.isdir(here):
117 here = os.path.join(os.path.dirname(os.__file__),
118 'test', 'test_asyncio')
119 keyfile = os.path.join(here, 'ssl_key.pem')
120 certfile = os.path.join(here, 'ssl_cert.pem')
Christian Heimesd0486372016-09-10 23:23:33 +0200121 context = ssl.SSLContext()
122 context.load_cert_chain(certfile, keyfile)
123
124 ssock = context.wrap_socket(request, server_side=True)
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500125 try:
126 self.RequestHandlerClass(ssock, client_address, self)
127 ssock.close()
128 except OSError:
129 # maybe socket has been closed by peer
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700130 pass
131
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700132
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500133class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
134 pass
135
136
137def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700138
139 def app(environ, start_response):
140 status = '200 OK'
141 headers = [('Content-type', 'text/plain')]
142 start_response(status, headers)
143 return [b'Test message']
144
145 # Run the test WSGI server in a separate thread in order not to
146 # interfere with event handling in the main thread
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500147 server_class = server_ssl_cls if use_ssl else server_cls
148 httpd = server_class(address, SilentWSGIRequestHandler)
149 httpd.set_app(app)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700150 httpd.address = httpd.server_address
Antoine Pitroufd39a892014-10-15 16:58:21 +0200151 server_thread = threading.Thread(
152 target=lambda: httpd.serve_forever(poll_interval=0.05))
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700153 server_thread.start()
154 try:
155 yield httpd
156 finally:
157 httpd.shutdown()
Antoine Pitroua7a150c2013-10-20 23:26:23 +0200158 httpd.server_close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700159 server_thread.join()
160
161
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500162if hasattr(socket, 'AF_UNIX'):
163
164 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
165
166 def server_bind(self):
167 socketserver.UnixStreamServer.server_bind(self)
168 self.server_name = '127.0.0.1'
169 self.server_port = 80
170
171
172 class UnixWSGIServer(UnixHTTPServer, WSGIServer):
173
Antoine Pitroufd39a892014-10-15 16:58:21 +0200174 request_timeout = 2
175
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500176 def server_bind(self):
177 UnixHTTPServer.server_bind(self)
178 self.setup_environ()
179
180 def get_request(self):
181 request, client_addr = super().get_request()
Antoine Pitroufd39a892014-10-15 16:58:21 +0200182 request.settimeout(self.request_timeout)
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500183 # Code in the stdlib expects that get_request
184 # will return a socket and a tuple (host, port).
185 # However, this isn't true for UNIX sockets,
186 # as the second return value will be a path;
187 # hence we return some fake data sufficient
188 # to get the tests going
189 return request, ('127.0.0.1', '')
190
191
192 class SilentUnixWSGIServer(UnixWSGIServer):
193
194 def handle_error(self, request, client_address):
195 pass
196
197
198 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
199 pass
200
201
202 def gen_unix_socket_path():
203 with tempfile.NamedTemporaryFile() as file:
204 return file.name
205
206
207 @contextlib.contextmanager
208 def unix_socket_path():
209 path = gen_unix_socket_path()
210 try:
211 yield path
212 finally:
213 try:
214 os.unlink(path)
215 except OSError:
216 pass
217
218
219 @contextlib.contextmanager
220 def run_test_unix_server(*, use_ssl=False):
221 with unix_socket_path() as path:
222 yield from _run_test_server(address=path, use_ssl=use_ssl,
223 server_cls=SilentUnixWSGIServer,
224 server_ssl_cls=UnixSSLWSGIServer)
225
226
227@contextlib.contextmanager
228def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
229 yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
230 server_cls=SilentWSGIServer,
231 server_ssl_cls=SSLWSGIServer)
232
233
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700234def make_test_protocol(base):
235 dct = {}
236 for name in dir(base):
237 if name.startswith('__') and name.endswith('__'):
238 # skip magic names
239 continue
Victor Stinnera1254972014-02-11 11:34:30 +0100240 dct[name] = MockCallback(return_value=None)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700241 return type('TestProtocol', (base,) + base.__bases__, dct)()
242
243
244class TestSelector(selectors.BaseSelector):
245
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100246 def __init__(self):
247 self.keys = {}
248
249 def register(self, fileobj, events, data=None):
250 key = selectors.SelectorKey(fileobj, 0, events, data)
251 self.keys[fileobj] = key
252 return key
253
254 def unregister(self, fileobj):
255 return self.keys.pop(fileobj)
256
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700257 def select(self, timeout):
258 return []
259
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100260 def get_map(self):
261 return self.keys
262
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700263
264class TestLoop(base_events.BaseEventLoop):
265 """Loop for unittests.
266
267 It manages self time directly.
268 If something scheduled to be executed later then
269 on next loop iteration after all ready handlers done
270 generator passed to __init__ is calling.
271
272 Generator should be like this:
273
274 def gen():
275 ...
276 when = yield ...
277 ... = yield time_advance
278
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500279 Value returned by yield is absolute time of next scheduled handler.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700280 Value passed to yield is time advance to move loop's time forward.
281 """
282
283 def __init__(self, gen=None):
284 super().__init__()
285
286 if gen is None:
287 def gen():
288 yield
289 self._check_on_close = False
290 else:
291 self._check_on_close = True
292
293 self._gen = gen()
294 next(self._gen)
295 self._time = 0
Victor Stinner06847d92014-02-11 09:03:47 +0100296 self._clock_resolution = 1e-9
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700297 self._timers = []
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700298 self._selector = TestSelector()
299
300 self.readers = {}
301 self.writers = {}
302 self.reset_counters()
303
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400304 self._transports = weakref.WeakValueDictionary()
305
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700306 def time(self):
307 return self._time
308
309 def advance_time(self, advance):
310 """Move test time forward."""
311 if advance:
312 self._time += advance
313
314 def close(self):
Victor Stinner29ad0112015-01-15 00:04:21 +0100315 super().close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700316 if self._check_on_close:
317 try:
318 self._gen.send(0)
319 except StopIteration:
320 pass
321 else: # pragma: no cover
322 raise AssertionError("Time generator is not finished")
323
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400324 def _add_reader(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500325 self.readers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700326
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400327 def _remove_reader(self, fd):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700328 self.remove_reader_count[fd] += 1
329 if fd in self.readers:
330 del self.readers[fd]
331 return True
332 else:
333 return False
334
335 def assert_reader(self, fd, callback, *args):
336 assert fd in self.readers, 'fd {} is not registered'.format(fd)
337 handle = self.readers[fd]
338 assert handle._callback == callback, '{!r} != {!r}'.format(
339 handle._callback, callback)
340 assert handle._args == args, '{!r} != {!r}'.format(
341 handle._args, args)
342
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400343 def _add_writer(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500344 self.writers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700345
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400346 def _remove_writer(self, fd):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700347 self.remove_writer_count[fd] += 1
348 if fd in self.writers:
349 del self.writers[fd]
350 return True
351 else:
352 return False
353
354 def assert_writer(self, fd, callback, *args):
355 assert fd in self.writers, 'fd {} is not registered'.format(fd)
356 handle = self.writers[fd]
357 assert handle._callback == callback, '{!r} != {!r}'.format(
358 handle._callback, callback)
359 assert handle._args == args, '{!r} != {!r}'.format(
360 handle._args, args)
361
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400362 def _ensure_fd_no_transport(self, fd):
363 try:
364 transport = self._transports[fd]
365 except KeyError:
366 pass
367 else:
368 raise RuntimeError(
369 'File descriptor {!r} is used by transport {!r}'.format(
370 fd, transport))
371
372 def add_reader(self, fd, callback, *args):
373 """Add a reader callback."""
374 self._ensure_fd_no_transport(fd)
375 return self._add_reader(fd, callback, *args)
376
377 def remove_reader(self, fd):
378 """Remove a reader callback."""
379 self._ensure_fd_no_transport(fd)
380 return self._remove_reader(fd)
381
382 def add_writer(self, fd, callback, *args):
383 """Add a writer callback.."""
384 self._ensure_fd_no_transport(fd)
385 return self._add_writer(fd, callback, *args)
386
387 def remove_writer(self, fd):
388 """Remove a writer callback."""
389 self._ensure_fd_no_transport(fd)
390 return self._remove_writer(fd)
391
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700392 def reset_counters(self):
393 self.remove_reader_count = collections.defaultdict(int)
394 self.remove_writer_count = collections.defaultdict(int)
395
396 def _run_once(self):
397 super()._run_once()
398 for when in self._timers:
399 advance = self._gen.send(when)
400 self.advance_time(advance)
401 self._timers = []
402
403 def call_at(self, when, callback, *args):
404 self._timers.append(when)
405 return super().call_at(when, callback, *args)
406
407 def _process_events(self, event_list):
408 return
409
410 def _write_to_self(self):
411 pass
Victor Stinnera1254972014-02-11 11:34:30 +0100412
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500413
Victor Stinnera1254972014-02-11 11:34:30 +0100414def MockCallback(**kwargs):
Victor Stinner24ba2032014-02-26 10:25:02 +0100415 return mock.Mock(spec=['__call__'], **kwargs)
Yury Selivanovff827f02014-02-18 18:02:19 -0500416
417
418class MockPattern(str):
419 """A regex based str with a fuzzy __eq__.
420
421 Use this helper with 'mock.assert_called_with', or anywhere
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500422 where a regex comparison between strings is needed.
Yury Selivanovff827f02014-02-18 18:02:19 -0500423
424 For instance:
425 mock_call.assert_called_with(MockPattern('spam.*ham'))
426 """
427 def __eq__(self, other):
428 return bool(re.search(str(self), other, re.S))
Victor Stinner307bccc2014-06-12 18:39:26 +0200429
430
431def get_function_source(func):
432 source = events._get_function_source(func)
433 if source is None:
434 raise ValueError("unable to get the source of %r" % (func,))
435 return source
Victor Stinnerc73701d2014-06-18 01:36:32 +0200436
437
438class TestCase(unittest.TestCase):
439 def set_event_loop(self, loop, *, cleanup=True):
440 assert loop is not None
441 # ensure that the event loop is passed explicitly in asyncio
442 events.set_event_loop(None)
443 if cleanup:
444 self.addCleanup(loop.close)
445
446 def new_test_loop(self, gen=None):
447 loop = TestLoop(gen)
448 self.set_event_loop(loop)
449 return loop
450
Yury Selivanovba7e1f92017-03-02 20:07:11 -0500451 def unpatch_get_running_loop(self):
452 events._get_running_loop = self._get_running_loop
453
Yury Selivanov600a3492016-11-04 14:29:28 -0400454 def setUp(self):
455 self._get_running_loop = events._get_running_loop
456 events._get_running_loop = lambda: None
457
Victor Stinnerc73701d2014-06-18 01:36:32 +0200458 def tearDown(self):
Yury Selivanovba7e1f92017-03-02 20:07:11 -0500459 self.unpatch_get_running_loop()
Yury Selivanov600a3492016-11-04 14:29:28 -0400460
Victor Stinnerc73701d2014-06-18 01:36:32 +0200461 events.set_event_loop(None)
Victor Stinner1cae9ec2014-07-14 22:26:34 +0200462
Victor Stinner5d44c082015-02-02 18:36:31 +0100463 # Detect CPython bug #23353: ensure that yield/yield-from is not used
464 # in an except block of a generator
465 self.assertEqual(sys.exc_info(), (None, None, None))
466
Victor Stinner1cae9ec2014-07-14 22:26:34 +0200467
468@contextlib.contextmanager
469def disable_logger():
470 """Context manager to disable asyncio logger.
471
472 For example, it can be used to ignore warnings in debug mode.
473 """
474 old_level = logger.level
475 try:
476 logger.setLevel(logging.CRITICAL+1)
477 yield
478 finally:
479 logger.setLevel(old_level)
Victor Stinnerb2614752014-08-25 23:20:52 +0200480
Yury Selivanovd5c2a622015-12-16 19:31:17 -0500481
482def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
483 family=socket.AF_INET):
Victor Stinnerb2614752014-08-25 23:20:52 +0200484 """Create a mock of a non-blocking socket."""
Yury Selivanovd5c2a622015-12-16 19:31:17 -0500485 sock = mock.MagicMock(socket.socket)
486 sock.proto = proto
487 sock.type = type
488 sock.family = family
Victor Stinnerb2614752014-08-25 23:20:52 +0200489 sock.gettimeout.return_value = 0.0
490 return sock
Victor Stinner231b4042015-01-14 00:19:09 +0100491
492
493def force_legacy_ssl_support():
494 return mock.patch('asyncio.sslproto._is_sslproto_available',
495 return_value=False)