blob: 42b9cd75ae87dc5cc90ac6fa68848716adf09dbd [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07006import os
7import sys
8import threading
Antoine Pitroud20afad2013-10-20 01:51:25 +02009import time
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070010import unittest
11import unittest.mock
12from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
13try:
14 import ssl
15except ImportError: # pragma: no cover
16 ssl = None
17
18from . import tasks
19from . import base_events
20from . import events
21from . import selectors
22
23
24if sys.platform == 'win32': # pragma: no cover
25 from .windows_utils import socketpair
26else:
27 from socket import socketpair # pragma: no cover
28
29
30def dummy_ssl_context():
31 if ssl is None:
32 return None
33 else:
34 return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
35
36
37def run_briefly(loop):
38 @tasks.coroutine
39 def once():
40 pass
41 gen = once()
42 t = tasks.Task(gen, loop=loop)
43 try:
44 loop.run_until_complete(t)
45 finally:
46 gen.close()
47
48
Antoine Pitroud20afad2013-10-20 01:51:25 +020049def run_until(loop, pred, timeout=None):
50 if timeout is not None:
51 deadline = time.time() + timeout
52 while not pred():
53 if timeout is not None:
54 timeout = deadline - time.time()
55 if timeout <= 0:
56 return False
57 loop.run_until_complete(tasks.sleep(timeout, loop=loop))
58 else:
59 run_briefly(loop)
60 return True
61
62
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070063def run_once(loop):
64 """loop.stop() schedules _raise_stop_error()
65 and run_forever() runs until _raise_stop_error() callback.
66 this wont work if test waits for some IO events, because
67 _raise_stop_error() runs before any of io events callbacks.
68 """
69 loop.stop()
70 loop.run_forever()
71
72
73@contextlib.contextmanager
74def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
75
76 class SilentWSGIRequestHandler(WSGIRequestHandler):
77 def get_stderr(self):
78 return io.StringIO()
79
80 def log_message(self, format, *args):
81 pass
82
83 class SilentWSGIServer(WSGIServer):
84 def handle_error(self, request, client_address):
85 pass
86
87 class SSLWSGIServer(SilentWSGIServer):
88 def finish_request(self, request, client_address):
89 # The relative location of our test directory (which
Guido van Rossum1a605ed2013-12-06 12:57:40 -080090 # contains the ssl key and certificate files) differs
Victor Stinner2748bc72013-12-13 10:57:04 +010091 # between the stdlib and stand-alone asyncio.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070092 # Prefer our own if we can find it.
93 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
94 if not os.path.isdir(here):
95 here = os.path.join(os.path.dirname(os.__file__),
96 'test', 'test_asyncio')
Christian Heimesc9a87e62013-12-06 02:58:23 +010097 keyfile = os.path.join(here, 'ssl_key.pem')
98 certfile = os.path.join(here, 'ssl_cert.pem')
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070099 ssock = ssl.wrap_socket(request,
100 keyfile=keyfile,
101 certfile=certfile,
102 server_side=True)
103 try:
104 self.RequestHandlerClass(ssock, client_address, self)
105 ssock.close()
106 except OSError:
107 # maybe socket has been closed by peer
108 pass
109
110 def app(environ, start_response):
111 status = '200 OK'
112 headers = [('Content-type', 'text/plain')]
113 start_response(status, headers)
114 return [b'Test message']
115
116 # Run the test WSGI server in a separate thread in order not to
117 # interfere with event handling in the main thread
118 server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
119 httpd = make_server(host, port, app,
120 server_class, SilentWSGIRequestHandler)
121 httpd.address = httpd.server_address
122 server_thread = threading.Thread(target=httpd.serve_forever)
123 server_thread.start()
124 try:
125 yield httpd
126 finally:
127 httpd.shutdown()
Antoine Pitroua7a150c2013-10-20 23:26:23 +0200128 httpd.server_close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700129 server_thread.join()
130
131
132def make_test_protocol(base):
133 dct = {}
134 for name in dir(base):
135 if name.startswith('__') and name.endswith('__'):
136 # skip magic names
137 continue
138 dct[name] = unittest.mock.Mock(return_value=None)
139 return type('TestProtocol', (base,) + base.__bases__, dct)()
140
141
142class TestSelector(selectors.BaseSelector):
143
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100144 def __init__(self):
145 self.keys = {}
146
Victor Stinner75a5ec82014-01-25 15:31:06 +0100147 @property
148 def resolution(self):
149 return 1e-3
150
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100151 def register(self, fileobj, events, data=None):
152 key = selectors.SelectorKey(fileobj, 0, events, data)
153 self.keys[fileobj] = key
154 return key
155
156 def unregister(self, fileobj):
157 return self.keys.pop(fileobj)
158
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700159 def select(self, timeout):
160 return []
161
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100162 def get_map(self):
163 return self.keys
164
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700165
166class TestLoop(base_events.BaseEventLoop):
167 """Loop for unittests.
168
169 It manages self time directly.
170 If something scheduled to be executed later then
171 on next loop iteration after all ready handlers done
172 generator passed to __init__ is calling.
173
174 Generator should be like this:
175
176 def gen():
177 ...
178 when = yield ...
179 ... = yield time_advance
180
181 Value retuned by yield is absolute time of next scheduled handler.
182 Value passed to yield is time advance to move loop's time forward.
183 """
184
185 def __init__(self, gen=None):
186 super().__init__()
187
188 if gen is None:
189 def gen():
190 yield
191 self._check_on_close = False
192 else:
193 self._check_on_close = True
194
195 self._gen = gen()
196 next(self._gen)
197 self._time = 0
198 self._timers = []
199 self._selector = TestSelector()
200
201 self.readers = {}
202 self.writers = {}
203 self.reset_counters()
204
205 def time(self):
206 return self._time
207
208 def advance_time(self, advance):
209 """Move test time forward."""
210 if advance:
211 self._time += advance
212
213 def close(self):
214 if self._check_on_close:
215 try:
216 self._gen.send(0)
217 except StopIteration:
218 pass
219 else: # pragma: no cover
220 raise AssertionError("Time generator is not finished")
221
222 def add_reader(self, fd, callback, *args):
223 self.readers[fd] = events.make_handle(callback, args)
224
225 def remove_reader(self, fd):
226 self.remove_reader_count[fd] += 1
227 if fd in self.readers:
228 del self.readers[fd]
229 return True
230 else:
231 return False
232
233 def assert_reader(self, fd, callback, *args):
234 assert fd in self.readers, 'fd {} is not registered'.format(fd)
235 handle = self.readers[fd]
236 assert handle._callback == callback, '{!r} != {!r}'.format(
237 handle._callback, callback)
238 assert handle._args == args, '{!r} != {!r}'.format(
239 handle._args, args)
240
241 def add_writer(self, fd, callback, *args):
242 self.writers[fd] = events.make_handle(callback, args)
243
244 def remove_writer(self, fd):
245 self.remove_writer_count[fd] += 1
246 if fd in self.writers:
247 del self.writers[fd]
248 return True
249 else:
250 return False
251
252 def assert_writer(self, fd, callback, *args):
253 assert fd in self.writers, 'fd {} is not registered'.format(fd)
254 handle = self.writers[fd]
255 assert handle._callback == callback, '{!r} != {!r}'.format(
256 handle._callback, callback)
257 assert handle._args == args, '{!r} != {!r}'.format(
258 handle._args, args)
259
260 def reset_counters(self):
261 self.remove_reader_count = collections.defaultdict(int)
262 self.remove_writer_count = collections.defaultdict(int)
263
264 def _run_once(self):
265 super()._run_once()
266 for when in self._timers:
267 advance = self._gen.send(when)
268 self.advance_time(advance)
269 self._timers = []
270
271 def call_at(self, when, callback, *args):
272 self._timers.append(when)
273 return super().call_at(when, callback, *args)
274
275 def _process_events(self, event_list):
276 return
277
278 def _write_to_self(self):
279 pass