blob: d9c7ae2d114d11fe1ef6233a29d0e2430000716b [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07006import os
Yury Selivanovff827f02014-02-18 18:02:19 -05007import re
Yury Selivanov88a5bf02014-02-18 12:15:06 -05008import socket
9import socketserver
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070010import sys
Yury Selivanov88a5bf02014-02-18 12:15:06 -050011import tempfile
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070012import threading
Antoine Pitroud20afad2013-10-20 01:51:25 +020013import time
Victor Stinnerc73701d2014-06-18 01:36:32 +020014import unittest
Victor Stinner24ba2032014-02-26 10:25:02 +010015from unittest import mock
Yury Selivanov88a5bf02014-02-18 12:15:06 -050016
17from http.server import HTTPServer
Victor Stinnerda492a82014-02-20 10:37:27 +010018from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
Yury Selivanov88a5bf02014-02-18 12:15:06 -050019
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070020try:
21 import ssl
22except ImportError: # pragma: no cover
23 ssl = None
24
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070025from . import base_events
26from . import events
Victor Stinnere6a53792014-03-06 01:00:36 +010027from . import futures
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070028from . import selectors
Victor Stinnere6a53792014-03-06 01:00:36 +010029from . import tasks
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070030
31
32if sys.platform == 'win32': # pragma: no cover
33 from .windows_utils import socketpair
34else:
35 from socket import socketpair # pragma: no cover
36
37
38def dummy_ssl_context():
39 if ssl is None:
40 return None
41 else:
42 return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
43
44
45def run_briefly(loop):
46 @tasks.coroutine
47 def once():
48 pass
49 gen = once()
50 t = tasks.Task(gen, loop=loop)
51 try:
52 loop.run_until_complete(t)
53 finally:
54 gen.close()
55
56
Victor Stinnere6a53792014-03-06 01:00:36 +010057def run_until(loop, pred, timeout=30):
58 deadline = time.time() + timeout
Antoine Pitroud20afad2013-10-20 01:51:25 +020059 while not pred():
60 if timeout is not None:
61 timeout = deadline - time.time()
62 if timeout <= 0:
Victor Stinnere6a53792014-03-06 01:00:36 +010063 raise futures.TimeoutError()
64 loop.run_until_complete(tasks.sleep(0.001, loop=loop))
Antoine Pitroud20afad2013-10-20 01:51:25 +020065
66
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070067def run_once(loop):
68 """loop.stop() schedules _raise_stop_error()
69 and run_forever() runs until _raise_stop_error() callback.
70 this wont work if test waits for some IO events, because
71 _raise_stop_error() runs before any of io events callbacks.
72 """
73 loop.stop()
74 loop.run_forever()
75
76
Yury Selivanov88a5bf02014-02-18 12:15:06 -050077class SilentWSGIRequestHandler(WSGIRequestHandler):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070078
Yury Selivanov88a5bf02014-02-18 12:15:06 -050079 def get_stderr(self):
80 return io.StringIO()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070081
Yury Selivanov88a5bf02014-02-18 12:15:06 -050082 def log_message(self, format, *args):
83 pass
84
85
86class SilentWSGIServer(WSGIServer):
87
88 def handle_error(self, request, client_address):
89 pass
90
91
92class SSLWSGIServerMixin:
93
94 def finish_request(self, request, client_address):
95 # The relative location of our test directory (which
96 # contains the ssl key and certificate files) differs
97 # between the stdlib and stand-alone asyncio.
98 # Prefer our own if we can find it.
99 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
100 if not os.path.isdir(here):
101 here = os.path.join(os.path.dirname(os.__file__),
102 'test', 'test_asyncio')
103 keyfile = os.path.join(here, 'ssl_key.pem')
104 certfile = os.path.join(here, 'ssl_cert.pem')
105 ssock = ssl.wrap_socket(request,
106 keyfile=keyfile,
107 certfile=certfile,
108 server_side=True)
109 try:
110 self.RequestHandlerClass(ssock, client_address, self)
111 ssock.close()
112 except OSError:
113 # maybe socket has been closed by peer
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700114 pass
115
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700116
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500117class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
118 pass
119
120
121def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700122
123 def app(environ, start_response):
124 status = '200 OK'
125 headers = [('Content-type', 'text/plain')]
126 start_response(status, headers)
127 return [b'Test message']
128
129 # Run the test WSGI server in a separate thread in order not to
130 # interfere with event handling in the main thread
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500131 server_class = server_ssl_cls if use_ssl else server_cls
132 httpd = server_class(address, SilentWSGIRequestHandler)
133 httpd.set_app(app)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700134 httpd.address = httpd.server_address
135 server_thread = threading.Thread(target=httpd.serve_forever)
136 server_thread.start()
137 try:
138 yield httpd
139 finally:
140 httpd.shutdown()
Antoine Pitroua7a150c2013-10-20 23:26:23 +0200141 httpd.server_close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700142 server_thread.join()
143
144
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500145if hasattr(socket, 'AF_UNIX'):
146
147 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
148
149 def server_bind(self):
150 socketserver.UnixStreamServer.server_bind(self)
151 self.server_name = '127.0.0.1'
152 self.server_port = 80
153
154
155 class UnixWSGIServer(UnixHTTPServer, WSGIServer):
156
157 def server_bind(self):
158 UnixHTTPServer.server_bind(self)
159 self.setup_environ()
160
161 def get_request(self):
162 request, client_addr = super().get_request()
163 # Code in the stdlib expects that get_request
164 # will return a socket and a tuple (host, port).
165 # However, this isn't true for UNIX sockets,
166 # as the second return value will be a path;
167 # hence we return some fake data sufficient
168 # to get the tests going
169 return request, ('127.0.0.1', '')
170
171
172 class SilentUnixWSGIServer(UnixWSGIServer):
173
174 def handle_error(self, request, client_address):
175 pass
176
177
178 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
179 pass
180
181
182 def gen_unix_socket_path():
183 with tempfile.NamedTemporaryFile() as file:
184 return file.name
185
186
187 @contextlib.contextmanager
188 def unix_socket_path():
189 path = gen_unix_socket_path()
190 try:
191 yield path
192 finally:
193 try:
194 os.unlink(path)
195 except OSError:
196 pass
197
198
199 @contextlib.contextmanager
200 def run_test_unix_server(*, use_ssl=False):
201 with unix_socket_path() as path:
202 yield from _run_test_server(address=path, use_ssl=use_ssl,
203 server_cls=SilentUnixWSGIServer,
204 server_ssl_cls=UnixSSLWSGIServer)
205
206
207@contextlib.contextmanager
208def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
209 yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
210 server_cls=SilentWSGIServer,
211 server_ssl_cls=SSLWSGIServer)
212
213
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700214def make_test_protocol(base):
215 dct = {}
216 for name in dir(base):
217 if name.startswith('__') and name.endswith('__'):
218 # skip magic names
219 continue
Victor Stinnera1254972014-02-11 11:34:30 +0100220 dct[name] = MockCallback(return_value=None)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700221 return type('TestProtocol', (base,) + base.__bases__, dct)()
222
223
224class TestSelector(selectors.BaseSelector):
225
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100226 def __init__(self):
227 self.keys = {}
228
229 def register(self, fileobj, events, data=None):
230 key = selectors.SelectorKey(fileobj, 0, events, data)
231 self.keys[fileobj] = key
232 return key
233
234 def unregister(self, fileobj):
235 return self.keys.pop(fileobj)
236
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700237 def select(self, timeout):
238 return []
239
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100240 def get_map(self):
241 return self.keys
242
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700243
244class TestLoop(base_events.BaseEventLoop):
245 """Loop for unittests.
246
247 It manages self time directly.
248 If something scheduled to be executed later then
249 on next loop iteration after all ready handlers done
250 generator passed to __init__ is calling.
251
252 Generator should be like this:
253
254 def gen():
255 ...
256 when = yield ...
257 ... = yield time_advance
258
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500259 Value returned by yield is absolute time of next scheduled handler.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700260 Value passed to yield is time advance to move loop's time forward.
261 """
262
263 def __init__(self, gen=None):
264 super().__init__()
265
266 if gen is None:
267 def gen():
268 yield
269 self._check_on_close = False
270 else:
271 self._check_on_close = True
272
273 self._gen = gen()
274 next(self._gen)
275 self._time = 0
Victor Stinner06847d92014-02-11 09:03:47 +0100276 self._clock_resolution = 1e-9
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700277 self._timers = []
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700278 self._selector = TestSelector()
279
280 self.readers = {}
281 self.writers = {}
282 self.reset_counters()
283
284 def time(self):
285 return self._time
286
287 def advance_time(self, advance):
288 """Move test time forward."""
289 if advance:
290 self._time += advance
291
292 def close(self):
293 if self._check_on_close:
294 try:
295 self._gen.send(0)
296 except StopIteration:
297 pass
298 else: # pragma: no cover
299 raise AssertionError("Time generator is not finished")
300
301 def add_reader(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500302 self.readers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700303
304 def remove_reader(self, fd):
305 self.remove_reader_count[fd] += 1
306 if fd in self.readers:
307 del self.readers[fd]
308 return True
309 else:
310 return False
311
312 def assert_reader(self, fd, callback, *args):
313 assert fd in self.readers, 'fd {} is not registered'.format(fd)
314 handle = self.readers[fd]
315 assert handle._callback == callback, '{!r} != {!r}'.format(
316 handle._callback, callback)
317 assert handle._args == args, '{!r} != {!r}'.format(
318 handle._args, args)
319
320 def add_writer(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500321 self.writers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700322
323 def remove_writer(self, fd):
324 self.remove_writer_count[fd] += 1
325 if fd in self.writers:
326 del self.writers[fd]
327 return True
328 else:
329 return False
330
331 def assert_writer(self, fd, callback, *args):
332 assert fd in self.writers, 'fd {} is not registered'.format(fd)
333 handle = self.writers[fd]
334 assert handle._callback == callback, '{!r} != {!r}'.format(
335 handle._callback, callback)
336 assert handle._args == args, '{!r} != {!r}'.format(
337 handle._args, args)
338
339 def reset_counters(self):
340 self.remove_reader_count = collections.defaultdict(int)
341 self.remove_writer_count = collections.defaultdict(int)
342
343 def _run_once(self):
344 super()._run_once()
345 for when in self._timers:
346 advance = self._gen.send(when)
347 self.advance_time(advance)
348 self._timers = []
349
350 def call_at(self, when, callback, *args):
351 self._timers.append(when)
352 return super().call_at(when, callback, *args)
353
354 def _process_events(self, event_list):
355 return
356
357 def _write_to_self(self):
358 pass
Victor Stinnera1254972014-02-11 11:34:30 +0100359
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500360
Victor Stinnera1254972014-02-11 11:34:30 +0100361def MockCallback(**kwargs):
Victor Stinner24ba2032014-02-26 10:25:02 +0100362 return mock.Mock(spec=['__call__'], **kwargs)
Yury Selivanovff827f02014-02-18 18:02:19 -0500363
364
365class MockPattern(str):
366 """A regex based str with a fuzzy __eq__.
367
368 Use this helper with 'mock.assert_called_with', or anywhere
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500369 where a regex comparison between strings is needed.
Yury Selivanovff827f02014-02-18 18:02:19 -0500370
371 For instance:
372 mock_call.assert_called_with(MockPattern('spam.*ham'))
373 """
374 def __eq__(self, other):
375 return bool(re.search(str(self), other, re.S))
Victor Stinner307bccc2014-06-12 18:39:26 +0200376
377
378def get_function_source(func):
379 source = events._get_function_source(func)
380 if source is None:
381 raise ValueError("unable to get the source of %r" % (func,))
382 return source
Victor Stinnerc73701d2014-06-18 01:36:32 +0200383
384
385class TestCase(unittest.TestCase):
386 def set_event_loop(self, loop, *, cleanup=True):
387 assert loop is not None
388 # ensure that the event loop is passed explicitly in asyncio
389 events.set_event_loop(None)
390 if cleanup:
391 self.addCleanup(loop.close)
392
393 def new_test_loop(self, gen=None):
394 loop = TestLoop(gen)
395 self.set_event_loop(loop)
396 return loop
397
398 def tearDown(self):
399 events.set_event_loop(None)