blob: e06ac06eee33ea128ead6b08c32e74cae4344637 [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
Victor Stinner1cae9ec2014-07-14 22:26:34 +02006import logging
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07007import os
Yury Selivanovff827f02014-02-18 18:02:19 -05008import re
Yury Selivanov88a5bf02014-02-18 12:15:06 -05009import socket
10import socketserver
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070011import sys
Yury Selivanov88a5bf02014-02-18 12:15:06 -050012import tempfile
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070013import threading
Antoine Pitroud20afad2013-10-20 01:51:25 +020014import time
Victor Stinnerc73701d2014-06-18 01:36:32 +020015import unittest
Victor Stinner24ba2032014-02-26 10:25:02 +010016from unittest import mock
Yury Selivanov88a5bf02014-02-18 12:15:06 -050017
18from http.server import HTTPServer
Victor Stinnerda492a82014-02-20 10:37:27 +010019from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
Yury Selivanov88a5bf02014-02-18 12:15:06 -050020
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070021try:
22 import ssl
23except ImportError: # pragma: no cover
24 ssl = None
25
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070026from . import base_events
27from . import events
Victor Stinnere6a53792014-03-06 01:00:36 +010028from . import futures
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070029from . import selectors
Victor Stinnere6a53792014-03-06 01:00:36 +010030from . import tasks
Victor Stinnerf951d282014-06-29 00:46:45 +020031from .coroutines import coroutine
Victor Stinner1cae9ec2014-07-14 22:26:34 +020032from .log import logger
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070033
34
35if sys.platform == 'win32': # pragma: no cover
36 from .windows_utils import socketpair
37else:
38 from socket import socketpair # pragma: no cover
39
40
41def dummy_ssl_context():
42 if ssl is None:
43 return None
44 else:
45 return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
46
47
48def run_briefly(loop):
Victor Stinnerf951d282014-06-29 00:46:45 +020049 @coroutine
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070050 def once():
51 pass
52 gen = once()
Victor Stinner896a25a2014-07-08 11:29:25 +020053 t = loop.create_task(gen)
Victor Stinner98b63912014-06-30 14:51:04 +020054 # Don't log a warning if the task is not done after run_until_complete().
55 # It occurs if the loop is stopped or if a task raises a BaseException.
56 t._log_destroy_pending = False
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070057 try:
58 loop.run_until_complete(t)
59 finally:
60 gen.close()
61
62
Victor Stinnere6a53792014-03-06 01:00:36 +010063def run_until(loop, pred, timeout=30):
64 deadline = time.time() + timeout
Antoine Pitroud20afad2013-10-20 01:51:25 +020065 while not pred():
66 if timeout is not None:
67 timeout = deadline - time.time()
68 if timeout <= 0:
Victor Stinnere6a53792014-03-06 01:00:36 +010069 raise futures.TimeoutError()
70 loop.run_until_complete(tasks.sleep(0.001, loop=loop))
Antoine Pitroud20afad2013-10-20 01:51:25 +020071
72
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070073def run_once(loop):
Guido van Rossum41f69f42015-11-19 13:28:47 -080074 """Legacy API to run once through the event loop.
75
76 This is the recommended pattern for test code. It will poll the
77 selector once and run all callbacks scheduled in response to I/O
78 events.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070079 """
Guido van Rossum41f69f42015-11-19 13:28:47 -080080 loop.call_soon(loop.stop)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070081 loop.run_forever()
82
83
Yury Selivanov88a5bf02014-02-18 12:15:06 -050084class SilentWSGIRequestHandler(WSGIRequestHandler):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070085
Yury Selivanov88a5bf02014-02-18 12:15:06 -050086 def get_stderr(self):
87 return io.StringIO()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070088
Yury Selivanov88a5bf02014-02-18 12:15:06 -050089 def log_message(self, format, *args):
90 pass
91
92
93class SilentWSGIServer(WSGIServer):
94
Antoine Pitroufd39a892014-10-15 16:58:21 +020095 request_timeout = 2
96
97 def get_request(self):
98 request, client_addr = super().get_request()
99 request.settimeout(self.request_timeout)
100 return request, client_addr
101
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500102 def handle_error(self, request, client_address):
103 pass
104
105
106class SSLWSGIServerMixin:
107
108 def finish_request(self, request, client_address):
109 # The relative location of our test directory (which
110 # contains the ssl key and certificate files) differs
111 # between the stdlib and stand-alone asyncio.
112 # Prefer our own if we can find it.
113 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
114 if not os.path.isdir(here):
115 here = os.path.join(os.path.dirname(os.__file__),
116 'test', 'test_asyncio')
117 keyfile = os.path.join(here, 'ssl_key.pem')
118 certfile = os.path.join(here, 'ssl_cert.pem')
119 ssock = ssl.wrap_socket(request,
120 keyfile=keyfile,
121 certfile=certfile,
122 server_side=True)
123 try:
124 self.RequestHandlerClass(ssock, client_address, self)
125 ssock.close()
126 except OSError:
127 # maybe socket has been closed by peer
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700128 pass
129
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700130
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500131class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
132 pass
133
134
135def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700136
137 def app(environ, start_response):
138 status = '200 OK'
139 headers = [('Content-type', 'text/plain')]
140 start_response(status, headers)
141 return [b'Test message']
142
143 # Run the test WSGI server in a separate thread in order not to
144 # interfere with event handling in the main thread
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500145 server_class = server_ssl_cls if use_ssl else server_cls
146 httpd = server_class(address, SilentWSGIRequestHandler)
147 httpd.set_app(app)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700148 httpd.address = httpd.server_address
Antoine Pitroufd39a892014-10-15 16:58:21 +0200149 server_thread = threading.Thread(
150 target=lambda: httpd.serve_forever(poll_interval=0.05))
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700151 server_thread.start()
152 try:
153 yield httpd
154 finally:
155 httpd.shutdown()
Antoine Pitroua7a150c2013-10-20 23:26:23 +0200156 httpd.server_close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700157 server_thread.join()
158
159
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500160if hasattr(socket, 'AF_UNIX'):
161
162 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
163
164 def server_bind(self):
165 socketserver.UnixStreamServer.server_bind(self)
166 self.server_name = '127.0.0.1'
167 self.server_port = 80
168
169
170 class UnixWSGIServer(UnixHTTPServer, WSGIServer):
171
Antoine Pitroufd39a892014-10-15 16:58:21 +0200172 request_timeout = 2
173
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500174 def server_bind(self):
175 UnixHTTPServer.server_bind(self)
176 self.setup_environ()
177
178 def get_request(self):
179 request, client_addr = super().get_request()
Antoine Pitroufd39a892014-10-15 16:58:21 +0200180 request.settimeout(self.request_timeout)
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500181 # Code in the stdlib expects that get_request
182 # will return a socket and a tuple (host, port).
183 # However, this isn't true for UNIX sockets,
184 # as the second return value will be a path;
185 # hence we return some fake data sufficient
186 # to get the tests going
187 return request, ('127.0.0.1', '')
188
189
190 class SilentUnixWSGIServer(UnixWSGIServer):
191
192 def handle_error(self, request, client_address):
193 pass
194
195
196 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
197 pass
198
199
200 def gen_unix_socket_path():
201 with tempfile.NamedTemporaryFile() as file:
202 return file.name
203
204
205 @contextlib.contextmanager
206 def unix_socket_path():
207 path = gen_unix_socket_path()
208 try:
209 yield path
210 finally:
211 try:
212 os.unlink(path)
213 except OSError:
214 pass
215
216
217 @contextlib.contextmanager
218 def run_test_unix_server(*, use_ssl=False):
219 with unix_socket_path() as path:
220 yield from _run_test_server(address=path, use_ssl=use_ssl,
221 server_cls=SilentUnixWSGIServer,
222 server_ssl_cls=UnixSSLWSGIServer)
223
224
225@contextlib.contextmanager
226def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
227 yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
228 server_cls=SilentWSGIServer,
229 server_ssl_cls=SSLWSGIServer)
230
231
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700232def make_test_protocol(base):
233 dct = {}
234 for name in dir(base):
235 if name.startswith('__') and name.endswith('__'):
236 # skip magic names
237 continue
Victor Stinnera1254972014-02-11 11:34:30 +0100238 dct[name] = MockCallback(return_value=None)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700239 return type('TestProtocol', (base,) + base.__bases__, dct)()
240
241
242class TestSelector(selectors.BaseSelector):
243
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100244 def __init__(self):
245 self.keys = {}
246
247 def register(self, fileobj, events, data=None):
248 key = selectors.SelectorKey(fileobj, 0, events, data)
249 self.keys[fileobj] = key
250 return key
251
252 def unregister(self, fileobj):
253 return self.keys.pop(fileobj)
254
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700255 def select(self, timeout):
256 return []
257
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100258 def get_map(self):
259 return self.keys
260
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700261
262class TestLoop(base_events.BaseEventLoop):
263 """Loop for unittests.
264
265 It manages self time directly.
266 If something scheduled to be executed later then
267 on next loop iteration after all ready handlers done
268 generator passed to __init__ is calling.
269
270 Generator should be like this:
271
272 def gen():
273 ...
274 when = yield ...
275 ... = yield time_advance
276
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500277 Value returned by yield is absolute time of next scheduled handler.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700278 Value passed to yield is time advance to move loop's time forward.
279 """
280
281 def __init__(self, gen=None):
282 super().__init__()
283
284 if gen is None:
285 def gen():
286 yield
287 self._check_on_close = False
288 else:
289 self._check_on_close = True
290
291 self._gen = gen()
292 next(self._gen)
293 self._time = 0
Victor Stinner06847d92014-02-11 09:03:47 +0100294 self._clock_resolution = 1e-9
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700295 self._timers = []
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700296 self._selector = TestSelector()
297
298 self.readers = {}
299 self.writers = {}
300 self.reset_counters()
301
302 def time(self):
303 return self._time
304
305 def advance_time(self, advance):
306 """Move test time forward."""
307 if advance:
308 self._time += advance
309
310 def close(self):
Victor Stinner29ad0112015-01-15 00:04:21 +0100311 super().close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700312 if self._check_on_close:
313 try:
314 self._gen.send(0)
315 except StopIteration:
316 pass
317 else: # pragma: no cover
318 raise AssertionError("Time generator is not finished")
319
320 def add_reader(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500321 self.readers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700322
323 def remove_reader(self, fd):
324 self.remove_reader_count[fd] += 1
325 if fd in self.readers:
326 del self.readers[fd]
327 return True
328 else:
329 return False
330
331 def assert_reader(self, fd, callback, *args):
332 assert fd in self.readers, 'fd {} is not registered'.format(fd)
333 handle = self.readers[fd]
334 assert handle._callback == callback, '{!r} != {!r}'.format(
335 handle._callback, callback)
336 assert handle._args == args, '{!r} != {!r}'.format(
337 handle._args, args)
338
339 def add_writer(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500340 self.writers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700341
342 def remove_writer(self, fd):
343 self.remove_writer_count[fd] += 1
344 if fd in self.writers:
345 del self.writers[fd]
346 return True
347 else:
348 return False
349
350 def assert_writer(self, fd, callback, *args):
351 assert fd in self.writers, 'fd {} is not registered'.format(fd)
352 handle = self.writers[fd]
353 assert handle._callback == callback, '{!r} != {!r}'.format(
354 handle._callback, callback)
355 assert handle._args == args, '{!r} != {!r}'.format(
356 handle._args, args)
357
358 def reset_counters(self):
359 self.remove_reader_count = collections.defaultdict(int)
360 self.remove_writer_count = collections.defaultdict(int)
361
362 def _run_once(self):
363 super()._run_once()
364 for when in self._timers:
365 advance = self._gen.send(when)
366 self.advance_time(advance)
367 self._timers = []
368
369 def call_at(self, when, callback, *args):
370 self._timers.append(when)
371 return super().call_at(when, callback, *args)
372
373 def _process_events(self, event_list):
374 return
375
376 def _write_to_self(self):
377 pass
Victor Stinnera1254972014-02-11 11:34:30 +0100378
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500379
Victor Stinnera1254972014-02-11 11:34:30 +0100380def MockCallback(**kwargs):
Victor Stinner24ba2032014-02-26 10:25:02 +0100381 return mock.Mock(spec=['__call__'], **kwargs)
Yury Selivanovff827f02014-02-18 18:02:19 -0500382
383
384class MockPattern(str):
385 """A regex based str with a fuzzy __eq__.
386
387 Use this helper with 'mock.assert_called_with', or anywhere
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500388 where a regex comparison between strings is needed.
Yury Selivanovff827f02014-02-18 18:02:19 -0500389
390 For instance:
391 mock_call.assert_called_with(MockPattern('spam.*ham'))
392 """
393 def __eq__(self, other):
394 return bool(re.search(str(self), other, re.S))
Victor Stinner307bccc2014-06-12 18:39:26 +0200395
396
397def get_function_source(func):
398 source = events._get_function_source(func)
399 if source is None:
400 raise ValueError("unable to get the source of %r" % (func,))
401 return source
Victor Stinnerc73701d2014-06-18 01:36:32 +0200402
403
404class TestCase(unittest.TestCase):
405 def set_event_loop(self, loop, *, cleanup=True):
406 assert loop is not None
407 # ensure that the event loop is passed explicitly in asyncio
408 events.set_event_loop(None)
409 if cleanup:
410 self.addCleanup(loop.close)
411
412 def new_test_loop(self, gen=None):
413 loop = TestLoop(gen)
414 self.set_event_loop(loop)
415 return loop
416
417 def tearDown(self):
418 events.set_event_loop(None)
Victor Stinner1cae9ec2014-07-14 22:26:34 +0200419
Victor Stinner5d44c082015-02-02 18:36:31 +0100420 # Detect CPython bug #23353: ensure that yield/yield-from is not used
421 # in an except block of a generator
422 self.assertEqual(sys.exc_info(), (None, None, None))
423
Victor Stinner1cae9ec2014-07-14 22:26:34 +0200424
425@contextlib.contextmanager
426def disable_logger():
427 """Context manager to disable asyncio logger.
428
429 For example, it can be used to ignore warnings in debug mode.
430 """
431 old_level = logger.level
432 try:
433 logger.setLevel(logging.CRITICAL+1)
434 yield
435 finally:
436 logger.setLevel(old_level)
Victor Stinnerb2614752014-08-25 23:20:52 +0200437
438def mock_nonblocking_socket():
439 """Create a mock of a non-blocking socket."""
440 sock = mock.Mock(socket.socket)
441 sock.gettimeout.return_value = 0.0
442 return sock
Victor Stinner231b4042015-01-14 00:19:09 +0100443
444
445def force_legacy_ssl_support():
446 return mock.patch('asyncio.sslproto._is_sslproto_available',
447 return_value=False)