blob: ccb445418d5397ec06b0f75b6a48f1dbde4c8497 [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07006import os
7import sys
8import threading
Antoine Pitroud20afad2013-10-20 01:51:25 +02009import time
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070010import unittest
11import unittest.mock
12from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
13try:
14 import ssl
15except ImportError: # pragma: no cover
16 ssl = None
17
18from . import tasks
19from . import base_events
20from . import events
21from . import selectors
22
23
24if sys.platform == 'win32': # pragma: no cover
25 from .windows_utils import socketpair
26else:
27 from socket import socketpair # pragma: no cover
28
29
30def dummy_ssl_context():
31 if ssl is None:
32 return None
33 else:
34 return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
35
36
37def run_briefly(loop):
38 @tasks.coroutine
39 def once():
40 pass
41 gen = once()
42 t = tasks.Task(gen, loop=loop)
43 try:
44 loop.run_until_complete(t)
45 finally:
46 gen.close()
47
48
Antoine Pitroud20afad2013-10-20 01:51:25 +020049def run_until(loop, pred, timeout=None):
50 if timeout is not None:
51 deadline = time.time() + timeout
52 while not pred():
53 if timeout is not None:
54 timeout = deadline - time.time()
55 if timeout <= 0:
56 return False
57 loop.run_until_complete(tasks.sleep(timeout, loop=loop))
58 else:
59 run_briefly(loop)
60 return True
61
62
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070063def run_once(loop):
64 """loop.stop() schedules _raise_stop_error()
65 and run_forever() runs until _raise_stop_error() callback.
66 this wont work if test waits for some IO events, because
67 _raise_stop_error() runs before any of io events callbacks.
68 """
69 loop.stop()
70 loop.run_forever()
71
72
73@contextlib.contextmanager
74def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
75
76 class SilentWSGIRequestHandler(WSGIRequestHandler):
77 def get_stderr(self):
78 return io.StringIO()
79
80 def log_message(self, format, *args):
81 pass
82
83 class SilentWSGIServer(WSGIServer):
84 def handle_error(self, request, client_address):
85 pass
86
87 class SSLWSGIServer(SilentWSGIServer):
88 def finish_request(self, request, client_address):
89 # The relative location of our test directory (which
Guido van Rossum1a605ed2013-12-06 12:57:40 -080090 # contains the ssl key and certificate files) differs
Victor Stinner2748bc72013-12-13 10:57:04 +010091 # between the stdlib and stand-alone asyncio.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070092 # Prefer our own if we can find it.
93 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
94 if not os.path.isdir(here):
95 here = os.path.join(os.path.dirname(os.__file__),
96 'test', 'test_asyncio')
Christian Heimesc9a87e62013-12-06 02:58:23 +010097 keyfile = os.path.join(here, 'ssl_key.pem')
98 certfile = os.path.join(here, 'ssl_cert.pem')
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070099 ssock = ssl.wrap_socket(request,
100 keyfile=keyfile,
101 certfile=certfile,
102 server_side=True)
103 try:
104 self.RequestHandlerClass(ssock, client_address, self)
105 ssock.close()
106 except OSError:
107 # maybe socket has been closed by peer
108 pass
109
110 def app(environ, start_response):
111 status = '200 OK'
112 headers = [('Content-type', 'text/plain')]
113 start_response(status, headers)
114 return [b'Test message']
115
116 # Run the test WSGI server in a separate thread in order not to
117 # interfere with event handling in the main thread
118 server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
119 httpd = make_server(host, port, app,
120 server_class, SilentWSGIRequestHandler)
121 httpd.address = httpd.server_address
122 server_thread = threading.Thread(target=httpd.serve_forever)
123 server_thread.start()
124 try:
125 yield httpd
126 finally:
127 httpd.shutdown()
Antoine Pitroua7a150c2013-10-20 23:26:23 +0200128 httpd.server_close()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700129 server_thread.join()
130
131
132def make_test_protocol(base):
133 dct = {}
134 for name in dir(base):
135 if name.startswith('__') and name.endswith('__'):
136 # skip magic names
137 continue
138 dct[name] = unittest.mock.Mock(return_value=None)
139 return type('TestProtocol', (base,) + base.__bases__, dct)()
140
141
142class TestSelector(selectors.BaseSelector):
143
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100144 def __init__(self):
145 self.keys = {}
146
147 def register(self, fileobj, events, data=None):
148 key = selectors.SelectorKey(fileobj, 0, events, data)
149 self.keys[fileobj] = key
150 return key
151
152 def unregister(self, fileobj):
153 return self.keys.pop(fileobj)
154
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700155 def select(self, timeout):
156 return []
157
Charles-François Natalib3330a0a2013-12-01 11:04:17 +0100158 def get_map(self):
159 return self.keys
160
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700161
162class TestLoop(base_events.BaseEventLoop):
163 """Loop for unittests.
164
165 It manages self time directly.
166 If something scheduled to be executed later then
167 on next loop iteration after all ready handlers done
168 generator passed to __init__ is calling.
169
170 Generator should be like this:
171
172 def gen():
173 ...
174 when = yield ...
175 ... = yield time_advance
176
177 Value retuned by yield is absolute time of next scheduled handler.
178 Value passed to yield is time advance to move loop's time forward.
179 """
180
181 def __init__(self, gen=None):
182 super().__init__()
183
184 if gen is None:
185 def gen():
186 yield
187 self._check_on_close = False
188 else:
189 self._check_on_close = True
190
191 self._gen = gen()
192 next(self._gen)
193 self._time = 0
194 self._timers = []
195 self._selector = TestSelector()
196
197 self.readers = {}
198 self.writers = {}
199 self.reset_counters()
200
201 def time(self):
202 return self._time
203
204 def advance_time(self, advance):
205 """Move test time forward."""
206 if advance:
207 self._time += advance
208
209 def close(self):
210 if self._check_on_close:
211 try:
212 self._gen.send(0)
213 except StopIteration:
214 pass
215 else: # pragma: no cover
216 raise AssertionError("Time generator is not finished")
217
218 def add_reader(self, fd, callback, *args):
219 self.readers[fd] = events.make_handle(callback, args)
220
221 def remove_reader(self, fd):
222 self.remove_reader_count[fd] += 1
223 if fd in self.readers:
224 del self.readers[fd]
225 return True
226 else:
227 return False
228
229 def assert_reader(self, fd, callback, *args):
230 assert fd in self.readers, 'fd {} is not registered'.format(fd)
231 handle = self.readers[fd]
232 assert handle._callback == callback, '{!r} != {!r}'.format(
233 handle._callback, callback)
234 assert handle._args == args, '{!r} != {!r}'.format(
235 handle._args, args)
236
237 def add_writer(self, fd, callback, *args):
238 self.writers[fd] = events.make_handle(callback, args)
239
240 def remove_writer(self, fd):
241 self.remove_writer_count[fd] += 1
242 if fd in self.writers:
243 del self.writers[fd]
244 return True
245 else:
246 return False
247
248 def assert_writer(self, fd, callback, *args):
249 assert fd in self.writers, 'fd {} is not registered'.format(fd)
250 handle = self.writers[fd]
251 assert handle._callback == callback, '{!r} != {!r}'.format(
252 handle._callback, callback)
253 assert handle._args == args, '{!r} != {!r}'.format(
254 handle._args, args)
255
256 def reset_counters(self):
257 self.remove_reader_count = collections.defaultdict(int)
258 self.remove_writer_count = collections.defaultdict(int)
259
260 def _run_once(self):
261 super()._run_once()
262 for when in self._timers:
263 advance = self._gen.send(when)
264 self.advance_time(advance)
265 self._timers = []
266
267 def call_at(self, when, callback, *args):
268 self._timers.append(when)
269 return super().call_at(when, callback, *args)
270
271 def _process_events(self, event_list):
272 return
273
274 def _write_to_self(self):
275 pass