blob: bd77cabb116af35b7fde0ad2b00ac4d38f6147c9 [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Stream-related things."""
2
Guido van Rossum49c96fb2013-11-25 15:07:18 -08003__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
Victor Stinner8dffc452014-01-25 15:32:06 +01004 'open_connection', 'start_server', 'IncompleteReadError',
Guido van Rossum1540b162013-11-19 11:43:38 -08005 ]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07006
7import collections
8
9from . import events
10from . import futures
11from . import protocols
12from . import tasks
13
14
15_DEFAULT_LIMIT = 2**16
16
Victor Stinner8dffc452014-01-25 15:32:06 +010017class IncompleteReadError(EOFError):
18 """
19 Incomplete read error. Attributes:
20
21 - partial: read bytes string before the end of stream was reached
22 - expected: total number of expected bytes
23 """
24 def __init__(self, partial, expected):
25 EOFError.__init__(self, "%s bytes read on a total of %s expected bytes"
26 % (len(partial), expected))
27 self.partial = partial
28 self.expected = expected
29
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070030
31@tasks.coroutine
32def open_connection(host=None, port=None, *,
33 loop=None, limit=_DEFAULT_LIMIT, **kwds):
34 """A wrapper for create_connection() returning a (reader, writer) pair.
35
36 The reader returned is a StreamReader instance; the writer is a
Victor Stinner183e3472014-01-23 17:40:03 +010037 StreamWriter instance.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070038
39 The arguments are all the usual arguments to create_connection()
40 except protocol_factory; most common are positional host and port,
41 with various optional keyword arguments following.
42
43 Additional optional keyword arguments are loop (to set the event loop
44 instance to use) and limit (to set the buffer limit passed to the
45 StreamReader).
46
47 (If you want to customize the StreamReader and/or
48 StreamReaderProtocol classes, just copy the code -- there's
49 really nothing special here except some convenience.)
50 """
51 if loop is None:
52 loop = events.get_event_loop()
53 reader = StreamReader(limit=limit, loop=loop)
Guido van Rossumefef9d32014-01-10 13:26:38 -080054 protocol = StreamReaderProtocol(reader, loop=loop)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070055 transport, _ = yield from loop.create_connection(
56 lambda: protocol, host, port, **kwds)
Guido van Rossum355491d2013-10-18 15:17:11 -070057 writer = StreamWriter(transport, protocol, reader, loop)
58 return reader, writer
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070059
60
Guido van Rossum1540b162013-11-19 11:43:38 -080061@tasks.coroutine
62def start_server(client_connected_cb, host=None, port=None, *,
63 loop=None, limit=_DEFAULT_LIMIT, **kwds):
64 """Start a socket server, call back for each client connected.
65
66 The first parameter, `client_connected_cb`, takes two parameters:
67 client_reader, client_writer. client_reader is a StreamReader
68 object, while client_writer is a StreamWriter object. This
69 parameter can either be a plain callback function or a coroutine;
70 if it is a coroutine, it will be automatically converted into a
71 Task.
72
73 The rest of the arguments are all the usual arguments to
74 loop.create_server() except protocol_factory; most common are
75 positional host and port, with various optional keyword arguments
76 following. The return value is the same as loop.create_server().
77
78 Additional optional keyword arguments are loop (to set the event loop
79 instance to use) and limit (to set the buffer limit passed to the
80 StreamReader).
81
82 The return value is the same as loop.create_server(), i.e. a
83 Server object which can be used to stop the service.
84 """
85 if loop is None:
86 loop = events.get_event_loop()
87
88 def factory():
89 reader = StreamReader(limit=limit, loop=loop)
90 protocol = StreamReaderProtocol(reader, client_connected_cb,
91 loop=loop)
92 return protocol
93
94 return (yield from loop.create_server(factory, host, port, **kwds))
95
96
Guido van Rossum4d62d0b2014-01-29 14:24:45 -080097class FlowControlMixin(protocols.Protocol):
98 """Reusable flow control logic for StreamWriter.drain().
99
100 This implements the protocol methods pause_writing(),
101 resume_reading() and connection_lost(). If the subclass overrides
102 these it must call the super methods.
103
104 StreamWriter.drain() must check for error conditions and then call
105 _make_drain_waiter(), which will return either () or a Future
106 depending on the paused state.
107 """
108
109 def __init__(self, loop=None):
110 self._loop = loop # May be None; we may never need it.
111 self._paused = False
112 self._drain_waiter = None
113
114 def pause_writing(self):
115 assert not self._paused
116 self._paused = True
117
118 def resume_writing(self):
119 assert self._paused
120 self._paused = False
121 waiter = self._drain_waiter
122 if waiter is not None:
123 self._drain_waiter = None
124 if not waiter.done():
125 waiter.set_result(None)
126
127 def connection_lost(self, exc):
128 # Wake up the writer if currently paused.
129 if not self._paused:
130 return
131 waiter = self._drain_waiter
132 if waiter is None:
133 return
134 self._drain_waiter = None
135 if waiter.done():
136 return
137 if exc is None:
138 waiter.set_result(None)
139 else:
140 waiter.set_exception(exc)
141
142 def _make_drain_waiter(self):
143 if not self._paused:
144 return ()
145 waiter = self._drain_waiter
146 assert waiter is None or waiter.cancelled()
147 waiter = futures.Future(loop=self._loop)
148 self._drain_waiter = waiter
149 return waiter
150
151
152class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
153 """Helper class to adapt between Protocol and StreamReader.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700154
155 (This is a helper class instead of making StreamReader itself a
156 Protocol subclass, because the StreamReader has other potential
157 uses, and to prevent the user of the StreamReader to accidentally
158 call inappropriate methods of the protocol.)
159 """
160
Guido van Rossum1540b162013-11-19 11:43:38 -0800161 def __init__(self, stream_reader, client_connected_cb=None, loop=None):
Guido van Rossum4d62d0b2014-01-29 14:24:45 -0800162 super().__init__(loop=loop)
Guido van Rossum355491d2013-10-18 15:17:11 -0700163 self._stream_reader = stream_reader
Guido van Rossum1540b162013-11-19 11:43:38 -0800164 self._stream_writer = None
Guido van Rossum1540b162013-11-19 11:43:38 -0800165 self._client_connected_cb = client_connected_cb
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700166
167 def connection_made(self, transport):
Guido van Rossum355491d2013-10-18 15:17:11 -0700168 self._stream_reader.set_transport(transport)
Guido van Rossum1540b162013-11-19 11:43:38 -0800169 if self._client_connected_cb is not None:
170 self._stream_writer = StreamWriter(transport, self,
171 self._stream_reader,
172 self._loop)
173 res = self._client_connected_cb(self._stream_reader,
174 self._stream_writer)
175 if tasks.iscoroutine(res):
176 tasks.Task(res, loop=self._loop)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700177
178 def connection_lost(self, exc):
179 if exc is None:
Guido van Rossum355491d2013-10-18 15:17:11 -0700180 self._stream_reader.feed_eof()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700181 else:
Guido van Rossum355491d2013-10-18 15:17:11 -0700182 self._stream_reader.set_exception(exc)
Guido van Rossum4d62d0b2014-01-29 14:24:45 -0800183 super().connection_lost(exc)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700184
185 def data_received(self, data):
Guido van Rossum355491d2013-10-18 15:17:11 -0700186 self._stream_reader.feed_data(data)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700187
188 def eof_received(self):
Guido van Rossum355491d2013-10-18 15:17:11 -0700189 self._stream_reader.feed_eof()
190
Guido van Rossum355491d2013-10-18 15:17:11 -0700191
192class StreamWriter:
193 """Wraps a Transport.
194
195 This exposes write(), writelines(), [can_]write_eof(),
196 get_extra_info() and close(). It adds drain() which returns an
197 optional Future on which you can wait for flow control. It also
Guido van Rossumefef9d32014-01-10 13:26:38 -0800198 adds a transport property which references the Transport
Guido van Rossum355491d2013-10-18 15:17:11 -0700199 directly.
200 """
201
202 def __init__(self, transport, protocol, reader, loop):
203 self._transport = transport
204 self._protocol = protocol
205 self._reader = reader
206 self._loop = loop
207
208 @property
209 def transport(self):
210 return self._transport
211
212 def write(self, data):
213 self._transport.write(data)
214
215 def writelines(self, data):
216 self._transport.writelines(data)
217
218 def write_eof(self):
219 return self._transport.write_eof()
220
221 def can_write_eof(self):
222 return self._transport.can_write_eof()
223
224 def close(self):
225 return self._transport.close()
226
227 def get_extra_info(self, name, default=None):
228 return self._transport.get_extra_info(name, default)
229
230 def drain(self):
231 """This method has an unusual return value.
232
233 The intended use is to write
234
235 w.write(data)
236 yield from w.drain()
237
238 When there's nothing to wait for, drain() returns (), and the
239 yield-from continues immediately. When the transport buffer
240 is full (the protocol is paused), drain() creates and returns
241 a Future and the yield-from will block until that Future is
242 completed, which will happen when the buffer is (partially)
243 drained and the protocol is resumed.
244 """
Guido van Rossum4d62d0b2014-01-29 14:24:45 -0800245 if self._reader is not None and self._reader._exception is not None:
Guido van Rossum6188bd42014-01-07 17:03:26 -0800246 raise self._reader._exception
Guido van Rossum355491d2013-10-18 15:17:11 -0700247 if self._transport._conn_lost: # Uses private variable.
248 raise ConnectionResetError('Connection lost')
Guido van Rossum4d62d0b2014-01-29 14:24:45 -0800249 return self._protocol._make_drain_waiter()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700250
251
252class StreamReader:
253
254 def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
255 # The line length limit is a security feature;
256 # it also doubles as half the buffer limit.
Guido van Rossum355491d2013-10-18 15:17:11 -0700257 self._limit = limit
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700258 if loop is None:
259 loop = events.get_event_loop()
Guido van Rossum355491d2013-10-18 15:17:11 -0700260 self._loop = loop
Guido van Rossum38455212014-01-06 16:09:18 -0800261 # TODO: Use a bytearray for a buffer, like the transport.
Guido van Rossum355491d2013-10-18 15:17:11 -0700262 self._buffer = collections.deque() # Deque of bytes objects.
263 self._byte_count = 0 # Bytes in buffer.
264 self._eof = False # Whether we're done.
265 self._waiter = None # A future.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700266 self._exception = None
267 self._transport = None
268 self._paused = False
269
270 def exception(self):
271 return self._exception
272
273 def set_exception(self, exc):
274 self._exception = exc
275
Guido van Rossum355491d2013-10-18 15:17:11 -0700276 waiter = self._waiter
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700277 if waiter is not None:
Guido van Rossum355491d2013-10-18 15:17:11 -0700278 self._waiter = None
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700279 if not waiter.cancelled():
280 waiter.set_exception(exc)
281
282 def set_transport(self, transport):
283 assert self._transport is None, 'Transport already set'
284 self._transport = transport
285
286 def _maybe_resume_transport(self):
Guido van Rossum355491d2013-10-18 15:17:11 -0700287 if self._paused and self._byte_count <= self._limit:
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700288 self._paused = False
Guido van Rossum57497ad2013-10-18 07:58:20 -0700289 self._transport.resume_reading()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700290
291 def feed_eof(self):
Guido van Rossum355491d2013-10-18 15:17:11 -0700292 self._eof = True
293 waiter = self._waiter
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700294 if waiter is not None:
Guido van Rossum355491d2013-10-18 15:17:11 -0700295 self._waiter = None
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700296 if not waiter.cancelled():
297 waiter.set_result(True)
298
299 def feed_data(self, data):
300 if not data:
301 return
302
Guido van Rossum355491d2013-10-18 15:17:11 -0700303 self._buffer.append(data)
304 self._byte_count += len(data)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700305
Guido van Rossum355491d2013-10-18 15:17:11 -0700306 waiter = self._waiter
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700307 if waiter is not None:
Guido van Rossum355491d2013-10-18 15:17:11 -0700308 self._waiter = None
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700309 if not waiter.cancelled():
310 waiter.set_result(False)
311
312 if (self._transport is not None and
313 not self._paused and
Guido van Rossum355491d2013-10-18 15:17:11 -0700314 self._byte_count > 2*self._limit):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700315 try:
Guido van Rossum57497ad2013-10-18 07:58:20 -0700316 self._transport.pause_reading()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700317 except NotImplementedError:
318 # The transport can't be paused.
319 # We'll just have to buffer all data.
320 # Forget the transport so we don't keep trying.
321 self._transport = None
322 else:
323 self._paused = True
324
Victor Stinner183e3472014-01-23 17:40:03 +0100325 def _create_waiter(self, func_name):
326 # StreamReader uses a future to link the protocol feed_data() method
327 # to a read coroutine. Running two read coroutines at the same time
328 # would have an unexpected behaviour. It would not possible to know
329 # which coroutine would get the next data.
330 if self._waiter is not None:
331 raise RuntimeError('%s() called while another coroutine is '
332 'already waiting for incoming data' % func_name)
333 return futures.Future(loop=self._loop)
334
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700335 @tasks.coroutine
336 def readline(self):
337 if self._exception is not None:
338 raise self._exception
339
340 parts = []
341 parts_size = 0
342 not_enough = True
343
344 while not_enough:
Guido van Rossum355491d2013-10-18 15:17:11 -0700345 while self._buffer and not_enough:
346 data = self._buffer.popleft()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700347 ichar = data.find(b'\n')
348 if ichar < 0:
349 parts.append(data)
350 parts_size += len(data)
351 else:
352 ichar += 1
353 head, tail = data[:ichar], data[ichar:]
354 if tail:
Guido van Rossum355491d2013-10-18 15:17:11 -0700355 self._buffer.appendleft(tail)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700356 not_enough = False
357 parts.append(head)
358 parts_size += len(head)
359
Guido van Rossum355491d2013-10-18 15:17:11 -0700360 if parts_size > self._limit:
361 self._byte_count -= parts_size
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700362 self._maybe_resume_transport()
363 raise ValueError('Line is too long')
364
Guido van Rossum355491d2013-10-18 15:17:11 -0700365 if self._eof:
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700366 break
367
368 if not_enough:
Victor Stinner183e3472014-01-23 17:40:03 +0100369 self._waiter = self._create_waiter('readline')
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700370 try:
Guido van Rossum355491d2013-10-18 15:17:11 -0700371 yield from self._waiter
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700372 finally:
Guido van Rossum355491d2013-10-18 15:17:11 -0700373 self._waiter = None
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700374
375 line = b''.join(parts)
Guido van Rossum355491d2013-10-18 15:17:11 -0700376 self._byte_count -= parts_size
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700377 self._maybe_resume_transport()
378
379 return line
380
381 @tasks.coroutine
382 def read(self, n=-1):
383 if self._exception is not None:
384 raise self._exception
385
386 if not n:
387 return b''
388
389 if n < 0:
Guido van Rossum355491d2013-10-18 15:17:11 -0700390 while not self._eof:
Victor Stinner183e3472014-01-23 17:40:03 +0100391 self._waiter = self._create_waiter('read')
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700392 try:
Guido van Rossum355491d2013-10-18 15:17:11 -0700393 yield from self._waiter
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700394 finally:
Guido van Rossum355491d2013-10-18 15:17:11 -0700395 self._waiter = None
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700396 else:
Guido van Rossum355491d2013-10-18 15:17:11 -0700397 if not self._byte_count and not self._eof:
Victor Stinner183e3472014-01-23 17:40:03 +0100398 self._waiter = self._create_waiter('read')
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700399 try:
Guido van Rossum355491d2013-10-18 15:17:11 -0700400 yield from self._waiter
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700401 finally:
Guido van Rossum355491d2013-10-18 15:17:11 -0700402 self._waiter = None
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700403
Guido van Rossum355491d2013-10-18 15:17:11 -0700404 if n < 0 or self._byte_count <= n:
405 data = b''.join(self._buffer)
406 self._buffer.clear()
407 self._byte_count = 0
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700408 self._maybe_resume_transport()
409 return data
410
411 parts = []
412 parts_bytes = 0
Guido van Rossum355491d2013-10-18 15:17:11 -0700413 while self._buffer and parts_bytes < n:
414 data = self._buffer.popleft()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700415 data_bytes = len(data)
416 if n < parts_bytes + data_bytes:
417 data_bytes = n - parts_bytes
418 data, rest = data[:data_bytes], data[data_bytes:]
Guido van Rossum355491d2013-10-18 15:17:11 -0700419 self._buffer.appendleft(rest)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700420
421 parts.append(data)
422 parts_bytes += data_bytes
Guido van Rossum355491d2013-10-18 15:17:11 -0700423 self._byte_count -= data_bytes
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700424 self._maybe_resume_transport()
425
426 return b''.join(parts)
427
428 @tasks.coroutine
429 def readexactly(self, n):
430 if self._exception is not None:
431 raise self._exception
432
Guido van Rossum38455212014-01-06 16:09:18 -0800433 # There used to be "optimized" code here. It created its own
434 # Future and waited until self._buffer had at least the n
435 # bytes, then called read(n). Unfortunately, this could pause
436 # the transport if the argument was larger than the pause
437 # limit (which is twice self._limit). So now we just read()
438 # into a local buffer.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700439
Guido van Rossum38455212014-01-06 16:09:18 -0800440 blocks = []
441 while n > 0:
442 block = yield from self.read(n)
443 if not block:
Victor Stinner8dffc452014-01-25 15:32:06 +0100444 partial = b''.join(blocks)
445 raise IncompleteReadError(partial, len(partial) + n)
Guido van Rossum38455212014-01-06 16:09:18 -0800446 blocks.append(block)
447 n -= len(block)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700448
Guido van Rossum38455212014-01-06 16:09:18 -0800449 return b''.join(blocks)