blob: 93a21d1af736b9328cf0ab2369dff2f2f5faaf48 [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Stream-related things."""
2
Guido van Rossum49c96fb2013-11-25 15:07:18 -08003__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
Guido van Rossum1540b162013-11-19 11:43:38 -08004 'open_connection', 'start_server',
5 ]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07006
7import collections
8
9from . import events
10from . import futures
11from . import protocols
12from . import tasks
13
14
15_DEFAULT_LIMIT = 2**16
16
17
18@tasks.coroutine
19def open_connection(host=None, port=None, *,
20 loop=None, limit=_DEFAULT_LIMIT, **kwds):
21 """A wrapper for create_connection() returning a (reader, writer) pair.
22
23 The reader returned is a StreamReader instance; the writer is a
24 Transport.
25
26 The arguments are all the usual arguments to create_connection()
27 except protocol_factory; most common are positional host and port,
28 with various optional keyword arguments following.
29
30 Additional optional keyword arguments are loop (to set the event loop
31 instance to use) and limit (to set the buffer limit passed to the
32 StreamReader).
33
34 (If you want to customize the StreamReader and/or
35 StreamReaderProtocol classes, just copy the code -- there's
36 really nothing special here except some convenience.)
37 """
38 if loop is None:
39 loop = events.get_event_loop()
40 reader = StreamReader(limit=limit, loop=loop)
41 protocol = StreamReaderProtocol(reader)
42 transport, _ = yield from loop.create_connection(
43 lambda: protocol, host, port, **kwds)
Guido van Rossum355491d2013-10-18 15:17:11 -070044 writer = StreamWriter(transport, protocol, reader, loop)
45 return reader, writer
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070046
47
Guido van Rossum1540b162013-11-19 11:43:38 -080048@tasks.coroutine
49def start_server(client_connected_cb, host=None, port=None, *,
50 loop=None, limit=_DEFAULT_LIMIT, **kwds):
51 """Start a socket server, call back for each client connected.
52
53 The first parameter, `client_connected_cb`, takes two parameters:
54 client_reader, client_writer. client_reader is a StreamReader
55 object, while client_writer is a StreamWriter object. This
56 parameter can either be a plain callback function or a coroutine;
57 if it is a coroutine, it will be automatically converted into a
58 Task.
59
60 The rest of the arguments are all the usual arguments to
61 loop.create_server() except protocol_factory; most common are
62 positional host and port, with various optional keyword arguments
63 following. The return value is the same as loop.create_server().
64
65 Additional optional keyword arguments are loop (to set the event loop
66 instance to use) and limit (to set the buffer limit passed to the
67 StreamReader).
68
69 The return value is the same as loop.create_server(), i.e. a
70 Server object which can be used to stop the service.
71 """
72 if loop is None:
73 loop = events.get_event_loop()
74
75 def factory():
76 reader = StreamReader(limit=limit, loop=loop)
77 protocol = StreamReaderProtocol(reader, client_connected_cb,
78 loop=loop)
79 return protocol
80
81 return (yield from loop.create_server(factory, host, port, **kwds))
82
83
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070084class StreamReaderProtocol(protocols.Protocol):
85 """Trivial helper class to adapt between Protocol and StreamReader.
86
87 (This is a helper class instead of making StreamReader itself a
88 Protocol subclass, because the StreamReader has other potential
89 uses, and to prevent the user of the StreamReader to accidentally
90 call inappropriate methods of the protocol.)
91 """
92
Guido van Rossum1540b162013-11-19 11:43:38 -080093 def __init__(self, stream_reader, client_connected_cb=None, loop=None):
Guido van Rossum355491d2013-10-18 15:17:11 -070094 self._stream_reader = stream_reader
Guido van Rossum1540b162013-11-19 11:43:38 -080095 self._stream_writer = None
Guido van Rossum355491d2013-10-18 15:17:11 -070096 self._drain_waiter = None
97 self._paused = False
Guido van Rossum1540b162013-11-19 11:43:38 -080098 self._client_connected_cb = client_connected_cb
99 self._loop = loop # May be None; we may never need it.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700100
101 def connection_made(self, transport):
Guido van Rossum355491d2013-10-18 15:17:11 -0700102 self._stream_reader.set_transport(transport)
Guido van Rossum1540b162013-11-19 11:43:38 -0800103 if self._client_connected_cb is not None:
104 self._stream_writer = StreamWriter(transport, self,
105 self._stream_reader,
106 self._loop)
107 res = self._client_connected_cb(self._stream_reader,
108 self._stream_writer)
109 if tasks.iscoroutine(res):
110 tasks.Task(res, loop=self._loop)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700111
112 def connection_lost(self, exc):
113 if exc is None:
Guido van Rossum355491d2013-10-18 15:17:11 -0700114 self._stream_reader.feed_eof()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700115 else:
Guido van Rossum355491d2013-10-18 15:17:11 -0700116 self._stream_reader.set_exception(exc)
117 # Also wake up the writing side.
118 if self._paused:
119 waiter = self._drain_waiter
120 if waiter is not None:
121 self._drain_waiter = None
122 if not waiter.done():
123 if exc is None:
124 waiter.set_result(None)
125 else:
126 waiter.set_exception(exc)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700127
128 def data_received(self, data):
Guido van Rossum355491d2013-10-18 15:17:11 -0700129 self._stream_reader.feed_data(data)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700130
131 def eof_received(self):
Guido van Rossum355491d2013-10-18 15:17:11 -0700132 self._stream_reader.feed_eof()
133
134 def pause_writing(self):
135 assert not self._paused
136 self._paused = True
137
138 def resume_writing(self):
139 assert self._paused
140 self._paused = False
141 waiter = self._drain_waiter
142 if waiter is not None:
143 self._drain_waiter = None
144 if not waiter.done():
145 waiter.set_result(None)
146
147
148class StreamWriter:
149 """Wraps a Transport.
150
151 This exposes write(), writelines(), [can_]write_eof(),
152 get_extra_info() and close(). It adds drain() which returns an
153 optional Future on which you can wait for flow control. It also
154 adds a transport attribute which references the Transport
155 directly.
156 """
157
158 def __init__(self, transport, protocol, reader, loop):
159 self._transport = transport
160 self._protocol = protocol
161 self._reader = reader
162 self._loop = loop
163
164 @property
165 def transport(self):
166 return self._transport
167
168 def write(self, data):
169 self._transport.write(data)
170
171 def writelines(self, data):
172 self._transport.writelines(data)
173
174 def write_eof(self):
175 return self._transport.write_eof()
176
177 def can_write_eof(self):
178 return self._transport.can_write_eof()
179
180 def close(self):
181 return self._transport.close()
182
183 def get_extra_info(self, name, default=None):
184 return self._transport.get_extra_info(name, default)
185
186 def drain(self):
187 """This method has an unusual return value.
188
189 The intended use is to write
190
191 w.write(data)
192 yield from w.drain()
193
194 When there's nothing to wait for, drain() returns (), and the
195 yield-from continues immediately. When the transport buffer
196 is full (the protocol is paused), drain() creates and returns
197 a Future and the yield-from will block until that Future is
198 completed, which will happen when the buffer is (partially)
199 drained and the protocol is resumed.
200 """
201 if self._reader._exception is not None:
202 raise self._writer._exception
203 if self._transport._conn_lost: # Uses private variable.
204 raise ConnectionResetError('Connection lost')
205 if not self._protocol._paused:
206 return ()
207 waiter = self._protocol._drain_waiter
208 assert waiter is None or waiter.cancelled()
209 waiter = futures.Future(loop=self._loop)
210 self._protocol._drain_waiter = waiter
211 return waiter
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700212
213
214class StreamReader:
215
216 def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
217 # The line length limit is a security feature;
218 # it also doubles as half the buffer limit.
Guido van Rossum355491d2013-10-18 15:17:11 -0700219 self._limit = limit
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700220 if loop is None:
221 loop = events.get_event_loop()
Guido van Rossum355491d2013-10-18 15:17:11 -0700222 self._loop = loop
Guido van Rossum38455212014-01-06 16:09:18 -0800223 # TODO: Use a bytearray for a buffer, like the transport.
Guido van Rossum355491d2013-10-18 15:17:11 -0700224 self._buffer = collections.deque() # Deque of bytes objects.
225 self._byte_count = 0 # Bytes in buffer.
226 self._eof = False # Whether we're done.
227 self._waiter = None # A future.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700228 self._exception = None
229 self._transport = None
230 self._paused = False
231
232 def exception(self):
233 return self._exception
234
235 def set_exception(self, exc):
236 self._exception = exc
237
Guido van Rossum355491d2013-10-18 15:17:11 -0700238 waiter = self._waiter
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700239 if waiter is not None:
Guido van Rossum355491d2013-10-18 15:17:11 -0700240 self._waiter = None
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700241 if not waiter.cancelled():
242 waiter.set_exception(exc)
243
244 def set_transport(self, transport):
245 assert self._transport is None, 'Transport already set'
246 self._transport = transport
247
248 def _maybe_resume_transport(self):
Guido van Rossum355491d2013-10-18 15:17:11 -0700249 if self._paused and self._byte_count <= self._limit:
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700250 self._paused = False
Guido van Rossum57497ad2013-10-18 07:58:20 -0700251 self._transport.resume_reading()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700252
253 def feed_eof(self):
Guido van Rossum355491d2013-10-18 15:17:11 -0700254 self._eof = True
255 waiter = self._waiter
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700256 if waiter is not None:
Guido van Rossum355491d2013-10-18 15:17:11 -0700257 self._waiter = None
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700258 if not waiter.cancelled():
259 waiter.set_result(True)
260
261 def feed_data(self, data):
262 if not data:
263 return
264
Guido van Rossum355491d2013-10-18 15:17:11 -0700265 self._buffer.append(data)
266 self._byte_count += len(data)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700267
Guido van Rossum355491d2013-10-18 15:17:11 -0700268 waiter = self._waiter
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700269 if waiter is not None:
Guido van Rossum355491d2013-10-18 15:17:11 -0700270 self._waiter = None
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700271 if not waiter.cancelled():
272 waiter.set_result(False)
273
274 if (self._transport is not None and
275 not self._paused and
Guido van Rossum355491d2013-10-18 15:17:11 -0700276 self._byte_count > 2*self._limit):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700277 try:
Guido van Rossum57497ad2013-10-18 07:58:20 -0700278 self._transport.pause_reading()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700279 except NotImplementedError:
280 # The transport can't be paused.
281 # We'll just have to buffer all data.
282 # Forget the transport so we don't keep trying.
283 self._transport = None
284 else:
285 self._paused = True
286
287 @tasks.coroutine
288 def readline(self):
289 if self._exception is not None:
290 raise self._exception
291
292 parts = []
293 parts_size = 0
294 not_enough = True
295
296 while not_enough:
Guido van Rossum355491d2013-10-18 15:17:11 -0700297 while self._buffer and not_enough:
298 data = self._buffer.popleft()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700299 ichar = data.find(b'\n')
300 if ichar < 0:
301 parts.append(data)
302 parts_size += len(data)
303 else:
304 ichar += 1
305 head, tail = data[:ichar], data[ichar:]
306 if tail:
Guido van Rossum355491d2013-10-18 15:17:11 -0700307 self._buffer.appendleft(tail)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700308 not_enough = False
309 parts.append(head)
310 parts_size += len(head)
311
Guido van Rossum355491d2013-10-18 15:17:11 -0700312 if parts_size > self._limit:
313 self._byte_count -= parts_size
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700314 self._maybe_resume_transport()
315 raise ValueError('Line is too long')
316
Guido van Rossum355491d2013-10-18 15:17:11 -0700317 if self._eof:
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700318 break
319
320 if not_enough:
Guido van Rossum355491d2013-10-18 15:17:11 -0700321 assert self._waiter is None
322 self._waiter = futures.Future(loop=self._loop)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700323 try:
Guido van Rossum355491d2013-10-18 15:17:11 -0700324 yield from self._waiter
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700325 finally:
Guido van Rossum355491d2013-10-18 15:17:11 -0700326 self._waiter = None
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700327
328 line = b''.join(parts)
Guido van Rossum355491d2013-10-18 15:17:11 -0700329 self._byte_count -= parts_size
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700330 self._maybe_resume_transport()
331
332 return line
333
334 @tasks.coroutine
335 def read(self, n=-1):
336 if self._exception is not None:
337 raise self._exception
338
339 if not n:
340 return b''
341
342 if n < 0:
Guido van Rossum355491d2013-10-18 15:17:11 -0700343 while not self._eof:
344 assert not self._waiter
345 self._waiter = futures.Future(loop=self._loop)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700346 try:
Guido van Rossum355491d2013-10-18 15:17:11 -0700347 yield from self._waiter
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700348 finally:
Guido van Rossum355491d2013-10-18 15:17:11 -0700349 self._waiter = None
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700350 else:
Guido van Rossum355491d2013-10-18 15:17:11 -0700351 if not self._byte_count and not self._eof:
352 assert not self._waiter
353 self._waiter = futures.Future(loop=self._loop)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700354 try:
Guido van Rossum355491d2013-10-18 15:17:11 -0700355 yield from self._waiter
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700356 finally:
Guido van Rossum355491d2013-10-18 15:17:11 -0700357 self._waiter = None
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700358
Guido van Rossum355491d2013-10-18 15:17:11 -0700359 if n < 0 or self._byte_count <= n:
360 data = b''.join(self._buffer)
361 self._buffer.clear()
362 self._byte_count = 0
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700363 self._maybe_resume_transport()
364 return data
365
366 parts = []
367 parts_bytes = 0
Guido van Rossum355491d2013-10-18 15:17:11 -0700368 while self._buffer and parts_bytes < n:
369 data = self._buffer.popleft()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700370 data_bytes = len(data)
371 if n < parts_bytes + data_bytes:
372 data_bytes = n - parts_bytes
373 data, rest = data[:data_bytes], data[data_bytes:]
Guido van Rossum355491d2013-10-18 15:17:11 -0700374 self._buffer.appendleft(rest)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700375
376 parts.append(data)
377 parts_bytes += data_bytes
Guido van Rossum355491d2013-10-18 15:17:11 -0700378 self._byte_count -= data_bytes
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700379 self._maybe_resume_transport()
380
381 return b''.join(parts)
382
383 @tasks.coroutine
384 def readexactly(self, n):
385 if self._exception is not None:
386 raise self._exception
387
Guido van Rossum38455212014-01-06 16:09:18 -0800388 # There used to be "optimized" code here. It created its own
389 # Future and waited until self._buffer had at least the n
390 # bytes, then called read(n). Unfortunately, this could pause
391 # the transport if the argument was larger than the pause
392 # limit (which is twice self._limit). So now we just read()
393 # into a local buffer.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700394
Guido van Rossum38455212014-01-06 16:09:18 -0800395 blocks = []
396 while n > 0:
397 block = yield from self.read(n)
398 if not block:
399 break
400 blocks.append(block)
401 n -= len(block)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700402
Guido van Rossum38455212014-01-06 16:09:18 -0800403 # TODO: Raise EOFError if we break before n == 0? (That would
404 # be a change in specification, but I've always had to add an
405 # explicit size check to the caller.)
406
407 return b''.join(blocks)