blob: 94054e70234cf766ff522e3d9eedda6d61ae83be [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07006import os
Yury Selivanovff827f02014-02-18 18:02:19 -05007import re
Yury Selivanov88a5bf02014-02-18 12:15:06 -05008import socket
9import socketserver
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070010import sys
Yury Selivanov88a5bf02014-02-18 12:15:06 -050011import tempfile
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070012import threading
Antoine Pitroud20afad2013-10-20 01:51:25 +020013import time
Victor Stinnerc73701d2014-06-18 01:36:32 +020014import unittest
Victor Stinner24ba2032014-02-26 10:25:02 +010015from unittest import mock
Yury Selivanov88a5bf02014-02-18 12:15:06 -050016
17from http.server import HTTPServer
Victor Stinnerda492a82014-02-20 10:37:27 +010018from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
Yury Selivanov88a5bf02014-02-18 12:15:06 -050019
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070020try:
21 import ssl
22except ImportError: # pragma: no cover
23 ssl = None
24
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070025from . import base_events
26from . import events
Victor Stinnere6a53792014-03-06 01:00:36 +010027from . import futures
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070028from . import selectors
Victor Stinnere6a53792014-03-06 01:00:36 +010029from . import tasks
Victor Stinnerf951d282014-06-29 00:46:45 +020030from .coroutines import coroutine
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070031
32
33if sys.platform == 'win32': # pragma: no cover
34 from .windows_utils import socketpair
35else:
36 from socket import socketpair # pragma: no cover
37
38
39def dummy_ssl_context():
40 if ssl is None:
41 return None
42 else:
43 return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
44
45
46def run_briefly(loop):
Victor Stinnerf951d282014-06-29 00:46:45 +020047 @coroutine
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070048 def once():
49 pass
50 gen = once()
51 t = tasks.Task(gen, loop=loop)
52 try:
53 loop.run_until_complete(t)
54 finally:
55 gen.close()
56
57
Victor Stinnere6a53792014-03-06 01:00:36 +010058def run_until(loop, pred, timeout=30):
59 deadline = time.time() + timeout
Antoine Pitroud20afad2013-10-20 01:51:25 +020060 while not pred():
61 if timeout is not None:
62 timeout = deadline - time.time()
63 if timeout <= 0:
Victor Stinnere6a53792014-03-06 01:00:36 +010064 raise futures.TimeoutError()
65 loop.run_until_complete(tasks.sleep(0.001, loop=loop))
Antoine Pitroud20afad2013-10-20 01:51:25 +020066
67
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070068def run_once(loop):
69 """loop.stop() schedules _raise_stop_error()
70 and run_forever() runs until _raise_stop_error() callback.
71 this wont work if test waits for some IO events, because
72 _raise_stop_error() runs before any of io events callbacks.
73 """
74 loop.stop()
75 loop.run_forever()
76
77
Yury Selivanov88a5bf02014-02-18 12:15:06 -050078class SilentWSGIRequestHandler(WSGIRequestHandler):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070079
Yury Selivanov88a5bf02014-02-18 12:15:06 -050080 def get_stderr(self):
81 return io.StringIO()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070082
Yury Selivanov88a5bf02014-02-18 12:15:06 -050083 def log_message(self, format, *args):
84 pass
85
86
87class SilentWSGIServer(WSGIServer):
88
89 def handle_error(self, request, client_address):
90 pass
91
92
93class SSLWSGIServerMixin:
94
95 def finish_request(self, request, client_address):
96 # The relative location of our test directory (which
97 # contains the ssl key and certificate files) differs
98 # between the stdlib and stand-alone asyncio.
99 # Prefer our own if we can find it.
100 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
101 if not os.path.isdir(here):
102 here = os.path.join(os.path.dirname(os.__file__),
103 'test', 'test_asyncio')
104 keyfile = os.path.join(here, 'ssl_key.pem')
105 certfile = os.path.join(here, 'ssl_cert.pem')
106 ssock = ssl.wrap_socket(request,
107 keyfile=keyfile,
108 certfile=certfile,
109 server_side=True)
110 try:
111 self.RequestHandlerClass(ssock, client_address, self)
112 ssock.close()
113 except OSError:
114 # maybe socket has been closed by peer
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700115 pass
116
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700117
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500118class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
119 pass
120
121
122def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700123
124 def app(environ, start_response):
125 status = '200 OK'
126 headers = [('Content-type', 'text/plain')]
127 start_response(status, headers)
128 return [b'Test message']
129
130 # Run the test WSGI server in a separate thread in order not to
131 # interfere with event handling in the main thread
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500132 server_class = server_ssl_cls if use_ssl else server_cls
133 httpd = server_class(address, SilentWSGIRequestHandler)
134 httpd.set_app(app)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700135 httpd.address = httpd.server_address
136 server_thread = threading.Thread(target=httpd.serve_forever)
137 server_thread.start()
138 try:
139 yield httpd
140 finally:
141 httpd.shutdown()
Antoine Pitroua7a150c2013-10-20 23:26:23 +0200142 httpd.server_close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700143 server_thread.join()
144
145
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500146if hasattr(socket, 'AF_UNIX'):
147
148 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
149
150 def server_bind(self):
151 socketserver.UnixStreamServer.server_bind(self)
152 self.server_name = '127.0.0.1'
153 self.server_port = 80
154
155
156 class UnixWSGIServer(UnixHTTPServer, WSGIServer):
157
158 def server_bind(self):
159 UnixHTTPServer.server_bind(self)
160 self.setup_environ()
161
162 def get_request(self):
163 request, client_addr = super().get_request()
164 # Code in the stdlib expects that get_request
165 # will return a socket and a tuple (host, port).
166 # However, this isn't true for UNIX sockets,
167 # as the second return value will be a path;
168 # hence we return some fake data sufficient
169 # to get the tests going
170 return request, ('127.0.0.1', '')
171
172
173 class SilentUnixWSGIServer(UnixWSGIServer):
174
175 def handle_error(self, request, client_address):
176 pass
177
178
179 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
180 pass
181
182
183 def gen_unix_socket_path():
184 with tempfile.NamedTemporaryFile() as file:
185 return file.name
186
187
188 @contextlib.contextmanager
189 def unix_socket_path():
190 path = gen_unix_socket_path()
191 try:
192 yield path
193 finally:
194 try:
195 os.unlink(path)
196 except OSError:
197 pass
198
199
200 @contextlib.contextmanager
201 def run_test_unix_server(*, use_ssl=False):
202 with unix_socket_path() as path:
203 yield from _run_test_server(address=path, use_ssl=use_ssl,
204 server_cls=SilentUnixWSGIServer,
205 server_ssl_cls=UnixSSLWSGIServer)
206
207
208@contextlib.contextmanager
209def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
210 yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
211 server_cls=SilentWSGIServer,
212 server_ssl_cls=SSLWSGIServer)
213
214
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700215def make_test_protocol(base):
216 dct = {}
217 for name in dir(base):
218 if name.startswith('__') and name.endswith('__'):
219 # skip magic names
220 continue
Victor Stinnera1254972014-02-11 11:34:30 +0100221 dct[name] = MockCallback(return_value=None)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700222 return type('TestProtocol', (base,) + base.__bases__, dct)()
223
224
225class TestSelector(selectors.BaseSelector):
226
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100227 def __init__(self):
228 self.keys = {}
229
230 def register(self, fileobj, events, data=None):
231 key = selectors.SelectorKey(fileobj, 0, events, data)
232 self.keys[fileobj] = key
233 return key
234
235 def unregister(self, fileobj):
236 return self.keys.pop(fileobj)
237
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700238 def select(self, timeout):
239 return []
240
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100241 def get_map(self):
242 return self.keys
243
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700244
245class TestLoop(base_events.BaseEventLoop):
246 """Loop for unittests.
247
248 It manages self time directly.
249 If something scheduled to be executed later then
250 on next loop iteration after all ready handlers done
251 generator passed to __init__ is calling.
252
253 Generator should be like this:
254
255 def gen():
256 ...
257 when = yield ...
258 ... = yield time_advance
259
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500260 Value returned by yield is absolute time of next scheduled handler.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700261 Value passed to yield is time advance to move loop's time forward.
262 """
263
264 def __init__(self, gen=None):
265 super().__init__()
266
267 if gen is None:
268 def gen():
269 yield
270 self._check_on_close = False
271 else:
272 self._check_on_close = True
273
274 self._gen = gen()
275 next(self._gen)
276 self._time = 0
Victor Stinner06847d92014-02-11 09:03:47 +0100277 self._clock_resolution = 1e-9
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700278 self._timers = []
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700279 self._selector = TestSelector()
280
281 self.readers = {}
282 self.writers = {}
283 self.reset_counters()
284
285 def time(self):
286 return self._time
287
288 def advance_time(self, advance):
289 """Move test time forward."""
290 if advance:
291 self._time += advance
292
293 def close(self):
294 if self._check_on_close:
295 try:
296 self._gen.send(0)
297 except StopIteration:
298 pass
299 else: # pragma: no cover
300 raise AssertionError("Time generator is not finished")
301
302 def add_reader(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500303 self.readers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700304
305 def remove_reader(self, fd):
306 self.remove_reader_count[fd] += 1
307 if fd in self.readers:
308 del self.readers[fd]
309 return True
310 else:
311 return False
312
313 def assert_reader(self, fd, callback, *args):
314 assert fd in self.readers, 'fd {} is not registered'.format(fd)
315 handle = self.readers[fd]
316 assert handle._callback == callback, '{!r} != {!r}'.format(
317 handle._callback, callback)
318 assert handle._args == args, '{!r} != {!r}'.format(
319 handle._args, args)
320
321 def add_writer(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500322 self.writers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700323
324 def remove_writer(self, fd):
325 self.remove_writer_count[fd] += 1
326 if fd in self.writers:
327 del self.writers[fd]
328 return True
329 else:
330 return False
331
332 def assert_writer(self, fd, callback, *args):
333 assert fd in self.writers, 'fd {} is not registered'.format(fd)
334 handle = self.writers[fd]
335 assert handle._callback == callback, '{!r} != {!r}'.format(
336 handle._callback, callback)
337 assert handle._args == args, '{!r} != {!r}'.format(
338 handle._args, args)
339
340 def reset_counters(self):
341 self.remove_reader_count = collections.defaultdict(int)
342 self.remove_writer_count = collections.defaultdict(int)
343
344 def _run_once(self):
345 super()._run_once()
346 for when in self._timers:
347 advance = self._gen.send(when)
348 self.advance_time(advance)
349 self._timers = []
350
351 def call_at(self, when, callback, *args):
352 self._timers.append(when)
353 return super().call_at(when, callback, *args)
354
355 def _process_events(self, event_list):
356 return
357
358 def _write_to_self(self):
359 pass
Victor Stinnera1254972014-02-11 11:34:30 +0100360
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500361
Victor Stinnera1254972014-02-11 11:34:30 +0100362def MockCallback(**kwargs):
Victor Stinner24ba2032014-02-26 10:25:02 +0100363 return mock.Mock(spec=['__call__'], **kwargs)
Yury Selivanovff827f02014-02-18 18:02:19 -0500364
365
366class MockPattern(str):
367 """A regex based str with a fuzzy __eq__.
368
369 Use this helper with 'mock.assert_called_with', or anywhere
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500370 where a regex comparison between strings is needed.
Yury Selivanovff827f02014-02-18 18:02:19 -0500371
372 For instance:
373 mock_call.assert_called_with(MockPattern('spam.*ham'))
374 """
375 def __eq__(self, other):
376 return bool(re.search(str(self), other, re.S))
Victor Stinner307bccc2014-06-12 18:39:26 +0200377
378
379def get_function_source(func):
380 source = events._get_function_source(func)
381 if source is None:
382 raise ValueError("unable to get the source of %r" % (func,))
383 return source
Victor Stinnerc73701d2014-06-18 01:36:32 +0200384
385
386class TestCase(unittest.TestCase):
387 def set_event_loop(self, loop, *, cleanup=True):
388 assert loop is not None
389 # ensure that the event loop is passed explicitly in asyncio
390 events.set_event_loop(None)
391 if cleanup:
392 self.addCleanup(loop.close)
393
394 def new_test_loop(self, gen=None):
395 loop = TestLoop(gen)
396 self.set_event_loop(loop)
397 return loop
398
399 def tearDown(self):
400 events.set_event_loop(None)