blob: b53080ef5c854f66f6c6b71a8c4bb871083f49e3 [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Stream-related things."""
2
Guido van Rossum49c96fb2013-11-25 15:07:18 -08003__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
Guido van Rossum1540b162013-11-19 11:43:38 -08004 'open_connection', 'start_server',
5 ]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07006
7import collections
8
9from . import events
10from . import futures
11from . import protocols
12from . import tasks
13
14
15_DEFAULT_LIMIT = 2**16
16
17
18@tasks.coroutine
19def open_connection(host=None, port=None, *,
20 loop=None, limit=_DEFAULT_LIMIT, **kwds):
21 """A wrapper for create_connection() returning a (reader, writer) pair.
22
23 The reader returned is a StreamReader instance; the writer is a
Victor Stinner183e3472014-01-23 17:40:03 +010024 StreamWriter instance.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070025
26 The arguments are all the usual arguments to create_connection()
27 except protocol_factory; most common are positional host and port,
28 with various optional keyword arguments following.
29
30 Additional optional keyword arguments are loop (to set the event loop
31 instance to use) and limit (to set the buffer limit passed to the
32 StreamReader).
33
34 (If you want to customize the StreamReader and/or
35 StreamReaderProtocol classes, just copy the code -- there's
36 really nothing special here except some convenience.)
37 """
38 if loop is None:
39 loop = events.get_event_loop()
40 reader = StreamReader(limit=limit, loop=loop)
Guido van Rossumefef9d32014-01-10 13:26:38 -080041 protocol = StreamReaderProtocol(reader, loop=loop)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070042 transport, _ = yield from loop.create_connection(
43 lambda: protocol, host, port, **kwds)
Guido van Rossum355491d2013-10-18 15:17:11 -070044 writer = StreamWriter(transport, protocol, reader, loop)
45 return reader, writer
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070046
47
Guido van Rossum1540b162013-11-19 11:43:38 -080048@tasks.coroutine
49def start_server(client_connected_cb, host=None, port=None, *,
50 loop=None, limit=_DEFAULT_LIMIT, **kwds):
51 """Start a socket server, call back for each client connected.
52
53 The first parameter, `client_connected_cb`, takes two parameters:
54 client_reader, client_writer. client_reader is a StreamReader
55 object, while client_writer is a StreamWriter object. This
56 parameter can either be a plain callback function or a coroutine;
57 if it is a coroutine, it will be automatically converted into a
58 Task.
59
60 The rest of the arguments are all the usual arguments to
61 loop.create_server() except protocol_factory; most common are
62 positional host and port, with various optional keyword arguments
63 following. The return value is the same as loop.create_server().
64
65 Additional optional keyword arguments are loop (to set the event loop
66 instance to use) and limit (to set the buffer limit passed to the
67 StreamReader).
68
69 The return value is the same as loop.create_server(), i.e. a
70 Server object which can be used to stop the service.
71 """
72 if loop is None:
73 loop = events.get_event_loop()
74
75 def factory():
76 reader = StreamReader(limit=limit, loop=loop)
77 protocol = StreamReaderProtocol(reader, client_connected_cb,
78 loop=loop)
79 return protocol
80
81 return (yield from loop.create_server(factory, host, port, **kwds))
82
83
Guido van Rossum27b7c7e2013-10-17 13:40:50 -070084class StreamReaderProtocol(protocols.Protocol):
85 """Trivial helper class to adapt between Protocol and StreamReader.
86
87 (This is a helper class instead of making StreamReader itself a
88 Protocol subclass, because the StreamReader has other potential
89 uses, and to prevent the user of the StreamReader to accidentally
90 call inappropriate methods of the protocol.)
91 """
92
Guido van Rossum1540b162013-11-19 11:43:38 -080093 def __init__(self, stream_reader, client_connected_cb=None, loop=None):
Guido van Rossum355491d2013-10-18 15:17:11 -070094 self._stream_reader = stream_reader
Guido van Rossum1540b162013-11-19 11:43:38 -080095 self._stream_writer = None
Guido van Rossum355491d2013-10-18 15:17:11 -070096 self._drain_waiter = None
97 self._paused = False
Guido van Rossum1540b162013-11-19 11:43:38 -080098 self._client_connected_cb = client_connected_cb
99 self._loop = loop # May be None; we may never need it.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700100
101 def connection_made(self, transport):
Guido van Rossum355491d2013-10-18 15:17:11 -0700102 self._stream_reader.set_transport(transport)
Guido van Rossum1540b162013-11-19 11:43:38 -0800103 if self._client_connected_cb is not None:
104 self._stream_writer = StreamWriter(transport, self,
105 self._stream_reader,
106 self._loop)
107 res = self._client_connected_cb(self._stream_reader,
108 self._stream_writer)
109 if tasks.iscoroutine(res):
110 tasks.Task(res, loop=self._loop)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700111
112 def connection_lost(self, exc):
113 if exc is None:
Guido van Rossum355491d2013-10-18 15:17:11 -0700114 self._stream_reader.feed_eof()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700115 else:
Guido van Rossum355491d2013-10-18 15:17:11 -0700116 self._stream_reader.set_exception(exc)
117 # Also wake up the writing side.
118 if self._paused:
119 waiter = self._drain_waiter
120 if waiter is not None:
121 self._drain_waiter = None
122 if not waiter.done():
123 if exc is None:
124 waiter.set_result(None)
125 else:
126 waiter.set_exception(exc)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700127
128 def data_received(self, data):
Guido van Rossum355491d2013-10-18 15:17:11 -0700129 self._stream_reader.feed_data(data)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700130
131 def eof_received(self):
Guido van Rossum355491d2013-10-18 15:17:11 -0700132 self._stream_reader.feed_eof()
133
134 def pause_writing(self):
135 assert not self._paused
136 self._paused = True
137
138 def resume_writing(self):
139 assert self._paused
140 self._paused = False
141 waiter = self._drain_waiter
142 if waiter is not None:
143 self._drain_waiter = None
144 if not waiter.done():
145 waiter.set_result(None)
146
147
148class StreamWriter:
149 """Wraps a Transport.
150
151 This exposes write(), writelines(), [can_]write_eof(),
152 get_extra_info() and close(). It adds drain() which returns an
153 optional Future on which you can wait for flow control. It also
Guido van Rossumefef9d32014-01-10 13:26:38 -0800154 adds a transport property which references the Transport
Guido van Rossum355491d2013-10-18 15:17:11 -0700155 directly.
156 """
157
158 def __init__(self, transport, protocol, reader, loop):
159 self._transport = transport
160 self._protocol = protocol
161 self._reader = reader
162 self._loop = loop
163
164 @property
165 def transport(self):
166 return self._transport
167
168 def write(self, data):
169 self._transport.write(data)
170
171 def writelines(self, data):
172 self._transport.writelines(data)
173
174 def write_eof(self):
175 return self._transport.write_eof()
176
177 def can_write_eof(self):
178 return self._transport.can_write_eof()
179
180 def close(self):
181 return self._transport.close()
182
183 def get_extra_info(self, name, default=None):
184 return self._transport.get_extra_info(name, default)
185
186 def drain(self):
187 """This method has an unusual return value.
188
189 The intended use is to write
190
191 w.write(data)
192 yield from w.drain()
193
194 When there's nothing to wait for, drain() returns (), and the
195 yield-from continues immediately. When the transport buffer
196 is full (the protocol is paused), drain() creates and returns
197 a Future and the yield-from will block until that Future is
198 completed, which will happen when the buffer is (partially)
199 drained and the protocol is resumed.
200 """
201 if self._reader._exception is not None:
Guido van Rossum6188bd42014-01-07 17:03:26 -0800202 raise self._reader._exception
Guido van Rossum355491d2013-10-18 15:17:11 -0700203 if self._transport._conn_lost: # Uses private variable.
204 raise ConnectionResetError('Connection lost')
205 if not self._protocol._paused:
206 return ()
207 waiter = self._protocol._drain_waiter
208 assert waiter is None or waiter.cancelled()
209 waiter = futures.Future(loop=self._loop)
210 self._protocol._drain_waiter = waiter
211 return waiter
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700212
213
214class StreamReader:
215
216 def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
217 # The line length limit is a security feature;
218 # it also doubles as half the buffer limit.
Guido van Rossum355491d2013-10-18 15:17:11 -0700219 self._limit = limit
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700220 if loop is None:
221 loop = events.get_event_loop()
Guido van Rossum355491d2013-10-18 15:17:11 -0700222 self._loop = loop
Guido van Rossum38455212014-01-06 16:09:18 -0800223 # TODO: Use a bytearray for a buffer, like the transport.
Guido van Rossum355491d2013-10-18 15:17:11 -0700224 self._buffer = collections.deque() # Deque of bytes objects.
225 self._byte_count = 0 # Bytes in buffer.
226 self._eof = False # Whether we're done.
227 self._waiter = None # A future.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700228 self._exception = None
229 self._transport = None
230 self._paused = False
231
232 def exception(self):
233 return self._exception
234
235 def set_exception(self, exc):
236 self._exception = exc
237
Guido van Rossum355491d2013-10-18 15:17:11 -0700238 waiter = self._waiter
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700239 if waiter is not None:
Guido van Rossum355491d2013-10-18 15:17:11 -0700240 self._waiter = None
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700241 if not waiter.cancelled():
242 waiter.set_exception(exc)
243
244 def set_transport(self, transport):
245 assert self._transport is None, 'Transport already set'
246 self._transport = transport
247
248 def _maybe_resume_transport(self):
Guido van Rossum355491d2013-10-18 15:17:11 -0700249 if self._paused and self._byte_count <= self._limit:
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700250 self._paused = False
Guido van Rossum57497ad2013-10-18 07:58:20 -0700251 self._transport.resume_reading()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700252
253 def feed_eof(self):
Guido van Rossum355491d2013-10-18 15:17:11 -0700254 self._eof = True
255 waiter = self._waiter
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700256 if waiter is not None:
Guido van Rossum355491d2013-10-18 15:17:11 -0700257 self._waiter = None
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700258 if not waiter.cancelled():
259 waiter.set_result(True)
260
261 def feed_data(self, data):
262 if not data:
263 return
264
Guido van Rossum355491d2013-10-18 15:17:11 -0700265 self._buffer.append(data)
266 self._byte_count += len(data)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700267
Guido van Rossum355491d2013-10-18 15:17:11 -0700268 waiter = self._waiter
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700269 if waiter is not None:
Guido van Rossum355491d2013-10-18 15:17:11 -0700270 self._waiter = None
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700271 if not waiter.cancelled():
272 waiter.set_result(False)
273
274 if (self._transport is not None and
275 not self._paused and
Guido van Rossum355491d2013-10-18 15:17:11 -0700276 self._byte_count > 2*self._limit):
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700277 try:
Guido van Rossum57497ad2013-10-18 07:58:20 -0700278 self._transport.pause_reading()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700279 except NotImplementedError:
280 # The transport can't be paused.
281 # We'll just have to buffer all data.
282 # Forget the transport so we don't keep trying.
283 self._transport = None
284 else:
285 self._paused = True
286
Victor Stinner183e3472014-01-23 17:40:03 +0100287 def _create_waiter(self, func_name):
288 # StreamReader uses a future to link the protocol feed_data() method
289 # to a read coroutine. Running two read coroutines at the same time
290 # would have an unexpected behaviour. It would not possible to know
291 # which coroutine would get the next data.
292 if self._waiter is not None:
293 raise RuntimeError('%s() called while another coroutine is '
294 'already waiting for incoming data' % func_name)
295 return futures.Future(loop=self._loop)
296
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700297 @tasks.coroutine
298 def readline(self):
299 if self._exception is not None:
300 raise self._exception
301
302 parts = []
303 parts_size = 0
304 not_enough = True
305
306 while not_enough:
Guido van Rossum355491d2013-10-18 15:17:11 -0700307 while self._buffer and not_enough:
308 data = self._buffer.popleft()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700309 ichar = data.find(b'\n')
310 if ichar < 0:
311 parts.append(data)
312 parts_size += len(data)
313 else:
314 ichar += 1
315 head, tail = data[:ichar], data[ichar:]
316 if tail:
Guido van Rossum355491d2013-10-18 15:17:11 -0700317 self._buffer.appendleft(tail)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700318 not_enough = False
319 parts.append(head)
320 parts_size += len(head)
321
Guido van Rossum355491d2013-10-18 15:17:11 -0700322 if parts_size > self._limit:
323 self._byte_count -= parts_size
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700324 self._maybe_resume_transport()
325 raise ValueError('Line is too long')
326
Guido van Rossum355491d2013-10-18 15:17:11 -0700327 if self._eof:
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700328 break
329
330 if not_enough:
Victor Stinner183e3472014-01-23 17:40:03 +0100331 self._waiter = self._create_waiter('readline')
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700332 try:
Guido van Rossum355491d2013-10-18 15:17:11 -0700333 yield from self._waiter
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700334 finally:
Guido van Rossum355491d2013-10-18 15:17:11 -0700335 self._waiter = None
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700336
337 line = b''.join(parts)
Guido van Rossum355491d2013-10-18 15:17:11 -0700338 self._byte_count -= parts_size
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700339 self._maybe_resume_transport()
340
341 return line
342
343 @tasks.coroutine
344 def read(self, n=-1):
345 if self._exception is not None:
346 raise self._exception
347
348 if not n:
349 return b''
350
351 if n < 0:
Guido van Rossum355491d2013-10-18 15:17:11 -0700352 while not self._eof:
Victor Stinner183e3472014-01-23 17:40:03 +0100353 self._waiter = self._create_waiter('read')
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700354 try:
Guido van Rossum355491d2013-10-18 15:17:11 -0700355 yield from self._waiter
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700356 finally:
Guido van Rossum355491d2013-10-18 15:17:11 -0700357 self._waiter = None
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700358 else:
Guido van Rossum355491d2013-10-18 15:17:11 -0700359 if not self._byte_count and not self._eof:
Victor Stinner183e3472014-01-23 17:40:03 +0100360 self._waiter = self._create_waiter('read')
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700361 try:
Guido van Rossum355491d2013-10-18 15:17:11 -0700362 yield from self._waiter
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700363 finally:
Guido van Rossum355491d2013-10-18 15:17:11 -0700364 self._waiter = None
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700365
Guido van Rossum355491d2013-10-18 15:17:11 -0700366 if n < 0 or self._byte_count <= n:
367 data = b''.join(self._buffer)
368 self._buffer.clear()
369 self._byte_count = 0
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700370 self._maybe_resume_transport()
371 return data
372
373 parts = []
374 parts_bytes = 0
Guido van Rossum355491d2013-10-18 15:17:11 -0700375 while self._buffer and parts_bytes < n:
376 data = self._buffer.popleft()
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700377 data_bytes = len(data)
378 if n < parts_bytes + data_bytes:
379 data_bytes = n - parts_bytes
380 data, rest = data[:data_bytes], data[data_bytes:]
Guido van Rossum355491d2013-10-18 15:17:11 -0700381 self._buffer.appendleft(rest)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700382
383 parts.append(data)
384 parts_bytes += data_bytes
Guido van Rossum355491d2013-10-18 15:17:11 -0700385 self._byte_count -= data_bytes
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700386 self._maybe_resume_transport()
387
388 return b''.join(parts)
389
390 @tasks.coroutine
391 def readexactly(self, n):
392 if self._exception is not None:
393 raise self._exception
394
Guido van Rossum38455212014-01-06 16:09:18 -0800395 # There used to be "optimized" code here. It created its own
396 # Future and waited until self._buffer had at least the n
397 # bytes, then called read(n). Unfortunately, this could pause
398 # the transport if the argument was larger than the pause
399 # limit (which is twice self._limit). So now we just read()
400 # into a local buffer.
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700401
Guido van Rossum38455212014-01-06 16:09:18 -0800402 blocks = []
403 while n > 0:
404 block = yield from self.read(n)
405 if not block:
406 break
407 blocks.append(block)
408 n -= len(block)
Guido van Rossum27b7c7e2013-10-17 13:40:50 -0700409
Guido van Rossum38455212014-01-06 16:09:18 -0800410 # TODO: Raise EOFError if we break before n == 0? (That would
411 # be a change in specification, but I've always had to add an
412 # explicit size check to the caller.)
413
414 return b''.join(blocks)