blob: c278dd1773c25a4a0405abef4975c632a2e25436 [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
6import unittest.mock
7import os
8import sys
9import threading
Antoine Pitroud20afad2013-10-20 01:51:25 +020010import time
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070011import unittest
12import unittest.mock
13from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
14try:
15 import ssl
16except ImportError: # pragma: no cover
17 ssl = None
18
19from . import tasks
20from . import base_events
21from . import events
22from . import selectors
23
24
25if sys.platform == 'win32': # pragma: no cover
26 from .windows_utils import socketpair
27else:
28 from socket import socketpair # pragma: no cover
29
30
31def dummy_ssl_context():
32 if ssl is None:
33 return None
34 else:
35 return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
36
37
38def run_briefly(loop):
39 @tasks.coroutine
40 def once():
41 pass
42 gen = once()
43 t = tasks.Task(gen, loop=loop)
44 try:
45 loop.run_until_complete(t)
46 finally:
47 gen.close()
48
49
Antoine Pitroud20afad2013-10-20 01:51:25 +020050def run_until(loop, pred, timeout=None):
51 if timeout is not None:
52 deadline = time.time() + timeout
53 while not pred():
54 if timeout is not None:
55 timeout = deadline - time.time()
56 if timeout <= 0:
57 return False
58 loop.run_until_complete(tasks.sleep(timeout, loop=loop))
59 else:
60 run_briefly(loop)
61 return True
62
63
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070064def run_once(loop):
65 """loop.stop() schedules _raise_stop_error()
66 and run_forever() runs until _raise_stop_error() callback.
67 this wont work if test waits for some IO events, because
68 _raise_stop_error() runs before any of io events callbacks.
69 """
70 loop.stop()
71 loop.run_forever()
72
73
74@contextlib.contextmanager
75def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
76
77 class SilentWSGIRequestHandler(WSGIRequestHandler):
78 def get_stderr(self):
79 return io.StringIO()
80
81 def log_message(self, format, *args):
82 pass
83
84 class SilentWSGIServer(WSGIServer):
85 def handle_error(self, request, client_address):
86 pass
87
88 class SSLWSGIServer(SilentWSGIServer):
89 def finish_request(self, request, client_address):
90 # The relative location of our test directory (which
91 # contains the sample key and certificate files) differs
92 # between the stdlib and stand-alone Tulip/asyncio.
93 # Prefer our own if we can find it.
94 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
95 if not os.path.isdir(here):
96 here = os.path.join(os.path.dirname(os.__file__),
97 'test', 'test_asyncio')
98 keyfile = os.path.join(here, 'sample.key')
99 certfile = os.path.join(here, 'sample.crt')
100 ssock = ssl.wrap_socket(request,
101 keyfile=keyfile,
102 certfile=certfile,
103 server_side=True)
104 try:
105 self.RequestHandlerClass(ssock, client_address, self)
106 ssock.close()
107 except OSError:
108 # maybe socket has been closed by peer
109 pass
110
111 def app(environ, start_response):
112 status = '200 OK'
113 headers = [('Content-type', 'text/plain')]
114 start_response(status, headers)
115 return [b'Test message']
116
117 # Run the test WSGI server in a separate thread in order not to
118 # interfere with event handling in the main thread
119 server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
120 httpd = make_server(host, port, app,
121 server_class, SilentWSGIRequestHandler)
122 httpd.address = httpd.server_address
123 server_thread = threading.Thread(target=httpd.serve_forever)
124 server_thread.start()
125 try:
126 yield httpd
127 finally:
128 httpd.shutdown()
Antoine Pitroua7a150c2013-10-20 23:26:23 +0200129 httpd.server_close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700130 server_thread.join()
131
132
133def make_test_protocol(base):
134 dct = {}
135 for name in dir(base):
136 if name.startswith('__') and name.endswith('__'):
137 # skip magic names
138 continue
139 dct[name] = unittest.mock.Mock(return_value=None)
140 return type('TestProtocol', (base,) + base.__bases__, dct)()
141
142
143class TestSelector(selectors.BaseSelector):
144
145 def select(self, timeout):
146 return []
147
148
149class TestLoop(base_events.BaseEventLoop):
150 """Loop for unittests.
151
152 It manages self time directly.
153 If something scheduled to be executed later then
154 on next loop iteration after all ready handlers done
155 generator passed to __init__ is calling.
156
157 Generator should be like this:
158
159 def gen():
160 ...
161 when = yield ...
162 ... = yield time_advance
163
164 Value retuned by yield is absolute time of next scheduled handler.
165 Value passed to yield is time advance to move loop's time forward.
166 """
167
168 def __init__(self, gen=None):
169 super().__init__()
170
171 if gen is None:
172 def gen():
173 yield
174 self._check_on_close = False
175 else:
176 self._check_on_close = True
177
178 self._gen = gen()
179 next(self._gen)
180 self._time = 0
181 self._timers = []
182 self._selector = TestSelector()
183
184 self.readers = {}
185 self.writers = {}
186 self.reset_counters()
187
188 def time(self):
189 return self._time
190
191 def advance_time(self, advance):
192 """Move test time forward."""
193 if advance:
194 self._time += advance
195
196 def close(self):
197 if self._check_on_close:
198 try:
199 self._gen.send(0)
200 except StopIteration:
201 pass
202 else: # pragma: no cover
203 raise AssertionError("Time generator is not finished")
204
205 def add_reader(self, fd, callback, *args):
206 self.readers[fd] = events.make_handle(callback, args)
207
208 def remove_reader(self, fd):
209 self.remove_reader_count[fd] += 1
210 if fd in self.readers:
211 del self.readers[fd]
212 return True
213 else:
214 return False
215
216 def assert_reader(self, fd, callback, *args):
217 assert fd in self.readers, 'fd {} is not registered'.format(fd)
218 handle = self.readers[fd]
219 assert handle._callback == callback, '{!r} != {!r}'.format(
220 handle._callback, callback)
221 assert handle._args == args, '{!r} != {!r}'.format(
222 handle._args, args)
223
224 def add_writer(self, fd, callback, *args):
225 self.writers[fd] = events.make_handle(callback, args)
226
227 def remove_writer(self, fd):
228 self.remove_writer_count[fd] += 1
229 if fd in self.writers:
230 del self.writers[fd]
231 return True
232 else:
233 return False
234
235 def assert_writer(self, fd, callback, *args):
236 assert fd in self.writers, 'fd {} is not registered'.format(fd)
237 handle = self.writers[fd]
238 assert handle._callback == callback, '{!r} != {!r}'.format(
239 handle._callback, callback)
240 assert handle._args == args, '{!r} != {!r}'.format(
241 handle._args, args)
242
243 def reset_counters(self):
244 self.remove_reader_count = collections.defaultdict(int)
245 self.remove_writer_count = collections.defaultdict(int)
246
247 def _run_once(self):
248 super()._run_once()
249 for when in self._timers:
250 advance = self._gen.send(when)
251 self.advance_time(advance)
252 self._timers = []
253
254 def call_at(self, when, callback, *args):
255 self._timers.append(when)
256 return super().call_at(when, callback, *args)
257
258 def _process_events(self, event_list):
259 return
260
261 def _write_to_self(self):
262 pass