blob: 9c3656ac2bea23a626995c5838be7be26052ff08 [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07006import os
Yury Selivanovff827f02014-02-18 18:02:19 -05007import re
Yury Selivanov88a5bf02014-02-18 12:15:06 -05008import socket
9import socketserver
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070010import sys
Yury Selivanov88a5bf02014-02-18 12:15:06 -050011import tempfile
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070012import threading
Antoine Pitroud20afad2013-10-20 01:51:25 +020013import time
Victor Stinner24ba2032014-02-26 10:25:02 +010014from unittest import mock
Yury Selivanov88a5bf02014-02-18 12:15:06 -050015
16from http.server import HTTPServer
Victor Stinnerda492a82014-02-20 10:37:27 +010017from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
Yury Selivanov88a5bf02014-02-18 12:15:06 -050018
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070019try:
20 import ssl
21except ImportError: # pragma: no cover
22 ssl = None
23
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070024from . import base_events
25from . import events
Victor Stinnere6a53792014-03-06 01:00:36 +010026from . import futures
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070027from . import selectors
Victor Stinnere6a53792014-03-06 01:00:36 +010028from . import tasks
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070029
30
31if sys.platform == 'win32': # pragma: no cover
32 from .windows_utils import socketpair
33else:
34 from socket import socketpair # pragma: no cover
35
36
37def dummy_ssl_context():
38 if ssl is None:
39 return None
40 else:
41 return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
42
43
44def run_briefly(loop):
45 @tasks.coroutine
46 def once():
47 pass
48 gen = once()
49 t = tasks.Task(gen, loop=loop)
50 try:
51 loop.run_until_complete(t)
52 finally:
53 gen.close()
54
55
Victor Stinnere6a53792014-03-06 01:00:36 +010056def run_until(loop, pred, timeout=30):
57 deadline = time.time() + timeout
Antoine Pitroud20afad2013-10-20 01:51:25 +020058 while not pred():
59 if timeout is not None:
60 timeout = deadline - time.time()
61 if timeout <= 0:
Victor Stinnere6a53792014-03-06 01:00:36 +010062 raise futures.TimeoutError()
63 loop.run_until_complete(tasks.sleep(0.001, loop=loop))
Antoine Pitroud20afad2013-10-20 01:51:25 +020064
65
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070066def run_once(loop):
67 """loop.stop() schedules _raise_stop_error()
68 and run_forever() runs until _raise_stop_error() callback.
69 this wont work if test waits for some IO events, because
70 _raise_stop_error() runs before any of io events callbacks.
71 """
72 loop.stop()
73 loop.run_forever()
74
75
Yury Selivanov88a5bf02014-02-18 12:15:06 -050076class SilentWSGIRequestHandler(WSGIRequestHandler):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070077
Yury Selivanov88a5bf02014-02-18 12:15:06 -050078 def get_stderr(self):
79 return io.StringIO()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070080
Yury Selivanov88a5bf02014-02-18 12:15:06 -050081 def log_message(self, format, *args):
82 pass
83
84
85class SilentWSGIServer(WSGIServer):
86
87 def handle_error(self, request, client_address):
88 pass
89
90
91class SSLWSGIServerMixin:
92
93 def finish_request(self, request, client_address):
94 # The relative location of our test directory (which
95 # contains the ssl key and certificate files) differs
96 # between the stdlib and stand-alone asyncio.
97 # Prefer our own if we can find it.
98 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
99 if not os.path.isdir(here):
100 here = os.path.join(os.path.dirname(os.__file__),
101 'test', 'test_asyncio')
102 keyfile = os.path.join(here, 'ssl_key.pem')
103 certfile = os.path.join(here, 'ssl_cert.pem')
104 ssock = ssl.wrap_socket(request,
105 keyfile=keyfile,
106 certfile=certfile,
107 server_side=True)
108 try:
109 self.RequestHandlerClass(ssock, client_address, self)
110 ssock.close()
111 except OSError:
112 # maybe socket has been closed by peer
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700113 pass
114
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700115
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500116class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
117 pass
118
119
120def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700121
122 def app(environ, start_response):
123 status = '200 OK'
124 headers = [('Content-type', 'text/plain')]
125 start_response(status, headers)
126 return [b'Test message']
127
128 # Run the test WSGI server in a separate thread in order not to
129 # interfere with event handling in the main thread
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500130 server_class = server_ssl_cls if use_ssl else server_cls
131 httpd = server_class(address, SilentWSGIRequestHandler)
132 httpd.set_app(app)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700133 httpd.address = httpd.server_address
134 server_thread = threading.Thread(target=httpd.serve_forever)
135 server_thread.start()
136 try:
137 yield httpd
138 finally:
139 httpd.shutdown()
Antoine Pitroua7a150c2013-10-20 23:26:23 +0200140 httpd.server_close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700141 server_thread.join()
142
143
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500144if hasattr(socket, 'AF_UNIX'):
145
146 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
147
148 def server_bind(self):
149 socketserver.UnixStreamServer.server_bind(self)
150 self.server_name = '127.0.0.1'
151 self.server_port = 80
152
153
154 class UnixWSGIServer(UnixHTTPServer, WSGIServer):
155
156 def server_bind(self):
157 UnixHTTPServer.server_bind(self)
158 self.setup_environ()
159
160 def get_request(self):
161 request, client_addr = super().get_request()
162 # Code in the stdlib expects that get_request
163 # will return a socket and a tuple (host, port).
164 # However, this isn't true for UNIX sockets,
165 # as the second return value will be a path;
166 # hence we return some fake data sufficient
167 # to get the tests going
168 return request, ('127.0.0.1', '')
169
170
171 class SilentUnixWSGIServer(UnixWSGIServer):
172
173 def handle_error(self, request, client_address):
174 pass
175
176
177 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
178 pass
179
180
181 def gen_unix_socket_path():
182 with tempfile.NamedTemporaryFile() as file:
183 return file.name
184
185
186 @contextlib.contextmanager
187 def unix_socket_path():
188 path = gen_unix_socket_path()
189 try:
190 yield path
191 finally:
192 try:
193 os.unlink(path)
194 except OSError:
195 pass
196
197
198 @contextlib.contextmanager
199 def run_test_unix_server(*, use_ssl=False):
200 with unix_socket_path() as path:
201 yield from _run_test_server(address=path, use_ssl=use_ssl,
202 server_cls=SilentUnixWSGIServer,
203 server_ssl_cls=UnixSSLWSGIServer)
204
205
206@contextlib.contextmanager
207def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
208 yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
209 server_cls=SilentWSGIServer,
210 server_ssl_cls=SSLWSGIServer)
211
212
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700213def make_test_protocol(base):
214 dct = {}
215 for name in dir(base):
216 if name.startswith('__') and name.endswith('__'):
217 # skip magic names
218 continue
Victor Stinnera1254972014-02-11 11:34:30 +0100219 dct[name] = MockCallback(return_value=None)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700220 return type('TestProtocol', (base,) + base.__bases__, dct)()
221
222
223class TestSelector(selectors.BaseSelector):
224
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100225 def __init__(self):
226 self.keys = {}
227
228 def register(self, fileobj, events, data=None):
229 key = selectors.SelectorKey(fileobj, 0, events, data)
230 self.keys[fileobj] = key
231 return key
232
233 def unregister(self, fileobj):
234 return self.keys.pop(fileobj)
235
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700236 def select(self, timeout):
237 return []
238
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100239 def get_map(self):
240 return self.keys
241
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700242
243class TestLoop(base_events.BaseEventLoop):
244 """Loop for unittests.
245
246 It manages self time directly.
247 If something scheduled to be executed later then
248 on next loop iteration after all ready handlers done
249 generator passed to __init__ is calling.
250
251 Generator should be like this:
252
253 def gen():
254 ...
255 when = yield ...
256 ... = yield time_advance
257
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500258 Value returned by yield is absolute time of next scheduled handler.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700259 Value passed to yield is time advance to move loop's time forward.
260 """
261
262 def __init__(self, gen=None):
263 super().__init__()
264
265 if gen is None:
266 def gen():
267 yield
268 self._check_on_close = False
269 else:
270 self._check_on_close = True
271
272 self._gen = gen()
273 next(self._gen)
274 self._time = 0
Victor Stinner06847d92014-02-11 09:03:47 +0100275 self._clock_resolution = 1e-9
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700276 self._timers = []
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700277 self._selector = TestSelector()
278
279 self.readers = {}
280 self.writers = {}
281 self.reset_counters()
282
283 def time(self):
284 return self._time
285
286 def advance_time(self, advance):
287 """Move test time forward."""
288 if advance:
289 self._time += advance
290
291 def close(self):
292 if self._check_on_close:
293 try:
294 self._gen.send(0)
295 except StopIteration:
296 pass
297 else: # pragma: no cover
298 raise AssertionError("Time generator is not finished")
299
300 def add_reader(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500301 self.readers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700302
303 def remove_reader(self, fd):
304 self.remove_reader_count[fd] += 1
305 if fd in self.readers:
306 del self.readers[fd]
307 return True
308 else:
309 return False
310
311 def assert_reader(self, fd, callback, *args):
312 assert fd in self.readers, 'fd {} is not registered'.format(fd)
313 handle = self.readers[fd]
314 assert handle._callback == callback, '{!r} != {!r}'.format(
315 handle._callback, callback)
316 assert handle._args == args, '{!r} != {!r}'.format(
317 handle._args, args)
318
319 def add_writer(self, fd, callback, *args):
Yury Selivanovff827f02014-02-18 18:02:19 -0500320 self.writers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700321
322 def remove_writer(self, fd):
323 self.remove_writer_count[fd] += 1
324 if fd in self.writers:
325 del self.writers[fd]
326 return True
327 else:
328 return False
329
330 def assert_writer(self, fd, callback, *args):
331 assert fd in self.writers, 'fd {} is not registered'.format(fd)
332 handle = self.writers[fd]
333 assert handle._callback == callback, '{!r} != {!r}'.format(
334 handle._callback, callback)
335 assert handle._args == args, '{!r} != {!r}'.format(
336 handle._args, args)
337
338 def reset_counters(self):
339 self.remove_reader_count = collections.defaultdict(int)
340 self.remove_writer_count = collections.defaultdict(int)
341
342 def _run_once(self):
343 super()._run_once()
344 for when in self._timers:
345 advance = self._gen.send(when)
346 self.advance_time(advance)
347 self._timers = []
348
349 def call_at(self, when, callback, *args):
350 self._timers.append(when)
351 return super().call_at(when, callback, *args)
352
353 def _process_events(self, event_list):
354 return
355
356 def _write_to_self(self):
357 pass
Victor Stinnera1254972014-02-11 11:34:30 +0100358
Yury Selivanov88a5bf02014-02-18 12:15:06 -0500359
Victor Stinnera1254972014-02-11 11:34:30 +0100360def MockCallback(**kwargs):
Victor Stinner24ba2032014-02-26 10:25:02 +0100361 return mock.Mock(spec=['__call__'], **kwargs)
Yury Selivanovff827f02014-02-18 18:02:19 -0500362
363
364class MockPattern(str):
365 """A regex based str with a fuzzy __eq__.
366
367 Use this helper with 'mock.assert_called_with', or anywhere
Yury Selivanovb0b0e622014-02-18 22:27:48 -0500368 where a regex comparison between strings is needed.
Yury Selivanovff827f02014-02-18 18:02:19 -0500369
370 For instance:
371 mock_call.assert_called_with(MockPattern('spam.*ham'))
372 """
373 def __eq__(self, other):
374 return bool(re.search(str(self), other, re.S))