blob: deab7c33122f066fd52d5a756c1894c9fef6d6d2 [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07006import os
7import sys
8import threading
Antoine Pitroud20afad2013-10-20 01:51:25 +02009import time
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070010import unittest
11import unittest.mock
12from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
13try:
14 import ssl
15except ImportError: # pragma: no cover
16 ssl = None
17
18from . import tasks
19from . import base_events
20from . import events
21from . import selectors
22
23
24if sys.platform == 'win32': # pragma: no cover
25 from .windows_utils import socketpair
26else:
27 from socket import socketpair # pragma: no cover
28
29
30def dummy_ssl_context():
31 if ssl is None:
32 return None
33 else:
34 return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
35
36
37def run_briefly(loop):
38 @tasks.coroutine
39 def once():
40 pass
41 gen = once()
42 t = tasks.Task(gen, loop=loop)
43 try:
44 loop.run_until_complete(t)
45 finally:
46 gen.close()
47
48
Antoine Pitroud20afad2013-10-20 01:51:25 +020049def run_until(loop, pred, timeout=None):
50 if timeout is not None:
51 deadline = time.time() + timeout
52 while not pred():
53 if timeout is not None:
54 timeout = deadline - time.time()
55 if timeout <= 0:
56 return False
57 loop.run_until_complete(tasks.sleep(timeout, loop=loop))
58 else:
59 run_briefly(loop)
60 return True
61
62
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070063def run_once(loop):
64 """loop.stop() schedules _raise_stop_error()
65 and run_forever() runs until _raise_stop_error() callback.
66 this wont work if test waits for some IO events, because
67 _raise_stop_error() runs before any of io events callbacks.
68 """
69 loop.stop()
70 loop.run_forever()
71
72
73@contextlib.contextmanager
74def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
75
76 class SilentWSGIRequestHandler(WSGIRequestHandler):
77 def get_stderr(self):
78 return io.StringIO()
79
80 def log_message(self, format, *args):
81 pass
82
83 class SilentWSGIServer(WSGIServer):
84 def handle_error(self, request, client_address):
85 pass
86
87 class SSLWSGIServer(SilentWSGIServer):
88 def finish_request(self, request, client_address):
89 # The relative location of our test directory (which
Guido van Rossum1a605ed2013-12-06 12:57:40 -080090 # contains the ssl key and certificate files) differs
Victor Stinner2748bc72013-12-13 10:57:04 +010091 # between the stdlib and stand-alone asyncio.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070092 # Prefer our own if we can find it.
93 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
94 if not os.path.isdir(here):
95 here = os.path.join(os.path.dirname(os.__file__),
96 'test', 'test_asyncio')
Christian Heimesc9a87e62013-12-06 02:58:23 +010097 keyfile = os.path.join(here, 'ssl_key.pem')
98 certfile = os.path.join(here, 'ssl_cert.pem')
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070099 ssock = ssl.wrap_socket(request,
100 keyfile=keyfile,
101 certfile=certfile,
102 server_side=True)
103 try:
104 self.RequestHandlerClass(ssock, client_address, self)
105 ssock.close()
106 except OSError:
107 # maybe socket has been closed by peer
108 pass
109
110 def app(environ, start_response):
111 status = '200 OK'
112 headers = [('Content-type', 'text/plain')]
113 start_response(status, headers)
114 return [b'Test message']
115
116 # Run the test WSGI server in a separate thread in order not to
117 # interfere with event handling in the main thread
118 server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
119 httpd = make_server(host, port, app,
120 server_class, SilentWSGIRequestHandler)
121 httpd.address = httpd.server_address
122 server_thread = threading.Thread(target=httpd.serve_forever)
123 server_thread.start()
124 try:
125 yield httpd
126 finally:
127 httpd.shutdown()
Antoine Pitroua7a150c2013-10-20 23:26:23 +0200128 httpd.server_close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700129 server_thread.join()
130
131
132def make_test_protocol(base):
133 dct = {}
134 for name in dir(base):
135 if name.startswith('__') and name.endswith('__'):
136 # skip magic names
137 continue
Victor Stinner9af4a242014-02-11 11:34:30 +0100138 dct[name] = MockCallback(return_value=None)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700139 return type('TestProtocol', (base,) + base.__bases__, dct)()
140
141
142class TestSelector(selectors.BaseSelector):
143
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100144 def __init__(self):
145 self.keys = {}
146
147 def register(self, fileobj, events, data=None):
148 key = selectors.SelectorKey(fileobj, 0, events, data)
149 self.keys[fileobj] = key
150 return key
151
152 def unregister(self, fileobj):
153 return self.keys.pop(fileobj)
154
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700155 def select(self, timeout):
156 return []
157
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100158 def get_map(self):
159 return self.keys
160
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700161
162class TestLoop(base_events.BaseEventLoop):
163 """Loop for unittests.
164
165 It manages self time directly.
166 If something scheduled to be executed later then
167 on next loop iteration after all ready handlers done
168 generator passed to __init__ is calling.
169
170 Generator should be like this:
171
172 def gen():
173 ...
174 when = yield ...
175 ... = yield time_advance
176
177 Value retuned by yield is absolute time of next scheduled handler.
178 Value passed to yield is time advance to move loop's time forward.
179 """
180
181 def __init__(self, gen=None):
182 super().__init__()
183
184 if gen is None:
185 def gen():
186 yield
187 self._check_on_close = False
188 else:
189 self._check_on_close = True
190
191 self._gen = gen()
192 next(self._gen)
193 self._time = 0
Victor Stinner7b467db2014-02-11 09:03:47 +0100194 self._clock_resolution = 1e-9
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700195 self._timers = []
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700196 self._selector = TestSelector()
197
198 self.readers = {}
199 self.writers = {}
200 self.reset_counters()
201
202 def time(self):
203 return self._time
204
205 def advance_time(self, advance):
206 """Move test time forward."""
207 if advance:
208 self._time += advance
209
210 def close(self):
211 if self._check_on_close:
212 try:
213 self._gen.send(0)
214 except StopIteration:
215 pass
216 else: # pragma: no cover
217 raise AssertionError("Time generator is not finished")
218
219 def add_reader(self, fd, callback, *args):
Victor Stinnerdc62b7e2014-02-10 00:45:44 +0100220 self.readers[fd] = events.Handle(callback, args)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700221
222 def remove_reader(self, fd):
223 self.remove_reader_count[fd] += 1
224 if fd in self.readers:
225 del self.readers[fd]
226 return True
227 else:
228 return False
229
230 def assert_reader(self, fd, callback, *args):
231 assert fd in self.readers, 'fd {} is not registered'.format(fd)
232 handle = self.readers[fd]
233 assert handle._callback == callback, '{!r} != {!r}'.format(
234 handle._callback, callback)
235 assert handle._args == args, '{!r} != {!r}'.format(
236 handle._args, args)
237
238 def add_writer(self, fd, callback, *args):
Victor Stinnerdc62b7e2014-02-10 00:45:44 +0100239 self.writers[fd] = events.Handle(callback, args)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700240
241 def remove_writer(self, fd):
242 self.remove_writer_count[fd] += 1
243 if fd in self.writers:
244 del self.writers[fd]
245 return True
246 else:
247 return False
248
249 def assert_writer(self, fd, callback, *args):
250 assert fd in self.writers, 'fd {} is not registered'.format(fd)
251 handle = self.writers[fd]
252 assert handle._callback == callback, '{!r} != {!r}'.format(
253 handle._callback, callback)
254 assert handle._args == args, '{!r} != {!r}'.format(
255 handle._args, args)
256
257 def reset_counters(self):
258 self.remove_reader_count = collections.defaultdict(int)
259 self.remove_writer_count = collections.defaultdict(int)
260
261 def _run_once(self):
262 super()._run_once()
263 for when in self._timers:
264 advance = self._gen.send(when)
265 self.advance_time(advance)
266 self._timers = []
267
268 def call_at(self, when, callback, *args):
269 self._timers.append(when)
270 return super().call_at(when, callback, *args)
271
272 def _process_events(self, event_list):
273 return
274
275 def _write_to_self(self):
276 pass
Victor Stinner9af4a242014-02-11 11:34:30 +0100277
278def MockCallback(**kwargs):
279 return unittest.mock.Mock(spec=['__call__'], **kwargs)