blob: 8cee95b84f95726dfb028182996e5cc1697dd5cf [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
Victor Stinner1cae9ec2014-07-14 22:26:34 +02006import logging
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07007import os
Yury Selivanovff827f02014-02-18 18:02:19 -05008import re
Yury Selivanov88a5bf02014-02-18 12:15:06 -05009import socket
10import socketserver
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070011import sys
Yury Selivanov88a5bf02014-02-18 12:15:06 -050012import tempfile
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070013import threading
Antoine Pitroud20afad2013-10-20 01:51:25 +020014import time
Victor Stinnerc73701d2014-06-18 01:36:32 +020015import unittest
Victor Stinner24ba2032014-02-26 10:25:02 +010016from unittest import mock
Yury Selivanov88a5bf02014-02-18 12:15:06 -050017
18from http.server import HTTPServer
Victor Stinnerda492a82014-02-20 10:37:27 +010019from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
Yury Selivanov88a5bf02014-02-18 12:15:06 -050020
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070021try:
22 import ssl
23except ImportError: # pragma: no cover
24 ssl = None
25
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070026from . import base_events
27from . import events
Victor Stinnere6a53792014-03-06 01:00:36 +010028from . import futures
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070029from . import selectors
Victor Stinnere6a53792014-03-06 01:00:36 +010030from . import tasks
Victor Stinnerf951d282014-06-29 00:46:45 +020031from .coroutines import coroutine
Victor Stinner1cae9ec2014-07-14 22:26:34 +020032from .log import logger
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070033
34
35if sys.platform == 'win32': # pragma: no cover
36 from .windows_utils import socketpair
37else:
38 from socket import socketpair # pragma: no cover
39
40
41def dummy_ssl_context():
42 if ssl is None:
43 return None
44 else:
45 return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
46
47
48def run_briefly(loop):
Victor Stinnerf951d282014-06-29 00:46:45 +020049 @coroutine
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070050 def once():
51 pass
52 gen = once()
Victor Stinner896a25a2014-07-08 11:29:25 +020053 t = loop.create_task(gen)
Victor Stinner98b63912014-06-30 14:51:04 +020054 # Don't log a warning if the task is not done after run_until_complete().
55 # It occurs if the loop is stopped or if a task raises a BaseException.
56 t._log_destroy_pending = False
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070057 try:
58 loop.run_until_complete(t)
59 finally:
60 gen.close()
61
62
Victor Stinnere6a53792014-03-06 01:00:36 +010063def run_until(loop, pred, timeout=30):
64 deadline = time.time() + timeout
Antoine Pitroud20afad2013-10-20 01:51:25 +020065 while not pred():
66 if timeout is not None:
67 timeout = deadline - time.time()
68 if timeout <= 0:
Victor Stinnere6a53792014-03-06 01:00:36 +010069 raise futures.TimeoutError()
70 loop.run_until_complete(tasks.sleep(0.001, loop=loop))
Antoine Pitroud20afad2013-10-20 01:51:25 +020071
72
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070073def run_once(loop):
74 """loop.stop() schedules _raise_stop_error()
75 and run_forever() runs until _raise_stop_error() callback.
76 this wont work if test waits for some IO events, because
77 _raise_stop_error() runs before any of io events callbacks.
78 """
79 loop.stop()
80 loop.run_forever()
81
82
Yury Selivanov88a5bf02014-02-18 12:15:06 -050083class SilentWSGIRequestHandler(WSGIRequestHandler):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070084
Yury Selivanov88a5bf02014-02-18 12:15:06 -050085 def get_stderr(self):
86 return io.StringIO()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070087
Yury Selivanov88a5bf02014-02-18 12:15:06 -050088 def log_message(self, format, *args):
89 pass
90
91
92class SilentWSGIServer(WSGIServer):
93
Antoine Pitroufd39a892014-10-15 16:58:21 +020094 request_timeout = 2
95
96 def get_request(self):
97 request, client_addr = super().get_request()
98 request.settimeout(self.request_timeout)
99 return request, client_addr
100
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500101 def handle_error(self, request, client_address):
102 pass
103
104
105class SSLWSGIServerMixin:
106
107 def finish_request(self, request, client_address):
108 # The relative location of our test directory (which
109 # contains the ssl key and certificate files) differs
110 # between the stdlib and stand-alone asyncio.
111 # Prefer our own if we can find it.
112 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
113 if not os.path.isdir(here):
114 here = os.path.join(os.path.dirname(os.__file__),
115 'test', 'test_asyncio')
116 keyfile = os.path.join(here, 'ssl_key.pem')
117 certfile = os.path.join(here, 'ssl_cert.pem')
118 ssock = ssl.wrap_socket(request,
119 keyfile=keyfile,
120 certfile=certfile,
121 server_side=True)
122 try:
123 self.RequestHandlerClass(ssock, client_address, self)
124 ssock.close()
125 except OSError:
126 # maybe socket has been closed by peer
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700127 pass
128
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700129
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500130class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
131 pass
132
133
134def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700135
136 def app(environ, start_response):
137 status = '200 OK'
138 headers = [('Content-type', 'text/plain')]
139 start_response(status, headers)
140 return [b'Test message']
141
142 # Run the test WSGI server in a separate thread in order not to
143 # interfere with event handling in the main thread
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500144 server_class = server_ssl_cls if use_ssl else server_cls
145 httpd = server_class(address, SilentWSGIRequestHandler)
146 httpd.set_app(app)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700147 httpd.address = httpd.server_address
Antoine Pitroufd39a892014-10-15 16:58:21 +0200148 server_thread = threading.Thread(
149 target=lambda: httpd.serve_forever(poll_interval=0.05))
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700150 server_thread.start()
151 try:
152 yield httpd
153 finally:
154 httpd.shutdown()
Antoine Pitroua7a150c2013-10-20 23:26:23 +0200155 httpd.server_close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700156 server_thread.join()
157
158
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500159if hasattr(socket, 'AF_UNIX'):
160
161 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
162
163 def server_bind(self):
164 socketserver.UnixStreamServer.server_bind(self)
165 self.server_name = '127.0.0.1'
166 self.server_port = 80
167
168
169 class UnixWSGIServer(UnixHTTPServer, WSGIServer):
170
Antoine Pitroufd39a892014-10-15 16:58:21 +0200171 request_timeout = 2
172
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500173 def server_bind(self):
174 UnixHTTPServer.server_bind(self)
175 self.setup_environ()
176
177 def get_request(self):
178 request, client_addr = super().get_request()
Antoine Pitroufd39a892014-10-15 16:58:21 +0200179 request.settimeout(self.request_timeout)
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500180 # Code in the stdlib expects that get_request
181 # will return a socket and a tuple (host, port).
182 # However, this isn't true for UNIX sockets,
183 # as the second return value will be a path;
184 # hence we return some fake data sufficient
185 # to get the tests going
186 return request, ('127.0.0.1', '')
187
188
189 class SilentUnixWSGIServer(UnixWSGIServer):
190
191 def handle_error(self, request, client_address):
192 pass
193
194
195 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
196 pass
197
198
199 def gen_unix_socket_path():
200 with tempfile.NamedTemporaryFile() as file:
201 return file.name
202
203
204 @contextlib.contextmanager
205 def unix_socket_path():
206 path = gen_unix_socket_path()
207 try:
208 yield path
209 finally:
210 try:
211 os.unlink(path)
212 except OSError:
213 pass
214
215
216 @contextlib.contextmanager
217 def run_test_unix_server(*, use_ssl=False):
218 with unix_socket_path() as path:
219 yield from _run_test_server(address=path, use_ssl=use_ssl,
220 server_cls=SilentUnixWSGIServer,
221 server_ssl_cls=UnixSSLWSGIServer)
222
223
224@contextlib.contextmanager
225def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
226 yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
227 server_cls=SilentWSGIServer,
228 server_ssl_cls=SSLWSGIServer)
229
230
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700231def make_test_protocol(base):
232 dct = {}
233 for name in dir(base):
234 if name.startswith('__') and name.endswith('__'):
235 # skip magic names
236 continue
Victor Stinnera1254972014-02-11 11:34:30 +0100237 dct[name] = MockCallback(return_value=None)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700238 return type('TestProtocol', (base,) + base.__bases__, dct)()
239
240
241class TestSelector(selectors.BaseSelector):
242
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100243 def __init__(self):
244 self.keys = {}
245
246 def register(self, fileobj, events, data=None):
247 key = selectors.SelectorKey(fileobj, 0, events, data)
248 self.keys[fileobj] = key
249 return key
250
251 def unregister(self, fileobj):
252 return self.keys.pop(fileobj)
253
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700254 def select(self, timeout):
255 return []
256
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100257 def get_map(self):
258 return self.keys
259
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700260
261class TestLoop(base_events.BaseEventLoop):
262 """Loop for unittests.
263
264 It manages self time directly.
265 If something scheduled to be executed later then
266 on next loop iteration after all ready handlers done
267 generator passed to __init__ is calling.
268
269 Generator should be like this:
270
271 def gen():
272 ...
273 when = yield ...
274 ... = yield time_advance
275
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500276 Value returned by yield is absolute time of next scheduled handler.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700277 Value passed to yield is time advance to move loop's time forward.
278 """
279
280 def __init__(self, gen=None):
281 super().__init__()
282
283 if gen is None:
284 def gen():
285 yield
286 self._check_on_close = False
287 else:
288 self._check_on_close = True
289
290 self._gen = gen()
291 next(self._gen)
292 self._time = 0
Victor Stinner06847d92014-02-11 09:03:47 +0100293 self._clock_resolution = 1e-9
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700294 self._timers = []
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700295 self._selector = TestSelector()
296
297 self.readers = {}
298 self.writers = {}
299 self.reset_counters()
300
301 def time(self):
302 return self._time
303
304 def advance_time(self, advance):
305 """Move test time forward."""
306 if advance:
307 self._time += advance
308
309 def close(self):
Victor Stinner29ad0112015-01-15 00:04:21 +0100310 super().close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700311 if self._check_on_close:
312 try:
313 self._gen.send(0)
314 except StopIteration:
315 pass
316 else: # pragma: no cover
317 raise AssertionError("Time generator is not finished")
318
319 def add_reader(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500320 self.readers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700321
322 def remove_reader(self, fd):
323 self.remove_reader_count[fd] += 1
324 if fd in self.readers:
325 del self.readers[fd]
326 return True
327 else:
328 return False
329
330 def assert_reader(self, fd, callback, *args):
331 assert fd in self.readers, 'fd {} is not registered'.format(fd)
332 handle = self.readers[fd]
333 assert handle._callback == callback, '{!r} != {!r}'.format(
334 handle._callback, callback)
335 assert handle._args == args, '{!r} != {!r}'.format(
336 handle._args, args)
337
338 def add_writer(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500339 self.writers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700340
341 def remove_writer(self, fd):
342 self.remove_writer_count[fd] += 1
343 if fd in self.writers:
344 del self.writers[fd]
345 return True
346 else:
347 return False
348
349 def assert_writer(self, fd, callback, *args):
350 assert fd in self.writers, 'fd {} is not registered'.format(fd)
351 handle = self.writers[fd]
352 assert handle._callback == callback, '{!r} != {!r}'.format(
353 handle._callback, callback)
354 assert handle._args == args, '{!r} != {!r}'.format(
355 handle._args, args)
356
357 def reset_counters(self):
358 self.remove_reader_count = collections.defaultdict(int)
359 self.remove_writer_count = collections.defaultdict(int)
360
361 def _run_once(self):
362 super()._run_once()
363 for when in self._timers:
364 advance = self._gen.send(when)
365 self.advance_time(advance)
366 self._timers = []
367
368 def call_at(self, when, callback, *args):
369 self._timers.append(when)
370 return super().call_at(when, callback, *args)
371
372 def _process_events(self, event_list):
373 return
374
375 def _write_to_self(self):
376 pass
Victor Stinnera1254972014-02-11 11:34:30 +0100377
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500378
Victor Stinnera1254972014-02-11 11:34:30 +0100379def MockCallback(**kwargs):
Victor Stinner24ba2032014-02-26 10:25:02 +0100380 return mock.Mock(spec=['__call__'], **kwargs)
Yury Selivanovff827f02014-02-18 18:02:19 -0500381
382
383class MockPattern(str):
384 """A regex based str with a fuzzy __eq__.
385
386 Use this helper with 'mock.assert_called_with', or anywhere
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500387 where a regex comparison between strings is needed.
Yury Selivanovff827f02014-02-18 18:02:19 -0500388
389 For instance:
390 mock_call.assert_called_with(MockPattern('spam.*ham'))
391 """
392 def __eq__(self, other):
393 return bool(re.search(str(self), other, re.S))
Victor Stinner307bccc2014-06-12 18:39:26 +0200394
395
396def get_function_source(func):
397 source = events._get_function_source(func)
398 if source is None:
399 raise ValueError("unable to get the source of %r" % (func,))
400 return source
Victor Stinnerc73701d2014-06-18 01:36:32 +0200401
402
403class TestCase(unittest.TestCase):
404 def set_event_loop(self, loop, *, cleanup=True):
405 assert loop is not None
406 # ensure that the event loop is passed explicitly in asyncio
407 events.set_event_loop(None)
408 if cleanup:
409 self.addCleanup(loop.close)
410
411 def new_test_loop(self, gen=None):
412 loop = TestLoop(gen)
413 self.set_event_loop(loop)
414 return loop
415
416 def tearDown(self):
417 events.set_event_loop(None)
Victor Stinner1cae9ec2014-07-14 22:26:34 +0200418
Victor Stinner5d44c082015-02-02 18:36:31 +0100419 # Detect CPython bug #23353: ensure that yield/yield-from is not used
420 # in an except block of a generator
421 self.assertEqual(sys.exc_info(), (None, None, None))
422
Victor Stinner1cae9ec2014-07-14 22:26:34 +0200423
424@contextlib.contextmanager
425def disable_logger():
426 """Context manager to disable asyncio logger.
427
428 For example, it can be used to ignore warnings in debug mode.
429 """
430 old_level = logger.level
431 try:
432 logger.setLevel(logging.CRITICAL+1)
433 yield
434 finally:
435 logger.setLevel(old_level)
Victor Stinnerb2614752014-08-25 23:20:52 +0200436
437def mock_nonblocking_socket():
438 """Create a mock of a non-blocking socket."""
439 sock = mock.Mock(socket.socket)
440 sock.gettimeout.return_value = 0.0
441 return sock
Victor Stinner231b4042015-01-14 00:19:09 +0100442
443
444def force_legacy_ssl_support():
445 return mock.patch('asyncio.sslproto._is_sslproto_available',
446 return_value=False)