blob: 91bbedbafa4626b6919e6e422eb7710e330c5a8a [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Utilities shared by tests."""
2
3import collections
4import contextlib
5import io
6import unittest.mock
7import os
8import sys
9import threading
10import unittest
11import unittest.mock
12from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
13try:
14 import ssl
15except ImportError: # pragma: no cover
16 ssl = None
17
18from . import tasks
19from . import base_events
20from . import events
21from . import selectors
22
23
24if sys.platform == 'win32': # pragma: no cover
25 from .windows_utils import socketpair
26else:
27 from socket import socketpair # pragma: no cover
28
29
30def dummy_ssl_context():
31 if ssl is None:
32 return None
33 else:
34 return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
35
36
37def run_briefly(loop):
38 @tasks.coroutine
39 def once():
40 pass
41 gen = once()
42 t = tasks.Task(gen, loop=loop)
43 try:
44 loop.run_until_complete(t)
45 finally:
46 gen.close()
47
48
49def run_once(loop):
50 """loop.stop() schedules _raise_stop_error()
51 and run_forever() runs until _raise_stop_error() callback.
52 this wont work if test waits for some IO events, because
53 _raise_stop_error() runs before any of io events callbacks.
54 """
55 loop.stop()
56 loop.run_forever()
57
58
59@contextlib.contextmanager
60def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
61
62 class SilentWSGIRequestHandler(WSGIRequestHandler):
63 def get_stderr(self):
64 return io.StringIO()
65
66 def log_message(self, format, *args):
67 pass
68
69 class SilentWSGIServer(WSGIServer):
70 def handle_error(self, request, client_address):
71 pass
72
73 class SSLWSGIServer(SilentWSGIServer):
74 def finish_request(self, request, client_address):
75 # The relative location of our test directory (which
76 # contains the sample key and certificate files) differs
77 # between the stdlib and stand-alone Tulip/asyncio.
78 # Prefer our own if we can find it.
79 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
80 if not os.path.isdir(here):
81 here = os.path.join(os.path.dirname(os.__file__),
82 'test', 'test_asyncio')
83 keyfile = os.path.join(here, 'sample.key')
84 certfile = os.path.join(here, 'sample.crt')
85 ssock = ssl.wrap_socket(request,
86 keyfile=keyfile,
87 certfile=certfile,
88 server_side=True)
89 try:
90 self.RequestHandlerClass(ssock, client_address, self)
91 ssock.close()
92 except OSError:
93 # maybe socket has been closed by peer
94 pass
95
96 def app(environ, start_response):
97 status = '200 OK'
98 headers = [('Content-type', 'text/plain')]
99 start_response(status, headers)
100 return [b'Test message']
101
102 # Run the test WSGI server in a separate thread in order not to
103 # interfere with event handling in the main thread
104 server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
105 httpd = make_server(host, port, app,
106 server_class, SilentWSGIRequestHandler)
107 httpd.address = httpd.server_address
108 server_thread = threading.Thread(target=httpd.serve_forever)
109 server_thread.start()
110 try:
111 yield httpd
112 finally:
113 httpd.shutdown()
114 server_thread.join()
115
116
117def make_test_protocol(base):
118 dct = {}
119 for name in dir(base):
120 if name.startswith('__') and name.endswith('__'):
121 # skip magic names
122 continue
123 dct[name] = unittest.mock.Mock(return_value=None)
124 return type('TestProtocol', (base,) + base.__bases__, dct)()
125
126
127class TestSelector(selectors.BaseSelector):
128
129 def select(self, timeout):
130 return []
131
132
133class TestLoop(base_events.BaseEventLoop):
134 """Loop for unittests.
135
136 It manages self time directly.
137 If something scheduled to be executed later then
138 on next loop iteration after all ready handlers done
139 generator passed to __init__ is calling.
140
141 Generator should be like this:
142
143 def gen():
144 ...
145 when = yield ...
146 ... = yield time_advance
147
148 Value retuned by yield is absolute time of next scheduled handler.
149 Value passed to yield is time advance to move loop's time forward.
150 """
151
152 def __init__(self, gen=None):
153 super().__init__()
154
155 if gen is None:
156 def gen():
157 yield
158 self._check_on_close = False
159 else:
160 self._check_on_close = True
161
162 self._gen = gen()
163 next(self._gen)
164 self._time = 0
165 self._timers = []
166 self._selector = TestSelector()
167
168 self.readers = {}
169 self.writers = {}
170 self.reset_counters()
171
172 def time(self):
173 return self._time
174
175 def advance_time(self, advance):
176 """Move test time forward."""
177 if advance:
178 self._time += advance
179
180 def close(self):
181 if self._check_on_close:
182 try:
183 self._gen.send(0)
184 except StopIteration:
185 pass
186 else: # pragma: no cover
187 raise AssertionError("Time generator is not finished")
188
189 def add_reader(self, fd, callback, *args):
190 self.readers[fd] = events.make_handle(callback, args)
191
192 def remove_reader(self, fd):
193 self.remove_reader_count[fd] += 1
194 if fd in self.readers:
195 del self.readers[fd]
196 return True
197 else:
198 return False
199
200 def assert_reader(self, fd, callback, *args):
201 assert fd in self.readers, 'fd {} is not registered'.format(fd)
202 handle = self.readers[fd]
203 assert handle._callback == callback, '{!r} != {!r}'.format(
204 handle._callback, callback)
205 assert handle._args == args, '{!r} != {!r}'.format(
206 handle._args, args)
207
208 def add_writer(self, fd, callback, *args):
209 self.writers[fd] = events.make_handle(callback, args)
210
211 def remove_writer(self, fd):
212 self.remove_writer_count[fd] += 1
213 if fd in self.writers:
214 del self.writers[fd]
215 return True
216 else:
217 return False
218
219 def assert_writer(self, fd, callback, *args):
220 assert fd in self.writers, 'fd {} is not registered'.format(fd)
221 handle = self.writers[fd]
222 assert handle._callback == callback, '{!r} != {!r}'.format(
223 handle._callback, callback)
224 assert handle._args == args, '{!r} != {!r}'.format(
225 handle._args, args)
226
227 def reset_counters(self):
228 self.remove_reader_count = collections.defaultdict(int)
229 self.remove_writer_count = collections.defaultdict(int)
230
231 def _run_once(self):
232 super()._run_once()
233 for when in self._timers:
234 advance = self._gen.send(when)
235 self.advance_time(advance)
236 self._timers = []
237
238 def call_at(self, when, callback, *args):
239 self._timers.append(when)
240 return super().call_at(when, callback, *args)
241
242 def _process_events(self, event_list):
243 return
244
245 def _write_to_self(self):
246 pass