blob: fed28d7d64a3464feea0c645b23e6c75aba77759 [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07006import os
7import sys
8import threading
Antoine Pitroud20afad2013-10-20 01:51:25 +02009import time
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070010import unittest
11import unittest.mock
12from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
13try:
14 import ssl
15except ImportError: # pragma: no cover
16 ssl = None
17
18from . import tasks
19from . import base_events
20from . import events
21from . import selectors
22
23
24if sys.platform == 'win32': # pragma: no cover
25 from .windows_utils import socketpair
26else:
27 from socket import socketpair # pragma: no cover
28
29
30def dummy_ssl_context():
31 if ssl is None:
32 return None
33 else:
34 return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
35
36
37def run_briefly(loop):
38 @tasks.coroutine
39 def once():
40 pass
41 gen = once()
42 t = tasks.Task(gen, loop=loop)
43 try:
44 loop.run_until_complete(t)
45 finally:
46 gen.close()
47
48
Antoine Pitroud20afad2013-10-20 01:51:25 +020049def run_until(loop, pred, timeout=None):
50 if timeout is not None:
51 deadline = time.time() + timeout
52 while not pred():
53 if timeout is not None:
54 timeout = deadline - time.time()
55 if timeout <= 0:
56 return False
57 loop.run_until_complete(tasks.sleep(timeout, loop=loop))
58 else:
59 run_briefly(loop)
60 return True
61
62
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070063def run_once(loop):
64 """loop.stop() schedules _raise_stop_error()
65 and run_forever() runs until _raise_stop_error() callback.
66 this wont work if test waits for some IO events, because
67 _raise_stop_error() runs before any of io events callbacks.
68 """
69 loop.stop()
70 loop.run_forever()
71
72
73@contextlib.contextmanager
74def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
75
76 class SilentWSGIRequestHandler(WSGIRequestHandler):
77 def get_stderr(self):
78 return io.StringIO()
79
80 def log_message(self, format, *args):
81 pass
82
83 class SilentWSGIServer(WSGIServer):
84 def handle_error(self, request, client_address):
85 pass
86
87 class SSLWSGIServer(SilentWSGIServer):
88 def finish_request(self, request, client_address):
89 # The relative location of our test directory (which
Guido van Rossum1a605ed2013-12-06 12:57:40 -080090 # contains the ssl key and certificate files) differs
Victor Stinner2748bc72013-12-13 10:57:04 +010091 # between the stdlib and stand-alone asyncio.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070092 # Prefer our own if we can find it.
93 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
94 if not os.path.isdir(here):
95 here = os.path.join(os.path.dirname(os.__file__),
96 'test', 'test_asyncio')
Christian Heimesc9a87e62013-12-06 02:58:23 +010097 keyfile = os.path.join(here, 'ssl_key.pem')
98 certfile = os.path.join(here, 'ssl_cert.pem')
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070099 ssock = ssl.wrap_socket(request,
100 keyfile=keyfile,
101 certfile=certfile,
102 server_side=True)
103 try:
104 self.RequestHandlerClass(ssock, client_address, self)
105 ssock.close()
106 except OSError:
107 # maybe socket has been closed by peer
108 pass
109
110 def app(environ, start_response):
111 status = '200 OK'
112 headers = [('Content-type', 'text/plain')]
113 start_response(status, headers)
114 return [b'Test message']
115
116 # Run the test WSGI server in a separate thread in order not to
117 # interfere with event handling in the main thread
118 server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
119 httpd = make_server(host, port, app,
120 server_class, SilentWSGIRequestHandler)
121 httpd.address = httpd.server_address
122 server_thread = threading.Thread(target=httpd.serve_forever)
123 server_thread.start()
124 try:
125 yield httpd
126 finally:
127 httpd.shutdown()
Antoine Pitroua7a150c2013-10-20 23:26:23 +0200128 httpd.server_close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700129 server_thread.join()
130
131
132def make_test_protocol(base):
133 dct = {}
134 for name in dir(base):
135 if name.startswith('__') and name.endswith('__'):
136 # skip magic names
137 continue
138 dct[name] = unittest.mock.Mock(return_value=None)
139 return type('TestProtocol', (base,) + base.__bases__, dct)()
140
141
142class TestSelector(selectors.BaseSelector):
143
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100144 def __init__(self):
145 self.keys = {}
146
Victor Stinner75a5ec82014-01-25 15:31:06 +0100147 @property
148 def resolution(self):
149 return 1e-3
150
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100151 def register(self, fileobj, events, data=None):
152 key = selectors.SelectorKey(fileobj, 0, events, data)
153 self.keys[fileobj] = key
154 return key
155
156 def unregister(self, fileobj):
157 return self.keys.pop(fileobj)
158
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700159 def select(self, timeout):
160 return []
161
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100162 def get_map(self):
163 return self.keys
164
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700165
166class TestLoop(base_events.BaseEventLoop):
167 """Loop for unittests.
168
169 It manages self time directly.
170 If something scheduled to be executed later then
171 on next loop iteration after all ready handlers done
172 generator passed to __init__ is calling.
173
174 Generator should be like this:
175
176 def gen():
177 ...
178 when = yield ...
179 ... = yield time_advance
180
181 Value retuned by yield is absolute time of next scheduled handler.
182 Value passed to yield is time advance to move loop's time forward.
183 """
184
185 def __init__(self, gen=None):
186 super().__init__()
187
188 if gen is None:
189 def gen():
190 yield
191 self._check_on_close = False
192 else:
193 self._check_on_close = True
194
195 self._gen = gen()
196 next(self._gen)
197 self._time = 0
198 self._timers = []
Victor Stinner1c165372014-01-30 16:05:07 -0800199 self._granularity = 1e-9
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700200 self._selector = TestSelector()
201
202 self.readers = {}
203 self.writers = {}
204 self.reset_counters()
205
206 def time(self):
207 return self._time
208
209 def advance_time(self, advance):
210 """Move test time forward."""
211 if advance:
212 self._time += advance
213
214 def close(self):
215 if self._check_on_close:
216 try:
217 self._gen.send(0)
218 except StopIteration:
219 pass
220 else: # pragma: no cover
221 raise AssertionError("Time generator is not finished")
222
223 def add_reader(self, fd, callback, *args):
224 self.readers[fd] = events.make_handle(callback, args)
225
226 def remove_reader(self, fd):
227 self.remove_reader_count[fd] += 1
228 if fd in self.readers:
229 del self.readers[fd]
230 return True
231 else:
232 return False
233
234 def assert_reader(self, fd, callback, *args):
235 assert fd in self.readers, 'fd {} is not registered'.format(fd)
236 handle = self.readers[fd]
237 assert handle._callback == callback, '{!r} != {!r}'.format(
238 handle._callback, callback)
239 assert handle._args == args, '{!r} != {!r}'.format(
240 handle._args, args)
241
242 def add_writer(self, fd, callback, *args):
243 self.writers[fd] = events.make_handle(callback, args)
244
245 def remove_writer(self, fd):
246 self.remove_writer_count[fd] += 1
247 if fd in self.writers:
248 del self.writers[fd]
249 return True
250 else:
251 return False
252
253 def assert_writer(self, fd, callback, *args):
254 assert fd in self.writers, 'fd {} is not registered'.format(fd)
255 handle = self.writers[fd]
256 assert handle._callback == callback, '{!r} != {!r}'.format(
257 handle._callback, callback)
258 assert handle._args == args, '{!r} != {!r}'.format(
259 handle._args, args)
260
261 def reset_counters(self):
262 self.remove_reader_count = collections.defaultdict(int)
263 self.remove_writer_count = collections.defaultdict(int)
264
265 def _run_once(self):
266 super()._run_once()
267 for when in self._timers:
268 advance = self._gen.send(when)
269 self.advance_time(advance)
270 self._timers = []
271
272 def call_at(self, when, callback, *args):
273 self._timers.append(when)
274 return super().call_at(when, callback, *args)
275
276 def _process_events(self, event_list):
277 return
278
279 def _write_to_self(self):
280 pass