blob: 6abcaf1d37970a5de2ce0b8bb3ea140683a1a882 [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07006import os
Yury Selivanovff827f02014-02-18 18:02:19 -05007import re
Yury Selivanov88a5bf02014-02-18 12:15:06 -05008import socket
9import socketserver
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070010import sys
Yury Selivanov88a5bf02014-02-18 12:15:06 -050011import tempfile
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070012import threading
Antoine Pitroud20afad2013-10-20 01:51:25 +020013import time
Victor Stinnerc73701d2014-06-18 01:36:32 +020014import unittest
Victor Stinner24ba2032014-02-26 10:25:02 +010015from unittest import mock
Yury Selivanov88a5bf02014-02-18 12:15:06 -050016
17from http.server import HTTPServer
Victor Stinnerda492a82014-02-20 10:37:27 +010018from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
Yury Selivanov88a5bf02014-02-18 12:15:06 -050019
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070020try:
21 import ssl
22except ImportError: # pragma: no cover
23 ssl = None
24
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070025from . import base_events
26from . import events
Victor Stinnere6a53792014-03-06 01:00:36 +010027from . import futures
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070028from . import selectors
Victor Stinnere6a53792014-03-06 01:00:36 +010029from . import tasks
Victor Stinnerf951d282014-06-29 00:46:45 +020030from .coroutines import coroutine
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070031
32
33if sys.platform == 'win32': # pragma: no cover
34 from .windows_utils import socketpair
35else:
36 from socket import socketpair # pragma: no cover
37
38
39def dummy_ssl_context():
40 if ssl is None:
41 return None
42 else:
43 return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
44
45
46def run_briefly(loop):
Victor Stinnerf951d282014-06-29 00:46:45 +020047 @coroutine
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070048 def once():
49 pass
50 gen = once()
Victor Stinner896a25a2014-07-08 11:29:25 +020051 t = loop.create_task(gen)
Victor Stinner98b63912014-06-30 14:51:04 +020052 # Don't log a warning if the task is not done after run_until_complete().
53 # It occurs if the loop is stopped or if a task raises a BaseException.
54 t._log_destroy_pending = False
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070055 try:
56 loop.run_until_complete(t)
57 finally:
58 gen.close()
59
60
Victor Stinnere6a53792014-03-06 01:00:36 +010061def run_until(loop, pred, timeout=30):
62 deadline = time.time() + timeout
Antoine Pitroud20afad2013-10-20 01:51:25 +020063 while not pred():
64 if timeout is not None:
65 timeout = deadline - time.time()
66 if timeout <= 0:
Victor Stinnere6a53792014-03-06 01:00:36 +010067 raise futures.TimeoutError()
68 loop.run_until_complete(tasks.sleep(0.001, loop=loop))
Antoine Pitroud20afad2013-10-20 01:51:25 +020069
70
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070071def run_once(loop):
72 """loop.stop() schedules _raise_stop_error()
73 and run_forever() runs until _raise_stop_error() callback.
74 this wont work if test waits for some IO events, because
75 _raise_stop_error() runs before any of io events callbacks.
76 """
77 loop.stop()
78 loop.run_forever()
79
80
Yury Selivanov88a5bf02014-02-18 12:15:06 -050081class SilentWSGIRequestHandler(WSGIRequestHandler):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070082
Yury Selivanov88a5bf02014-02-18 12:15:06 -050083 def get_stderr(self):
84 return io.StringIO()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070085
Yury Selivanov88a5bf02014-02-18 12:15:06 -050086 def log_message(self, format, *args):
87 pass
88
89
90class SilentWSGIServer(WSGIServer):
91
92 def handle_error(self, request, client_address):
93 pass
94
95
96class SSLWSGIServerMixin:
97
98 def finish_request(self, request, client_address):
99 # The relative location of our test directory (which
100 # contains the ssl key and certificate files) differs
101 # between the stdlib and stand-alone asyncio.
102 # Prefer our own if we can find it.
103 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
104 if not os.path.isdir(here):
105 here = os.path.join(os.path.dirname(os.__file__),
106 'test', 'test_asyncio')
107 keyfile = os.path.join(here, 'ssl_key.pem')
108 certfile = os.path.join(here, 'ssl_cert.pem')
109 ssock = ssl.wrap_socket(request,
110 keyfile=keyfile,
111 certfile=certfile,
112 server_side=True)
113 try:
114 self.RequestHandlerClass(ssock, client_address, self)
115 ssock.close()
116 except OSError:
117 # maybe socket has been closed by peer
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700118 pass
119
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700120
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500121class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
122 pass
123
124
125def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700126
127 def app(environ, start_response):
128 status = '200 OK'
129 headers = [('Content-type', 'text/plain')]
130 start_response(status, headers)
131 return [b'Test message']
132
133 # Run the test WSGI server in a separate thread in order not to
134 # interfere with event handling in the main thread
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500135 server_class = server_ssl_cls if use_ssl else server_cls
136 httpd = server_class(address, SilentWSGIRequestHandler)
137 httpd.set_app(app)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700138 httpd.address = httpd.server_address
139 server_thread = threading.Thread(target=httpd.serve_forever)
140 server_thread.start()
141 try:
142 yield httpd
143 finally:
144 httpd.shutdown()
Antoine Pitroua7a150c2013-10-20 23:26:23 +0200145 httpd.server_close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700146 server_thread.join()
147
148
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500149if hasattr(socket, 'AF_UNIX'):
150
151 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
152
153 def server_bind(self):
154 socketserver.UnixStreamServer.server_bind(self)
155 self.server_name = '127.0.0.1'
156 self.server_port = 80
157
158
159 class UnixWSGIServer(UnixHTTPServer, WSGIServer):
160
161 def server_bind(self):
162 UnixHTTPServer.server_bind(self)
163 self.setup_environ()
164
165 def get_request(self):
166 request, client_addr = super().get_request()
167 # Code in the stdlib expects that get_request
168 # will return a socket and a tuple (host, port).
169 # However, this isn't true for UNIX sockets,
170 # as the second return value will be a path;
171 # hence we return some fake data sufficient
172 # to get the tests going
173 return request, ('127.0.0.1', '')
174
175
176 class SilentUnixWSGIServer(UnixWSGIServer):
177
178 def handle_error(self, request, client_address):
179 pass
180
181
182 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
183 pass
184
185
186 def gen_unix_socket_path():
187 with tempfile.NamedTemporaryFile() as file:
188 return file.name
189
190
191 @contextlib.contextmanager
192 def unix_socket_path():
193 path = gen_unix_socket_path()
194 try:
195 yield path
196 finally:
197 try:
198 os.unlink(path)
199 except OSError:
200 pass
201
202
203 @contextlib.contextmanager
204 def run_test_unix_server(*, use_ssl=False):
205 with unix_socket_path() as path:
206 yield from _run_test_server(address=path, use_ssl=use_ssl,
207 server_cls=SilentUnixWSGIServer,
208 server_ssl_cls=UnixSSLWSGIServer)
209
210
211@contextlib.contextmanager
212def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
213 yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
214 server_cls=SilentWSGIServer,
215 server_ssl_cls=SSLWSGIServer)
216
217
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700218def make_test_protocol(base):
219 dct = {}
220 for name in dir(base):
221 if name.startswith('__') and name.endswith('__'):
222 # skip magic names
223 continue
Victor Stinnera1254972014-02-11 11:34:30 +0100224 dct[name] = MockCallback(return_value=None)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700225 return type('TestProtocol', (base,) + base.__bases__, dct)()
226
227
228class TestSelector(selectors.BaseSelector):
229
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100230 def __init__(self):
231 self.keys = {}
232
233 def register(self, fileobj, events, data=None):
234 key = selectors.SelectorKey(fileobj, 0, events, data)
235 self.keys[fileobj] = key
236 return key
237
238 def unregister(self, fileobj):
239 return self.keys.pop(fileobj)
240
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700241 def select(self, timeout):
242 return []
243
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100244 def get_map(self):
245 return self.keys
246
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700247
248class TestLoop(base_events.BaseEventLoop):
249 """Loop for unittests.
250
251 It manages self time directly.
252 If something scheduled to be executed later then
253 on next loop iteration after all ready handlers done
254 generator passed to __init__ is calling.
255
256 Generator should be like this:
257
258 def gen():
259 ...
260 when = yield ...
261 ... = yield time_advance
262
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500263 Value returned by yield is absolute time of next scheduled handler.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700264 Value passed to yield is time advance to move loop's time forward.
265 """
266
267 def __init__(self, gen=None):
268 super().__init__()
269
270 if gen is None:
271 def gen():
272 yield
273 self._check_on_close = False
274 else:
275 self._check_on_close = True
276
277 self._gen = gen()
278 next(self._gen)
279 self._time = 0
Victor Stinner06847d92014-02-11 09:03:47 +0100280 self._clock_resolution = 1e-9
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700281 self._timers = []
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700282 self._selector = TestSelector()
283
284 self.readers = {}
285 self.writers = {}
286 self.reset_counters()
287
288 def time(self):
289 return self._time
290
291 def advance_time(self, advance):
292 """Move test time forward."""
293 if advance:
294 self._time += advance
295
296 def close(self):
297 if self._check_on_close:
298 try:
299 self._gen.send(0)
300 except StopIteration:
301 pass
302 else: # pragma: no cover
303 raise AssertionError("Time generator is not finished")
304
305 def add_reader(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500306 self.readers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700307
308 def remove_reader(self, fd):
309 self.remove_reader_count[fd] += 1
310 if fd in self.readers:
311 del self.readers[fd]
312 return True
313 else:
314 return False
315
316 def assert_reader(self, fd, callback, *args):
317 assert fd in self.readers, 'fd {} is not registered'.format(fd)
318 handle = self.readers[fd]
319 assert handle._callback == callback, '{!r} != {!r}'.format(
320 handle._callback, callback)
321 assert handle._args == args, '{!r} != {!r}'.format(
322 handle._args, args)
323
324 def add_writer(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500325 self.writers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700326
327 def remove_writer(self, fd):
328 self.remove_writer_count[fd] += 1
329 if fd in self.writers:
330 del self.writers[fd]
331 return True
332 else:
333 return False
334
335 def assert_writer(self, fd, callback, *args):
336 assert fd in self.writers, 'fd {} is not registered'.format(fd)
337 handle = self.writers[fd]
338 assert handle._callback == callback, '{!r} != {!r}'.format(
339 handle._callback, callback)
340 assert handle._args == args, '{!r} != {!r}'.format(
341 handle._args, args)
342
343 def reset_counters(self):
344 self.remove_reader_count = collections.defaultdict(int)
345 self.remove_writer_count = collections.defaultdict(int)
346
347 def _run_once(self):
348 super()._run_once()
349 for when in self._timers:
350 advance = self._gen.send(when)
351 self.advance_time(advance)
352 self._timers = []
353
354 def call_at(self, when, callback, *args):
355 self._timers.append(when)
356 return super().call_at(when, callback, *args)
357
358 def _process_events(self, event_list):
359 return
360
361 def _write_to_self(self):
362 pass
Victor Stinnera1254972014-02-11 11:34:30 +0100363
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500364
Victor Stinnera1254972014-02-11 11:34:30 +0100365def MockCallback(**kwargs):
Victor Stinner24ba2032014-02-26 10:25:02 +0100366 return mock.Mock(spec=['__call__'], **kwargs)
Yury Selivanovff827f02014-02-18 18:02:19 -0500367
368
369class MockPattern(str):
370 """A regex based str with a fuzzy __eq__.
371
372 Use this helper with 'mock.assert_called_with', or anywhere
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500373 where a regex comparison between strings is needed.
Yury Selivanovff827f02014-02-18 18:02:19 -0500374
375 For instance:
376 mock_call.assert_called_with(MockPattern('spam.*ham'))
377 """
378 def __eq__(self, other):
379 return bool(re.search(str(self), other, re.S))
Victor Stinner307bccc2014-06-12 18:39:26 +0200380
381
382def get_function_source(func):
383 source = events._get_function_source(func)
384 if source is None:
385 raise ValueError("unable to get the source of %r" % (func,))
386 return source
Victor Stinnerc73701d2014-06-18 01:36:32 +0200387
388
389class TestCase(unittest.TestCase):
390 def set_event_loop(self, loop, *, cleanup=True):
391 assert loop is not None
392 # ensure that the event loop is passed explicitly in asyncio
393 events.set_event_loop(None)
394 if cleanup:
395 self.addCleanup(loop.close)
396
397 def new_test_loop(self, gen=None):
398 loop = TestLoop(gen)
399 self.set_event_loop(loop)
400 return loop
401
402 def tearDown(self):
403 events.set_event_loop(None)