blob: 99e3839f456858ade3911dbce9637dfa2bc0499a [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
Victor Stinner1cae9ec2014-07-14 22:26:34 +02006import logging
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07007import os
Yury Selivanovff827f02014-02-18 18:02:19 -05008import re
Yury Selivanov88a5bf02014-02-18 12:15:06 -05009import socket
10import socketserver
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070011import sys
Yury Selivanov88a5bf02014-02-18 12:15:06 -050012import tempfile
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070013import threading
Antoine Pitroud20afad2013-10-20 01:51:25 +020014import time
Victor Stinnerc73701d2014-06-18 01:36:32 +020015import unittest
Yury Selivanov5b8d4f92016-10-05 17:48:59 -040016import weakref
17
Victor Stinner24ba2032014-02-26 10:25:02 +010018from unittest import mock
Yury Selivanov88a5bf02014-02-18 12:15:06 -050019
20from http.server import HTTPServer
Victor Stinnerda492a82014-02-20 10:37:27 +010021from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
Yury Selivanov88a5bf02014-02-18 12:15:06 -050022
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070023try:
24 import ssl
25except ImportError: # pragma: no cover
26 ssl = None
27
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070028from . import base_events
Yury Selivanov0f3c9762015-11-20 12:57:34 -050029from . import compat
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070030from . import events
Victor Stinnere6a53792014-03-06 01:00:36 +010031from . import futures
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070032from . import selectors
Victor Stinnere6a53792014-03-06 01:00:36 +010033from . import tasks
Victor Stinnerf951d282014-06-29 00:46:45 +020034from .coroutines import coroutine
Victor Stinner1cae9ec2014-07-14 22:26:34 +020035from .log import logger
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070036
37
38if sys.platform == 'win32': # pragma: no cover
39 from .windows_utils import socketpair
40else:
41 from socket import socketpair # pragma: no cover
42
43
44def dummy_ssl_context():
45 if ssl is None:
46 return None
47 else:
48 return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
49
50
51def run_briefly(loop):
Victor Stinnerf951d282014-06-29 00:46:45 +020052 @coroutine
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070053 def once():
54 pass
55 gen = once()
Victor Stinner896a25a2014-07-08 11:29:25 +020056 t = loop.create_task(gen)
Victor Stinner98b63912014-06-30 14:51:04 +020057 # Don't log a warning if the task is not done after run_until_complete().
58 # It occurs if the loop is stopped or if a task raises a BaseException.
59 t._log_destroy_pending = False
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070060 try:
61 loop.run_until_complete(t)
62 finally:
63 gen.close()
64
65
Victor Stinnere6a53792014-03-06 01:00:36 +010066def run_until(loop, pred, timeout=30):
67 deadline = time.time() + timeout
Antoine Pitroud20afad2013-10-20 01:51:25 +020068 while not pred():
69 if timeout is not None:
70 timeout = deadline - time.time()
71 if timeout <= 0:
Victor Stinnere6a53792014-03-06 01:00:36 +010072 raise futures.TimeoutError()
73 loop.run_until_complete(tasks.sleep(0.001, loop=loop))
Antoine Pitroud20afad2013-10-20 01:51:25 +020074
75
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070076def run_once(loop):
Guido van Rossum41f69f42015-11-19 13:28:47 -080077 """Legacy API to run once through the event loop.
78
79 This is the recommended pattern for test code. It will poll the
80 selector once and run all callbacks scheduled in response to I/O
81 events.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070082 """
Guido van Rossum41f69f42015-11-19 13:28:47 -080083 loop.call_soon(loop.stop)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070084 loop.run_forever()
85
86
Yury Selivanov88a5bf02014-02-18 12:15:06 -050087class SilentWSGIRequestHandler(WSGIRequestHandler):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070088
Yury Selivanov88a5bf02014-02-18 12:15:06 -050089 def get_stderr(self):
90 return io.StringIO()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070091
Yury Selivanov88a5bf02014-02-18 12:15:06 -050092 def log_message(self, format, *args):
93 pass
94
95
96class SilentWSGIServer(WSGIServer):
97
Antoine Pitroufd39a892014-10-15 16:58:21 +020098 request_timeout = 2
99
100 def get_request(self):
101 request, client_addr = super().get_request()
102 request.settimeout(self.request_timeout)
103 return request, client_addr
104
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500105 def handle_error(self, request, client_address):
106 pass
107
108
109class SSLWSGIServerMixin:
110
111 def finish_request(self, request, client_address):
112 # The relative location of our test directory (which
113 # contains the ssl key and certificate files) differs
114 # between the stdlib and stand-alone asyncio.
115 # Prefer our own if we can find it.
116 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
117 if not os.path.isdir(here):
118 here = os.path.join(os.path.dirname(os.__file__),
119 'test', 'test_asyncio')
120 keyfile = os.path.join(here, 'ssl_key.pem')
121 certfile = os.path.join(here, 'ssl_cert.pem')
Christian Heimesd0486372016-09-10 23:23:33 +0200122 context = ssl.SSLContext()
123 context.load_cert_chain(certfile, keyfile)
124
125 ssock = context.wrap_socket(request, server_side=True)
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500126 try:
127 self.RequestHandlerClass(ssock, client_address, self)
128 ssock.close()
129 except OSError:
130 # maybe socket has been closed by peer
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700131 pass
132
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700133
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500134class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
135 pass
136
137
138def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700139
140 def app(environ, start_response):
141 status = '200 OK'
142 headers = [('Content-type', 'text/plain')]
143 start_response(status, headers)
144 return [b'Test message']
145
146 # Run the test WSGI server in a separate thread in order not to
147 # interfere with event handling in the main thread
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500148 server_class = server_ssl_cls if use_ssl else server_cls
149 httpd = server_class(address, SilentWSGIRequestHandler)
150 httpd.set_app(app)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700151 httpd.address = httpd.server_address
Antoine Pitroufd39a892014-10-15 16:58:21 +0200152 server_thread = threading.Thread(
153 target=lambda: httpd.serve_forever(poll_interval=0.05))
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700154 server_thread.start()
155 try:
156 yield httpd
157 finally:
158 httpd.shutdown()
Antoine Pitroua7a150c2013-10-20 23:26:23 +0200159 httpd.server_close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700160 server_thread.join()
161
162
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500163if hasattr(socket, 'AF_UNIX'):
164
165 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
166
167 def server_bind(self):
168 socketserver.UnixStreamServer.server_bind(self)
169 self.server_name = '127.0.0.1'
170 self.server_port = 80
171
172
173 class UnixWSGIServer(UnixHTTPServer, WSGIServer):
174
Antoine Pitroufd39a892014-10-15 16:58:21 +0200175 request_timeout = 2
176
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500177 def server_bind(self):
178 UnixHTTPServer.server_bind(self)
179 self.setup_environ()
180
181 def get_request(self):
182 request, client_addr = super().get_request()
Antoine Pitroufd39a892014-10-15 16:58:21 +0200183 request.settimeout(self.request_timeout)
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500184 # Code in the stdlib expects that get_request
185 # will return a socket and a tuple (host, port).
186 # However, this isn't true for UNIX sockets,
187 # as the second return value will be a path;
188 # hence we return some fake data sufficient
189 # to get the tests going
190 return request, ('127.0.0.1', '')
191
192
193 class SilentUnixWSGIServer(UnixWSGIServer):
194
195 def handle_error(self, request, client_address):
196 pass
197
198
199 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
200 pass
201
202
203 def gen_unix_socket_path():
204 with tempfile.NamedTemporaryFile() as file:
205 return file.name
206
207
208 @contextlib.contextmanager
209 def unix_socket_path():
210 path = gen_unix_socket_path()
211 try:
212 yield path
213 finally:
214 try:
215 os.unlink(path)
216 except OSError:
217 pass
218
219
220 @contextlib.contextmanager
221 def run_test_unix_server(*, use_ssl=False):
222 with unix_socket_path() as path:
223 yield from _run_test_server(address=path, use_ssl=use_ssl,
224 server_cls=SilentUnixWSGIServer,
225 server_ssl_cls=UnixSSLWSGIServer)
226
227
228@contextlib.contextmanager
229def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
230 yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
231 server_cls=SilentWSGIServer,
232 server_ssl_cls=SSLWSGIServer)
233
234
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700235def make_test_protocol(base):
236 dct = {}
237 for name in dir(base):
238 if name.startswith('__') and name.endswith('__'):
239 # skip magic names
240 continue
Victor Stinnera1254972014-02-11 11:34:30 +0100241 dct[name] = MockCallback(return_value=None)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700242 return type('TestProtocol', (base,) + base.__bases__, dct)()
243
244
245class TestSelector(selectors.BaseSelector):
246
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100247 def __init__(self):
248 self.keys = {}
249
250 def register(self, fileobj, events, data=None):
251 key = selectors.SelectorKey(fileobj, 0, events, data)
252 self.keys[fileobj] = key
253 return key
254
255 def unregister(self, fileobj):
256 return self.keys.pop(fileobj)
257
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700258 def select(self, timeout):
259 return []
260
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100261 def get_map(self):
262 return self.keys
263
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700264
265class TestLoop(base_events.BaseEventLoop):
266 """Loop for unittests.
267
268 It manages self time directly.
269 If something scheduled to be executed later then
270 on next loop iteration after all ready handlers done
271 generator passed to __init__ is calling.
272
273 Generator should be like this:
274
275 def gen():
276 ...
277 when = yield ...
278 ... = yield time_advance
279
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500280 Value returned by yield is absolute time of next scheduled handler.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700281 Value passed to yield is time advance to move loop's time forward.
282 """
283
284 def __init__(self, gen=None):
285 super().__init__()
286
287 if gen is None:
288 def gen():
289 yield
290 self._check_on_close = False
291 else:
292 self._check_on_close = True
293
294 self._gen = gen()
295 next(self._gen)
296 self._time = 0
Victor Stinner06847d92014-02-11 09:03:47 +0100297 self._clock_resolution = 1e-9
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700298 self._timers = []
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700299 self._selector = TestSelector()
300
301 self.readers = {}
302 self.writers = {}
303 self.reset_counters()
304
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400305 self._transports = weakref.WeakValueDictionary()
306
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700307 def time(self):
308 return self._time
309
310 def advance_time(self, advance):
311 """Move test time forward."""
312 if advance:
313 self._time += advance
314
315 def close(self):
Victor Stinner29ad0112015-01-15 00:04:21 +0100316 super().close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700317 if self._check_on_close:
318 try:
319 self._gen.send(0)
320 except StopIteration:
321 pass
322 else: # pragma: no cover
323 raise AssertionError("Time generator is not finished")
324
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400325 def _add_reader(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500326 self.readers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700327
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400328 def _remove_reader(self, fd):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700329 self.remove_reader_count[fd] += 1
330 if fd in self.readers:
331 del self.readers[fd]
332 return True
333 else:
334 return False
335
336 def assert_reader(self, fd, callback, *args):
337 assert fd in self.readers, 'fd {} is not registered'.format(fd)
338 handle = self.readers[fd]
339 assert handle._callback == callback, '{!r} != {!r}'.format(
340 handle._callback, callback)
341 assert handle._args == args, '{!r} != {!r}'.format(
342 handle._args, args)
343
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400344 def _add_writer(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500345 self.writers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700346
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400347 def _remove_writer(self, fd):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700348 self.remove_writer_count[fd] += 1
349 if fd in self.writers:
350 del self.writers[fd]
351 return True
352 else:
353 return False
354
355 def assert_writer(self, fd, callback, *args):
356 assert fd in self.writers, 'fd {} is not registered'.format(fd)
357 handle = self.writers[fd]
358 assert handle._callback == callback, '{!r} != {!r}'.format(
359 handle._callback, callback)
360 assert handle._args == args, '{!r} != {!r}'.format(
361 handle._args, args)
362
Yury Selivanov5b8d4f92016-10-05 17:48:59 -0400363 def _ensure_fd_no_transport(self, fd):
364 try:
365 transport = self._transports[fd]
366 except KeyError:
367 pass
368 else:
369 raise RuntimeError(
370 'File descriptor {!r} is used by transport {!r}'.format(
371 fd, transport))
372
373 def add_reader(self, fd, callback, *args):
374 """Add a reader callback."""
375 self._ensure_fd_no_transport(fd)
376 return self._add_reader(fd, callback, *args)
377
378 def remove_reader(self, fd):
379 """Remove a reader callback."""
380 self._ensure_fd_no_transport(fd)
381 return self._remove_reader(fd)
382
383 def add_writer(self, fd, callback, *args):
384 """Add a writer callback.."""
385 self._ensure_fd_no_transport(fd)
386 return self._add_writer(fd, callback, *args)
387
388 def remove_writer(self, fd):
389 """Remove a writer callback."""
390 self._ensure_fd_no_transport(fd)
391 return self._remove_writer(fd)
392
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700393 def reset_counters(self):
394 self.remove_reader_count = collections.defaultdict(int)
395 self.remove_writer_count = collections.defaultdict(int)
396
397 def _run_once(self):
398 super()._run_once()
399 for when in self._timers:
400 advance = self._gen.send(when)
401 self.advance_time(advance)
402 self._timers = []
403
404 def call_at(self, when, callback, *args):
405 self._timers.append(when)
406 return super().call_at(when, callback, *args)
407
408 def _process_events(self, event_list):
409 return
410
411 def _write_to_self(self):
412 pass
Victor Stinnera1254972014-02-11 11:34:30 +0100413
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500414
Victor Stinnera1254972014-02-11 11:34:30 +0100415def MockCallback(**kwargs):
Victor Stinner24ba2032014-02-26 10:25:02 +0100416 return mock.Mock(spec=['__call__'], **kwargs)
Yury Selivanovff827f02014-02-18 18:02:19 -0500417
418
419class MockPattern(str):
420 """A regex based str with a fuzzy __eq__.
421
422 Use this helper with 'mock.assert_called_with', or anywhere
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500423 where a regex comparison between strings is needed.
Yury Selivanovff827f02014-02-18 18:02:19 -0500424
425 For instance:
426 mock_call.assert_called_with(MockPattern('spam.*ham'))
427 """
428 def __eq__(self, other):
429 return bool(re.search(str(self), other, re.S))
Victor Stinner307bccc2014-06-12 18:39:26 +0200430
431
432def get_function_source(func):
433 source = events._get_function_source(func)
434 if source is None:
435 raise ValueError("unable to get the source of %r" % (func,))
436 return source
Victor Stinnerc73701d2014-06-18 01:36:32 +0200437
438
439class TestCase(unittest.TestCase):
440 def set_event_loop(self, loop, *, cleanup=True):
441 assert loop is not None
442 # ensure that the event loop is passed explicitly in asyncio
443 events.set_event_loop(None)
444 if cleanup:
445 self.addCleanup(loop.close)
446
447 def new_test_loop(self, gen=None):
448 loop = TestLoop(gen)
449 self.set_event_loop(loop)
450 return loop
451
Yury Selivanov600a3492016-11-04 14:29:28 -0400452 def setUp(self):
453 self._get_running_loop = events._get_running_loop
454 events._get_running_loop = lambda: None
455
Victor Stinnerc73701d2014-06-18 01:36:32 +0200456 def tearDown(self):
Yury Selivanov600a3492016-11-04 14:29:28 -0400457 events._get_running_loop = self._get_running_loop
458
Victor Stinnerc73701d2014-06-18 01:36:32 +0200459 events.set_event_loop(None)
Victor Stinner1cae9ec2014-07-14 22:26:34 +0200460
Victor Stinner5d44c082015-02-02 18:36:31 +0100461 # Detect CPython bug #23353: ensure that yield/yield-from is not used
462 # in an except block of a generator
463 self.assertEqual(sys.exc_info(), (None, None, None))
464
Yury Selivanov0f3c9762015-11-20 12:57:34 -0500465 if not compat.PY34:
466 # Python 3.3 compatibility
467 def subTest(self, *args, **kwargs):
468 class EmptyCM:
469 def __enter__(self):
470 pass
471 def __exit__(self, *exc):
472 pass
473 return EmptyCM()
474
Victor Stinner1cae9ec2014-07-14 22:26:34 +0200475
476@contextlib.contextmanager
477def disable_logger():
478 """Context manager to disable asyncio logger.
479
480 For example, it can be used to ignore warnings in debug mode.
481 """
482 old_level = logger.level
483 try:
484 logger.setLevel(logging.CRITICAL+1)
485 yield
486 finally:
487 logger.setLevel(old_level)
Victor Stinnerb2614752014-08-25 23:20:52 +0200488
Yury Selivanovd5c2a622015-12-16 19:31:17 -0500489
490def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
491 family=socket.AF_INET):
Victor Stinnerb2614752014-08-25 23:20:52 +0200492 """Create a mock of a non-blocking socket."""
Yury Selivanovd5c2a622015-12-16 19:31:17 -0500493 sock = mock.MagicMock(socket.socket)
494 sock.proto = proto
495 sock.type = type
496 sock.family = family
Victor Stinnerb2614752014-08-25 23:20:52 +0200497 sock.gettimeout.return_value = 0.0
498 return sock
Victor Stinner231b4042015-01-14 00:19:09 +0100499
500
501def force_legacy_ssl_support():
502 return mock.patch('asyncio.sslproto._is_sslproto_available',
503 return_value=False)