blob: de2916bfc743b205d801ecf9b892b31698823d2f [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07006import os
Yury Selivanovb057c522014-02-18 12:15:06 -05007import socket
8import socketserver
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07009import sys
Yury Selivanovb057c522014-02-18 12:15:06 -050010import tempfile
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070011import threading
Antoine Pitroud20afad2013-10-20 01:51:25 +020012import time
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070013import unittest
14import unittest.mock
Yury Selivanovb057c522014-02-18 12:15:06 -050015
16from http.server import HTTPServer
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070017from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
Yury Selivanovb057c522014-02-18 12:15:06 -050018
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070019try:
20 import ssl
21except ImportError: # pragma: no cover
22 ssl = None
23
24from . import tasks
25from . import base_events
26from . import events
27from . import selectors
28
29
30if sys.platform == 'win32': # pragma: no cover
31 from .windows_utils import socketpair
32else:
33 from socket import socketpair # pragma: no cover
34
35
36def dummy_ssl_context():
37 if ssl is None:
38 return None
39 else:
40 return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
41
42
43def run_briefly(loop):
44 @tasks.coroutine
45 def once():
46 pass
47 gen = once()
48 t = tasks.Task(gen, loop=loop)
49 try:
50 loop.run_until_complete(t)
51 finally:
52 gen.close()
53
54
Antoine Pitroud20afad2013-10-20 01:51:25 +020055def run_until(loop, pred, timeout=None):
56 if timeout is not None:
57 deadline = time.time() + timeout
58 while not pred():
59 if timeout is not None:
60 timeout = deadline - time.time()
61 if timeout <= 0:
62 return False
63 loop.run_until_complete(tasks.sleep(timeout, loop=loop))
64 else:
65 run_briefly(loop)
66 return True
67
68
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070069def run_once(loop):
70 """loop.stop() schedules _raise_stop_error()
71 and run_forever() runs until _raise_stop_error() callback.
72 this wont work if test waits for some IO events, because
73 _raise_stop_error() runs before any of io events callbacks.
74 """
75 loop.stop()
76 loop.run_forever()
77
78
Yury Selivanovb057c522014-02-18 12:15:06 -050079class SilentWSGIRequestHandler(WSGIRequestHandler):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070080
Yury Selivanovb057c522014-02-18 12:15:06 -050081 def get_stderr(self):
82 return io.StringIO()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070083
Yury Selivanovb057c522014-02-18 12:15:06 -050084 def log_message(self, format, *args):
85 pass
86
87
88class SilentWSGIServer(WSGIServer):
89
90 def handle_error(self, request, client_address):
91 pass
92
93
94class SSLWSGIServerMixin:
95
96 def finish_request(self, request, client_address):
97 # The relative location of our test directory (which
98 # contains the ssl key and certificate files) differs
99 # between the stdlib and stand-alone asyncio.
100 # Prefer our own if we can find it.
101 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
102 if not os.path.isdir(here):
103 here = os.path.join(os.path.dirname(os.__file__),
104 'test', 'test_asyncio')
105 keyfile = os.path.join(here, 'ssl_key.pem')
106 certfile = os.path.join(here, 'ssl_cert.pem')
107 ssock = ssl.wrap_socket(request,
108 keyfile=keyfile,
109 certfile=certfile,
110 server_side=True)
111 try:
112 self.RequestHandlerClass(ssock, client_address, self)
113 ssock.close()
114 except OSError:
115 # maybe socket has been closed by peer
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700116 pass
117
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700118
Yury Selivanovb057c522014-02-18 12:15:06 -0500119class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
120 pass
121
122
123def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700124
125 def app(environ, start_response):
126 status = '200 OK'
127 headers = [('Content-type', 'text/plain')]
128 start_response(status, headers)
129 return [b'Test message']
130
131 # Run the test WSGI server in a separate thread in order not to
132 # interfere with event handling in the main thread
Yury Selivanovb057c522014-02-18 12:15:06 -0500133 server_class = server_ssl_cls if use_ssl else server_cls
134 httpd = server_class(address, SilentWSGIRequestHandler)
135 httpd.set_app(app)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700136 httpd.address = httpd.server_address
137 server_thread = threading.Thread(target=httpd.serve_forever)
138 server_thread.start()
139 try:
140 yield httpd
141 finally:
142 httpd.shutdown()
Antoine Pitroua7a150c2013-10-20 23:26:23 +0200143 httpd.server_close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700144 server_thread.join()
145
146
Yury Selivanovb057c522014-02-18 12:15:06 -0500147if hasattr(socket, 'AF_UNIX'):
148
149 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
150
151 def server_bind(self):
152 socketserver.UnixStreamServer.server_bind(self)
153 self.server_name = '127.0.0.1'
154 self.server_port = 80
155
156
157 class UnixWSGIServer(UnixHTTPServer, WSGIServer):
158
159 def server_bind(self):
160 UnixHTTPServer.server_bind(self)
161 self.setup_environ()
162
163 def get_request(self):
164 request, client_addr = super().get_request()
165 # Code in the stdlib expects that get_request
166 # will return a socket and a tuple (host, port).
167 # However, this isn't true for UNIX sockets,
168 # as the second return value will be a path;
169 # hence we return some fake data sufficient
170 # to get the tests going
171 return request, ('127.0.0.1', '')
172
173
174 class SilentUnixWSGIServer(UnixWSGIServer):
175
176 def handle_error(self, request, client_address):
177 pass
178
179
180 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
181 pass
182
183
184 def gen_unix_socket_path():
185 with tempfile.NamedTemporaryFile() as file:
186 return file.name
187
188
189 @contextlib.contextmanager
190 def unix_socket_path():
191 path = gen_unix_socket_path()
192 try:
193 yield path
194 finally:
195 try:
196 os.unlink(path)
197 except OSError:
198 pass
199
200
201 @contextlib.contextmanager
202 def run_test_unix_server(*, use_ssl=False):
203 with unix_socket_path() as path:
204 yield from _run_test_server(address=path, use_ssl=use_ssl,
205 server_cls=SilentUnixWSGIServer,
206 server_ssl_cls=UnixSSLWSGIServer)
207
208
209@contextlib.contextmanager
210def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
211 yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
212 server_cls=SilentWSGIServer,
213 server_ssl_cls=SSLWSGIServer)
214
215
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700216def make_test_protocol(base):
217 dct = {}
218 for name in dir(base):
219 if name.startswith('__') and name.endswith('__'):
220 # skip magic names
221 continue
Victor Stinner9af4a242014-02-11 11:34:30 +0100222 dct[name] = MockCallback(return_value=None)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700223 return type('TestProtocol', (base,) + base.__bases__, dct)()
224
225
226class TestSelector(selectors.BaseSelector):
227
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100228 def __init__(self):
229 self.keys = {}
230
231 def register(self, fileobj, events, data=None):
232 key = selectors.SelectorKey(fileobj, 0, events, data)
233 self.keys[fileobj] = key
234 return key
235
236 def unregister(self, fileobj):
237 return self.keys.pop(fileobj)
238
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700239 def select(self, timeout):
240 return []
241
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100242 def get_map(self):
243 return self.keys
244
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700245
246class TestLoop(base_events.BaseEventLoop):
247 """Loop for unittests.
248
249 It manages self time directly.
250 If something scheduled to be executed later then
251 on next loop iteration after all ready handlers done
252 generator passed to __init__ is calling.
253
254 Generator should be like this:
255
256 def gen():
257 ...
258 when = yield ...
259 ... = yield time_advance
260
261 Value retuned by yield is absolute time of next scheduled handler.
262 Value passed to yield is time advance to move loop's time forward.
263 """
264
265 def __init__(self, gen=None):
266 super().__init__()
267
268 if gen is None:
269 def gen():
270 yield
271 self._check_on_close = False
272 else:
273 self._check_on_close = True
274
275 self._gen = gen()
276 next(self._gen)
277 self._time = 0
Victor Stinner7b467db2014-02-11 09:03:47 +0100278 self._clock_resolution = 1e-9
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700279 self._timers = []
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700280 self._selector = TestSelector()
281
282 self.readers = {}
283 self.writers = {}
284 self.reset_counters()
285
286 def time(self):
287 return self._time
288
289 def advance_time(self, advance):
290 """Move test time forward."""
291 if advance:
292 self._time += advance
293
294 def close(self):
295 if self._check_on_close:
296 try:
297 self._gen.send(0)
298 except StopIteration:
299 pass
300 else: # pragma: no cover
301 raise AssertionError("Time generator is not finished")
302
303 def add_reader(self, fd, callback, *args):
Victor Stinnerdc62b7e2014-02-10 00:45:44 +0100304 self.readers[fd] = events.Handle(callback, args)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700305
306 def remove_reader(self, fd):
307 self.remove_reader_count[fd] += 1
308 if fd in self.readers:
309 del self.readers[fd]
310 return True
311 else:
312 return False
313
314 def assert_reader(self, fd, callback, *args):
315 assert fd in self.readers, 'fd {} is not registered'.format(fd)
316 handle = self.readers[fd]
317 assert handle._callback == callback, '{!r} != {!r}'.format(
318 handle._callback, callback)
319 assert handle._args == args, '{!r} != {!r}'.format(
320 handle._args, args)
321
322 def add_writer(self, fd, callback, *args):
Victor Stinnerdc62b7e2014-02-10 00:45:44 +0100323 self.writers[fd] = events.Handle(callback, args)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700324
325 def remove_writer(self, fd):
326 self.remove_writer_count[fd] += 1
327 if fd in self.writers:
328 del self.writers[fd]
329 return True
330 else:
331 return False
332
333 def assert_writer(self, fd, callback, *args):
334 assert fd in self.writers, 'fd {} is not registered'.format(fd)
335 handle = self.writers[fd]
336 assert handle._callback == callback, '{!r} != {!r}'.format(
337 handle._callback, callback)
338 assert handle._args == args, '{!r} != {!r}'.format(
339 handle._args, args)
340
341 def reset_counters(self):
342 self.remove_reader_count = collections.defaultdict(int)
343 self.remove_writer_count = collections.defaultdict(int)
344
345 def _run_once(self):
346 super()._run_once()
347 for when in self._timers:
348 advance = self._gen.send(when)
349 self.advance_time(advance)
350 self._timers = []
351
352 def call_at(self, when, callback, *args):
353 self._timers.append(when)
354 return super().call_at(when, callback, *args)
355
356 def _process_events(self, event_list):
357 return
358
359 def _write_to_self(self):
360 pass
Victor Stinner9af4a242014-02-11 11:34:30 +0100361
Yury Selivanovb057c522014-02-18 12:15:06 -0500362
Victor Stinner9af4a242014-02-11 11:34:30 +0100363def MockCallback(**kwargs):
364 return unittest.mock.Mock(spec=['__call__'], **kwargs)