blob: 2a8a241fd5ed913ef200c9e5de6c19ae145dde5f [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07006import os
Yury Selivanov569efa22014-02-18 18:02:19 -05007import re
Yury Selivanovb057c522014-02-18 12:15:06 -05008import socket
9import socketserver
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070010import sys
Yury Selivanovb057c522014-02-18 12:15:06 -050011import tempfile
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070012import threading
Antoine Pitroud20afad2013-10-20 01:51:25 +020013import time
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070014import unittest
15import unittest.mock
Yury Selivanovb057c522014-02-18 12:15:06 -050016
17from http.server import HTTPServer
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070018from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
Yury Selivanovb057c522014-02-18 12:15:06 -050019
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070020try:
21 import ssl
22except ImportError: # pragma: no cover
23 ssl = None
24
25from . import tasks
26from . import base_events
27from . import events
28from . import selectors
29
30
31if sys.platform == 'win32': # pragma: no cover
32 from .windows_utils import socketpair
33else:
34 from socket import socketpair # pragma: no cover
35
36
37def dummy_ssl_context():
38 if ssl is None:
39 return None
40 else:
41 return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
42
43
44def run_briefly(loop):
45 @tasks.coroutine
46 def once():
47 pass
48 gen = once()
49 t = tasks.Task(gen, loop=loop)
50 try:
51 loop.run_until_complete(t)
52 finally:
53 gen.close()
54
55
Antoine Pitroud20afad2013-10-20 01:51:25 +020056def run_until(loop, pred, timeout=None):
57 if timeout is not None:
58 deadline = time.time() + timeout
59 while not pred():
60 if timeout is not None:
61 timeout = deadline - time.time()
62 if timeout <= 0:
63 return False
64 loop.run_until_complete(tasks.sleep(timeout, loop=loop))
65 else:
66 run_briefly(loop)
67 return True
68
69
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070070def run_once(loop):
71 """loop.stop() schedules _raise_stop_error()
72 and run_forever() runs until _raise_stop_error() callback.
73 this wont work if test waits for some IO events, because
74 _raise_stop_error() runs before any of io events callbacks.
75 """
76 loop.stop()
77 loop.run_forever()
78
79
Yury Selivanovb057c522014-02-18 12:15:06 -050080class SilentWSGIRequestHandler(WSGIRequestHandler):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070081
Yury Selivanovb057c522014-02-18 12:15:06 -050082 def get_stderr(self):
83 return io.StringIO()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070084
Yury Selivanovb057c522014-02-18 12:15:06 -050085 def log_message(self, format, *args):
86 pass
87
88
89class SilentWSGIServer(WSGIServer):
90
91 def handle_error(self, request, client_address):
92 pass
93
94
95class SSLWSGIServerMixin:
96
97 def finish_request(self, request, client_address):
98 # The relative location of our test directory (which
99 # contains the ssl key and certificate files) differs
100 # between the stdlib and stand-alone asyncio.
101 # Prefer our own if we can find it.
102 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
103 if not os.path.isdir(here):
104 here = os.path.join(os.path.dirname(os.__file__),
105 'test', 'test_asyncio')
106 keyfile = os.path.join(here, 'ssl_key.pem')
107 certfile = os.path.join(here, 'ssl_cert.pem')
108 ssock = ssl.wrap_socket(request,
109 keyfile=keyfile,
110 certfile=certfile,
111 server_side=True)
112 try:
113 self.RequestHandlerClass(ssock, client_address, self)
114 ssock.close()
115 except OSError:
116 # maybe socket has been closed by peer
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700117 pass
118
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700119
Yury Selivanovb057c522014-02-18 12:15:06 -0500120class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
121 pass
122
123
124def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700125
126 def app(environ, start_response):
127 status = '200 OK'
128 headers = [('Content-type', 'text/plain')]
129 start_response(status, headers)
130 return [b'Test message']
131
132 # Run the test WSGI server in a separate thread in order not to
133 # interfere with event handling in the main thread
Yury Selivanovb057c522014-02-18 12:15:06 -0500134 server_class = server_ssl_cls if use_ssl else server_cls
135 httpd = server_class(address, SilentWSGIRequestHandler)
136 httpd.set_app(app)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700137 httpd.address = httpd.server_address
138 server_thread = threading.Thread(target=httpd.serve_forever)
139 server_thread.start()
140 try:
141 yield httpd
142 finally:
143 httpd.shutdown()
Antoine Pitroua7a150c2013-10-20 23:26:23 +0200144 httpd.server_close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700145 server_thread.join()
146
147
Yury Selivanovb057c522014-02-18 12:15:06 -0500148if hasattr(socket, 'AF_UNIX'):
149
150 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
151
152 def server_bind(self):
153 socketserver.UnixStreamServer.server_bind(self)
154 self.server_name = '127.0.0.1'
155 self.server_port = 80
156
157
158 class UnixWSGIServer(UnixHTTPServer, WSGIServer):
159
160 def server_bind(self):
161 UnixHTTPServer.server_bind(self)
162 self.setup_environ()
163
164 def get_request(self):
165 request, client_addr = super().get_request()
166 # Code in the stdlib expects that get_request
167 # will return a socket and a tuple (host, port).
168 # However, this isn't true for UNIX sockets,
169 # as the second return value will be a path;
170 # hence we return some fake data sufficient
171 # to get the tests going
172 return request, ('127.0.0.1', '')
173
174
175 class SilentUnixWSGIServer(UnixWSGIServer):
176
177 def handle_error(self, request, client_address):
178 pass
179
180
181 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
182 pass
183
184
185 def gen_unix_socket_path():
186 with tempfile.NamedTemporaryFile() as file:
187 return file.name
188
189
190 @contextlib.contextmanager
191 def unix_socket_path():
192 path = gen_unix_socket_path()
193 try:
194 yield path
195 finally:
196 try:
197 os.unlink(path)
198 except OSError:
199 pass
200
201
202 @contextlib.contextmanager
203 def run_test_unix_server(*, use_ssl=False):
204 with unix_socket_path() as path:
205 yield from _run_test_server(address=path, use_ssl=use_ssl,
206 server_cls=SilentUnixWSGIServer,
207 server_ssl_cls=UnixSSLWSGIServer)
208
209
210@contextlib.contextmanager
211def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
212 yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
213 server_cls=SilentWSGIServer,
214 server_ssl_cls=SSLWSGIServer)
215
216
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700217def make_test_protocol(base):
218 dct = {}
219 for name in dir(base):
220 if name.startswith('__') and name.endswith('__'):
221 # skip magic names
222 continue
Victor Stinner9af4a242014-02-11 11:34:30 +0100223 dct[name] = MockCallback(return_value=None)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700224 return type('TestProtocol', (base,) + base.__bases__, dct)()
225
226
227class TestSelector(selectors.BaseSelector):
228
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100229 def __init__(self):
230 self.keys = {}
231
232 def register(self, fileobj, events, data=None):
233 key = selectors.SelectorKey(fileobj, 0, events, data)
234 self.keys[fileobj] = key
235 return key
236
237 def unregister(self, fileobj):
238 return self.keys.pop(fileobj)
239
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700240 def select(self, timeout):
241 return []
242
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100243 def get_map(self):
244 return self.keys
245
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700246
247class TestLoop(base_events.BaseEventLoop):
248 """Loop for unittests.
249
250 It manages self time directly.
251 If something scheduled to be executed later then
252 on next loop iteration after all ready handlers done
253 generator passed to __init__ is calling.
254
255 Generator should be like this:
256
257 def gen():
258 ...
259 when = yield ...
260 ... = yield time_advance
261
Yury Selivanovdec1a452014-02-18 22:27:48 -0500262 Value returned by yield is absolute time of next scheduled handler.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700263 Value passed to yield is time advance to move loop's time forward.
264 """
265
266 def __init__(self, gen=None):
267 super().__init__()
268
269 if gen is None:
270 def gen():
271 yield
272 self._check_on_close = False
273 else:
274 self._check_on_close = True
275
276 self._gen = gen()
277 next(self._gen)
278 self._time = 0
Victor Stinner7b467db2014-02-11 09:03:47 +0100279 self._clock_resolution = 1e-9
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700280 self._timers = []
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700281 self._selector = TestSelector()
282
283 self.readers = {}
284 self.writers = {}
285 self.reset_counters()
286
287 def time(self):
288 return self._time
289
290 def advance_time(self, advance):
291 """Move test time forward."""
292 if advance:
293 self._time += advance
294
295 def close(self):
296 if self._check_on_close:
297 try:
298 self._gen.send(0)
299 except StopIteration:
300 pass
301 else: # pragma: no cover
302 raise AssertionError("Time generator is not finished")
303
304 def add_reader(self, fd, callback, *args):
Yury Selivanov569efa22014-02-18 18:02:19 -0500305 self.readers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700306
307 def remove_reader(self, fd):
308 self.remove_reader_count[fd] += 1
309 if fd in self.readers:
310 del self.readers[fd]
311 return True
312 else:
313 return False
314
315 def assert_reader(self, fd, callback, *args):
316 assert fd in self.readers, 'fd {} is not registered'.format(fd)
317 handle = self.readers[fd]
318 assert handle._callback == callback, '{!r} != {!r}'.format(
319 handle._callback, callback)
320 assert handle._args == args, '{!r} != {!r}'.format(
321 handle._args, args)
322
323 def add_writer(self, fd, callback, *args):
Yury Selivanov569efa22014-02-18 18:02:19 -0500324 self.writers[fd] = events.Handle(callback, args, self)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700325
326 def remove_writer(self, fd):
327 self.remove_writer_count[fd] += 1
328 if fd in self.writers:
329 del self.writers[fd]
330 return True
331 else:
332 return False
333
334 def assert_writer(self, fd, callback, *args):
335 assert fd in self.writers, 'fd {} is not registered'.format(fd)
336 handle = self.writers[fd]
337 assert handle._callback == callback, '{!r} != {!r}'.format(
338 handle._callback, callback)
339 assert handle._args == args, '{!r} != {!r}'.format(
340 handle._args, args)
341
342 def reset_counters(self):
343 self.remove_reader_count = collections.defaultdict(int)
344 self.remove_writer_count = collections.defaultdict(int)
345
346 def _run_once(self):
347 super()._run_once()
348 for when in self._timers:
349 advance = self._gen.send(when)
350 self.advance_time(advance)
351 self._timers = []
352
353 def call_at(self, when, callback, *args):
354 self._timers.append(when)
355 return super().call_at(when, callback, *args)
356
357 def _process_events(self, event_list):
358 return
359
360 def _write_to_self(self):
361 pass
Victor Stinner9af4a242014-02-11 11:34:30 +0100362
Yury Selivanovb057c522014-02-18 12:15:06 -0500363
Victor Stinner9af4a242014-02-11 11:34:30 +0100364def MockCallback(**kwargs):
365 return unittest.mock.Mock(spec=['__call__'], **kwargs)
Yury Selivanov569efa22014-02-18 18:02:19 -0500366
367
368class MockPattern(str):
369 """A regex based str with a fuzzy __eq__.
370
371 Use this helper with 'mock.assert_called_with', or anywhere
Yury Selivanovdec1a452014-02-18 22:27:48 -0500372 where a regex comparison between strings is needed.
Yury Selivanov569efa22014-02-18 18:02:19 -0500373
374 For instance:
375 mock_call.assert_called_with(MockPattern('spam.*ham'))
376 """
377 def __eq__(self, other):
378 return bool(re.search(str(self), other, re.S))