blob: ac7680de45387d165b03820f5994d6b268d1b4be [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
Victor Stinner1cae9ec2014-07-14 22:26:34 +02006import logging
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07007import os
Yury Selivanovff827f02014-02-18 18:02:19 -05008import re
Yury Selivanov88a5bf02014-02-18 12:15:06 -05009import socket
10import socketserver
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070011import sys
Yury Selivanov88a5bf02014-02-18 12:15:06 -050012import tempfile
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070013import threading
Antoine Pitroud20afad2013-10-20 01:51:25 +020014import time
Victor Stinnerc73701d2014-06-18 01:36:32 +020015import unittest
Victor Stinner24ba2032014-02-26 10:25:02 +010016from unittest import mock
Yury Selivanov88a5bf02014-02-18 12:15:06 -050017
18from http.server import HTTPServer
Victor Stinnerda492a82014-02-20 10:37:27 +010019from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
Yury Selivanov88a5bf02014-02-18 12:15:06 -050020
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070021try:
22 import ssl
23except ImportError: # pragma: no cover
24 ssl = None
25
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070026from . import base_events
27from . import events
Victor Stinnere6a53792014-03-06 01:00:36 +010028from . import futures
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070029from . import selectors
Victor Stinnere6a53792014-03-06 01:00:36 +010030from . import tasks
Victor Stinnerf951d282014-06-29 00:46:45 +020031from .coroutines import coroutine
Victor Stinner1cae9ec2014-07-14 22:26:34 +020032from .log import logger
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070033
34
35if sys.platform == 'win32': # pragma: no cover
36 from .windows_utils import socketpair
37else:
38 from socket import socketpair # pragma: no cover
39
40
41def dummy_ssl_context():
42 if ssl is None:
43 return None
44 else:
45 return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
46
47
48def run_briefly(loop):
Victor Stinnerf951d282014-06-29 00:46:45 +020049 @coroutine
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070050 def once():
51 pass
52 gen = once()
Victor Stinner896a25a2014-07-08 11:29:25 +020053 t = loop.create_task(gen)
Victor Stinner98b63912014-06-30 14:51:04 +020054 # Don't log a warning if the task is not done after run_until_complete().
55 # It occurs if the loop is stopped or if a task raises a BaseException.
56 t._log_destroy_pending = False
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070057 try:
58 loop.run_until_complete(t)
59 finally:
60 gen.close()
61
62
Victor Stinnere6a53792014-03-06 01:00:36 +010063def run_until(loop, pred, timeout=30):
64 deadline = time.time() + timeout
Antoine Pitroud20afad2013-10-20 01:51:25 +020065 while not pred():
66 if timeout is not None:
67 timeout = deadline - time.time()
68 if timeout <= 0:
Victor Stinnere6a53792014-03-06 01:00:36 +010069 raise futures.TimeoutError()
70 loop.run_until_complete(tasks.sleep(0.001, loop=loop))
Antoine Pitroud20afad2013-10-20 01:51:25 +020071
72
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070073def run_once(loop):
74 """loop.stop() schedules _raise_stop_error()
75 and run_forever() runs until _raise_stop_error() callback.
76 this wont work if test waits for some IO events, because
77 _raise_stop_error() runs before any of io events callbacks.
78 """
79 loop.stop()
80 loop.run_forever()
81
82
Yury Selivanov88a5bf02014-02-18 12:15:06 -050083class SilentWSGIRequestHandler(WSGIRequestHandler):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070084
Yury Selivanov88a5bf02014-02-18 12:15:06 -050085 def get_stderr(self):
86 return io.StringIO()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070087
Yury Selivanov88a5bf02014-02-18 12:15:06 -050088 def log_message(self, format, *args):
89 pass
90
91
92class SilentWSGIServer(WSGIServer):
93
94 def handle_error(self, request, client_address):
95 pass
96
97
98class SSLWSGIServerMixin:
99
100 def finish_request(self, request, client_address):
101 # The relative location of our test directory (which
102 # contains the ssl key and certificate files) differs
103 # between the stdlib and stand-alone asyncio.
104 # Prefer our own if we can find it.
105 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
106 if not os.path.isdir(here):
107 here = os.path.join(os.path.dirname(os.__file__),
108 'test', 'test_asyncio')
109 keyfile = os.path.join(here, 'ssl_key.pem')
110 certfile = os.path.join(here, 'ssl_cert.pem')
111 ssock = ssl.wrap_socket(request,
112 keyfile=keyfile,
113 certfile=certfile,
114 server_side=True)
115 try:
116 self.RequestHandlerClass(ssock, client_address, self)
117 ssock.close()
118 except OSError:
119 # maybe socket has been closed by peer
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700120 pass
121
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700122
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500123class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
124 pass
125
126
127def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700128
129 def app(environ, start_response):
130 status = '200 OK'
131 headers = [('Content-type', 'text/plain')]
132 start_response(status, headers)
133 return [b'Test message']
134
135 # Run the test WSGI server in a separate thread in order not to
136 # interfere with event handling in the main thread
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500137 server_class = server_ssl_cls if use_ssl else server_cls
138 httpd = server_class(address, SilentWSGIRequestHandler)
139 httpd.set_app(app)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700140 httpd.address = httpd.server_address
141 server_thread = threading.Thread(target=httpd.serve_forever)
142 server_thread.start()
143 try:
144 yield httpd
145 finally:
146 httpd.shutdown()
Antoine Pitroua7a150c2013-10-20 23:26:23 +0200147 httpd.server_close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700148 server_thread.join()
149
150
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500151if hasattr(socket, 'AF_UNIX'):
152
153 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
154
155 def server_bind(self):
156 socketserver.UnixStreamServer.server_bind(self)
157 self.server_name = '127.0.0.1'
158 self.server_port = 80
159
160
161 class UnixWSGIServer(UnixHTTPServer, WSGIServer):
162
163 def server_bind(self):
164 UnixHTTPServer.server_bind(self)
165 self.setup_environ()
166
167 def get_request(self):
168 request, client_addr = super().get_request()
169 # Code in the stdlib expects that get_request
170 # will return a socket and a tuple (host, port).
171 # However, this isn't true for UNIX sockets,
172 # as the second return value will be a path;
173 # hence we return some fake data sufficient
174 # to get the tests going
175 return request, ('127.0.0.1', '')
176
177
178 class SilentUnixWSGIServer(UnixWSGIServer):
179
180 def handle_error(self, request, client_address):
181 pass
182
183
184 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
185 pass
186
187
188 def gen_unix_socket_path():
189 with tempfile.NamedTemporaryFile() as file:
190 return file.name
191
192
193 @contextlib.contextmanager
194 def unix_socket_path():
195 path = gen_unix_socket_path()
196 try:
197 yield path
198 finally:
199 try:
200 os.unlink(path)
201 except OSError:
202 pass
203
204
205 @contextlib.contextmanager
206 def run_test_unix_server(*, use_ssl=False):
207 with unix_socket_path() as path:
208 yield from _run_test_server(address=path, use_ssl=use_ssl,
209 server_cls=SilentUnixWSGIServer,
210 server_ssl_cls=UnixSSLWSGIServer)
211
212
213@contextlib.contextmanager
214def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
215 yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
216 server_cls=SilentWSGIServer,
217 server_ssl_cls=SSLWSGIServer)
218
219
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700220def make_test_protocol(base):
221 dct = {}
222 for name in dir(base):
223 if name.startswith('__') and name.endswith('__'):
224 # skip magic names
225 continue
Victor Stinnera1254972014-02-11 11:34:30 +0100226 dct[name] = MockCallback(return_value=None)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700227 return type('TestProtocol', (base,) + base.__bases__, dct)()
228
229
230class TestSelector(selectors.BaseSelector):
231
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100232 def __init__(self):
233 self.keys = {}
234
235 def register(self, fileobj, events, data=None):
236 key = selectors.SelectorKey(fileobj, 0, events, data)
237 self.keys[fileobj] = key
238 return key
239
240 def unregister(self, fileobj):
241 return self.keys.pop(fileobj)
242
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700243 def select(self, timeout):
244 return []
245
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100246 def get_map(self):
247 return self.keys
248
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700249
250class TestLoop(base_events.BaseEventLoop):
251 """Loop for unittests.
252
253 It manages self time directly.
254 If something scheduled to be executed later then
255 on next loop iteration after all ready handlers done
256 generator passed to __init__ is calling.
257
258 Generator should be like this:
259
260 def gen():
261 ...
262 when = yield ...
263 ... = yield time_advance
264
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500265 Value returned by yield is absolute time of next scheduled handler.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700266 Value passed to yield is time advance to move loop's time forward.
267 """
268
269 def __init__(self, gen=None):
270 super().__init__()
271
272 if gen is None:
273 def gen():
274 yield
275 self._check_on_close = False
276 else:
277 self._check_on_close = True
278
279 self._gen = gen()
280 next(self._gen)
281 self._time = 0
Victor Stinner06847d92014-02-11 09:03:47 +0100282 self._clock_resolution = 1e-9
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700283 self._timers = []
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700284 self._selector = TestSelector()
285
286 self.readers = {}
287 self.writers = {}
288 self.reset_counters()
289
290 def time(self):
291 return self._time
292
293 def advance_time(self, advance):
294 """Move test time forward."""
295 if advance:
296 self._time += advance
297
298 def close(self):
299 if self._check_on_close:
300 try:
301 self._gen.send(0)
302 except StopIteration:
303 pass
304 else: # pragma: no cover
305 raise AssertionError("Time generator is not finished")
306
307 def add_reader(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500308 self.readers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700309
310 def remove_reader(self, fd):
311 self.remove_reader_count[fd] += 1
312 if fd in self.readers:
313 del self.readers[fd]
314 return True
315 else:
316 return False
317
318 def assert_reader(self, fd, callback, *args):
319 assert fd in self.readers, 'fd {} is not registered'.format(fd)
320 handle = self.readers[fd]
321 assert handle._callback == callback, '{!r} != {!r}'.format(
322 handle._callback, callback)
323 assert handle._args == args, '{!r} != {!r}'.format(
324 handle._args, args)
325
326 def add_writer(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500327 self.writers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700328
329 def remove_writer(self, fd):
330 self.remove_writer_count[fd] += 1
331 if fd in self.writers:
332 del self.writers[fd]
333 return True
334 else:
335 return False
336
337 def assert_writer(self, fd, callback, *args):
338 assert fd in self.writers, 'fd {} is not registered'.format(fd)
339 handle = self.writers[fd]
340 assert handle._callback == callback, '{!r} != {!r}'.format(
341 handle._callback, callback)
342 assert handle._args == args, '{!r} != {!r}'.format(
343 handle._args, args)
344
345 def reset_counters(self):
346 self.remove_reader_count = collections.defaultdict(int)
347 self.remove_writer_count = collections.defaultdict(int)
348
349 def _run_once(self):
350 super()._run_once()
351 for when in self._timers:
352 advance = self._gen.send(when)
353 self.advance_time(advance)
354 self._timers = []
355
356 def call_at(self, when, callback, *args):
357 self._timers.append(when)
358 return super().call_at(when, callback, *args)
359
360 def _process_events(self, event_list):
361 return
362
363 def _write_to_self(self):
364 pass
Victor Stinnera1254972014-02-11 11:34:30 +0100365
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500366
Victor Stinnera1254972014-02-11 11:34:30 +0100367def MockCallback(**kwargs):
Victor Stinner24ba2032014-02-26 10:25:02 +0100368 return mock.Mock(spec=['__call__'], **kwargs)
Yury Selivanovff827f02014-02-18 18:02:19 -0500369
370
371class MockPattern(str):
372 """A regex based str with a fuzzy __eq__.
373
374 Use this helper with 'mock.assert_called_with', or anywhere
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500375 where a regex comparison between strings is needed.
Yury Selivanovff827f02014-02-18 18:02:19 -0500376
377 For instance:
378 mock_call.assert_called_with(MockPattern('spam.*ham'))
379 """
380 def __eq__(self, other):
381 return bool(re.search(str(self), other, re.S))
Victor Stinner307bccc2014-06-12 18:39:26 +0200382
383
384def get_function_source(func):
385 source = events._get_function_source(func)
386 if source is None:
387 raise ValueError("unable to get the source of %r" % (func,))
388 return source
Victor Stinnerc73701d2014-06-18 01:36:32 +0200389
390
391class TestCase(unittest.TestCase):
392 def set_event_loop(self, loop, *, cleanup=True):
393 assert loop is not None
394 # ensure that the event loop is passed explicitly in asyncio
395 events.set_event_loop(None)
396 if cleanup:
397 self.addCleanup(loop.close)
398
399 def new_test_loop(self, gen=None):
400 loop = TestLoop(gen)
401 self.set_event_loop(loop)
402 return loop
403
404 def tearDown(self):
405 events.set_event_loop(None)
Victor Stinner1cae9ec2014-07-14 22:26:34 +0200406
407
408@contextlib.contextmanager
409def disable_logger():
410 """Context manager to disable asyncio logger.
411
412 For example, it can be used to ignore warnings in debug mode.
413 """
414 old_level = logger.level
415 try:
416 logger.setLevel(logging.CRITICAL+1)
417 yield
418 finally:
419 logger.setLevel(old_level)
Victor Stinnerb2614752014-08-25 23:20:52 +0200420
421def mock_nonblocking_socket():
422 """Create a mock of a non-blocking socket."""
423 sock = mock.Mock(socket.socket)
424 sock.gettimeout.return_value = 0.0
425 return sock