blob: d7d844249c4fd0ee22c3de9da15279da2e52f889 [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
6import unittest.mock
7import os
8import sys
9import threading
Antoine Pitroud20afad2013-10-20 01:51:25 +020010import time
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070011import unittest
12import unittest.mock
13from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
14try:
15 import ssl
16except ImportError: # pragma: no cover
17 ssl = None
18
19from . import tasks
20from . import base_events
21from . import events
22from . import selectors
23
24
25if sys.platform == 'win32': # pragma: no cover
26 from .windows_utils import socketpair
27else:
28 from socket import socketpair # pragma: no cover
29
30
31def dummy_ssl_context():
32 if ssl is None:
33 return None
34 else:
35 return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
36
37
38def run_briefly(loop):
39 @tasks.coroutine
40 def once():
41 pass
42 gen = once()
43 t = tasks.Task(gen, loop=loop)
44 try:
45 loop.run_until_complete(t)
46 finally:
47 gen.close()
48
49
Antoine Pitroud20afad2013-10-20 01:51:25 +020050def run_until(loop, pred, timeout=None):
51 if timeout is not None:
52 deadline = time.time() + timeout
53 while not pred():
54 if timeout is not None:
55 timeout = deadline - time.time()
56 if timeout <= 0:
57 return False
58 loop.run_until_complete(tasks.sleep(timeout, loop=loop))
59 else:
60 run_briefly(loop)
61 return True
62
63
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070064def run_once(loop):
65 """loop.stop() schedules _raise_stop_error()
66 and run_forever() runs until _raise_stop_error() callback.
67 this wont work if test waits for some IO events, because
68 _raise_stop_error() runs before any of io events callbacks.
69 """
70 loop.stop()
71 loop.run_forever()
72
73
74@contextlib.contextmanager
75def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
76
77 class SilentWSGIRequestHandler(WSGIRequestHandler):
78 def get_stderr(self):
79 return io.StringIO()
80
81 def log_message(self, format, *args):
82 pass
83
84 class SilentWSGIServer(WSGIServer):
85 def handle_error(self, request, client_address):
86 pass
87
88 class SSLWSGIServer(SilentWSGIServer):
89 def finish_request(self, request, client_address):
90 # The relative location of our test directory (which
91 # contains the sample key and certificate files) differs
92 # between the stdlib and stand-alone Tulip/asyncio.
93 # Prefer our own if we can find it.
94 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
95 if not os.path.isdir(here):
96 here = os.path.join(os.path.dirname(os.__file__),
97 'test', 'test_asyncio')
98 keyfile = os.path.join(here, 'sample.key')
99 certfile = os.path.join(here, 'sample.crt')
100 ssock = ssl.wrap_socket(request,
101 keyfile=keyfile,
102 certfile=certfile,
103 server_side=True)
104 try:
105 self.RequestHandlerClass(ssock, client_address, self)
106 ssock.close()
107 except OSError:
108 # maybe socket has been closed by peer
109 pass
110
111 def app(environ, start_response):
112 status = '200 OK'
113 headers = [('Content-type', 'text/plain')]
114 start_response(status, headers)
115 return [b'Test message']
116
117 # Run the test WSGI server in a separate thread in order not to
118 # interfere with event handling in the main thread
119 server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
120 httpd = make_server(host, port, app,
121 server_class, SilentWSGIRequestHandler)
122 httpd.address = httpd.server_address
123 server_thread = threading.Thread(target=httpd.serve_forever)
124 server_thread.start()
125 try:
126 yield httpd
127 finally:
128 httpd.shutdown()
Antoine Pitroua7a150c2013-10-20 23:26:23 +0200129 httpd.server_close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700130 server_thread.join()
131
132
133def make_test_protocol(base):
134 dct = {}
135 for name in dir(base):
136 if name.startswith('__') and name.endswith('__'):
137 # skip magic names
138 continue
139 dct[name] = unittest.mock.Mock(return_value=None)
140 return type('TestProtocol', (base,) + base.__bases__, dct)()
141
142
143class TestSelector(selectors.BaseSelector):
144
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100145 def __init__(self):
146 self.keys = {}
147
148 def register(self, fileobj, events, data=None):
149 key = selectors.SelectorKey(fileobj, 0, events, data)
150 self.keys[fileobj] = key
151 return key
152
153 def unregister(self, fileobj):
154 return self.keys.pop(fileobj)
155
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700156 def select(self, timeout):
157 return []
158
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100159 def get_map(self):
160 return self.keys
161
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700162
163class TestLoop(base_events.BaseEventLoop):
164 """Loop for unittests.
165
166 It manages self time directly.
167 If something scheduled to be executed later then
168 on next loop iteration after all ready handlers done
169 generator passed to __init__ is calling.
170
171 Generator should be like this:
172
173 def gen():
174 ...
175 when = yield ...
176 ... = yield time_advance
177
178 Value retuned by yield is absolute time of next scheduled handler.
179 Value passed to yield is time advance to move loop's time forward.
180 """
181
182 def __init__(self, gen=None):
183 super().__init__()
184
185 if gen is None:
186 def gen():
187 yield
188 self._check_on_close = False
189 else:
190 self._check_on_close = True
191
192 self._gen = gen()
193 next(self._gen)
194 self._time = 0
195 self._timers = []
196 self._selector = TestSelector()
197
198 self.readers = {}
199 self.writers = {}
200 self.reset_counters()
201
202 def time(self):
203 return self._time
204
205 def advance_time(self, advance):
206 """Move test time forward."""
207 if advance:
208 self._time += advance
209
210 def close(self):
211 if self._check_on_close:
212 try:
213 self._gen.send(0)
214 except StopIteration:
215 pass
216 else: # pragma: no cover
217 raise AssertionError("Time generator is not finished")
218
219 def add_reader(self, fd, callback, *args):
220 self.readers[fd] = events.make_handle(callback, args)
221
222 def remove_reader(self, fd):
223 self.remove_reader_count[fd] += 1
224 if fd in self.readers:
225 del self.readers[fd]
226 return True
227 else:
228 return False
229
230 def assert_reader(self, fd, callback, *args):
231 assert fd in self.readers, 'fd {} is not registered'.format(fd)
232 handle = self.readers[fd]
233 assert handle._callback == callback, '{!r} != {!r}'.format(
234 handle._callback, callback)
235 assert handle._args == args, '{!r} != {!r}'.format(
236 handle._args, args)
237
238 def add_writer(self, fd, callback, *args):
239 self.writers[fd] = events.make_handle(callback, args)
240
241 def remove_writer(self, fd):
242 self.remove_writer_count[fd] += 1
243 if fd in self.writers:
244 del self.writers[fd]
245 return True
246 else:
247 return False
248
249 def assert_writer(self, fd, callback, *args):
250 assert fd in self.writers, 'fd {} is not registered'.format(fd)
251 handle = self.writers[fd]
252 assert handle._callback == callback, '{!r} != {!r}'.format(
253 handle._callback, callback)
254 assert handle._args == args, '{!r} != {!r}'.format(
255 handle._args, args)
256
257 def reset_counters(self):
258 self.remove_reader_count = collections.defaultdict(int)
259 self.remove_writer_count = collections.defaultdict(int)
260
261 def _run_once(self):
262 super()._run_once()
263 for when in self._timers:
264 advance = self._gen.send(when)
265 self.advance_time(advance)
266 self._timers = []
267
268 def call_at(self, when, callback, *args):
269 self._timers.append(when)
270 return super().call_at(when, callback, *args)
271
272 def _process_events(self, event_list):
273 return
274
275 def _write_to_self(self):
276 pass