blob: d0f12e81df9ab291a867c7633b720fdd34866504 [file] [log] [blame]
Guido van Rossum27b7c7e2013-10-17 13:40:50 -07001"""Stream-related things."""
2
3__all__ = ['StreamReader', 'StreamReaderProtocol', 'open_connection']
4
5import collections
6
7from . import events
8from . import futures
9from . import protocols
10from . import tasks
11
12
13_DEFAULT_LIMIT = 2**16
14
15
16@tasks.coroutine
17def open_connection(host=None, port=None, *,
18 loop=None, limit=_DEFAULT_LIMIT, **kwds):
19 """A wrapper for create_connection() returning a (reader, writer) pair.
20
21 The reader returned is a StreamReader instance; the writer is a
22 Transport.
23
24 The arguments are all the usual arguments to create_connection()
25 except protocol_factory; most common are positional host and port,
26 with various optional keyword arguments following.
27
28 Additional optional keyword arguments are loop (to set the event loop
29 instance to use) and limit (to set the buffer limit passed to the
30 StreamReader).
31
32 (If you want to customize the StreamReader and/or
33 StreamReaderProtocol classes, just copy the code -- there's
34 really nothing special here except some convenience.)
35 """
36 if loop is None:
37 loop = events.get_event_loop()
38 reader = StreamReader(limit=limit, loop=loop)
39 protocol = StreamReaderProtocol(reader)
40 transport, _ = yield from loop.create_connection(
41 lambda: protocol, host, port, **kwds)
42 return reader, transport # (reader, writer)
43
44
45class StreamReaderProtocol(protocols.Protocol):
46 """Trivial helper class to adapt between Protocol and StreamReader.
47
48 (This is a helper class instead of making StreamReader itself a
49 Protocol subclass, because the StreamReader has other potential
50 uses, and to prevent the user of the StreamReader to accidentally
51 call inappropriate methods of the protocol.)
52 """
53
54 def __init__(self, stream_reader):
55 self.stream_reader = stream_reader
56
57 def connection_made(self, transport):
58 self.stream_reader.set_transport(transport)
59
60 def connection_lost(self, exc):
61 if exc is None:
62 self.stream_reader.feed_eof()
63 else:
64 self.stream_reader.set_exception(exc)
65
66 def data_received(self, data):
67 self.stream_reader.feed_data(data)
68
69 def eof_received(self):
70 self.stream_reader.feed_eof()
71
72
73class StreamReader:
74
75 def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
76 # The line length limit is a security feature;
77 # it also doubles as half the buffer limit.
78 self.limit = limit
79 if loop is None:
80 loop = events.get_event_loop()
81 self.loop = loop
82 self.buffer = collections.deque() # Deque of bytes objects.
83 self.byte_count = 0 # Bytes in buffer.
84 self.eof = False # Whether we're done.
85 self.waiter = None # A future.
86 self._exception = None
87 self._transport = None
88 self._paused = False
89
90 def exception(self):
91 return self._exception
92
93 def set_exception(self, exc):
94 self._exception = exc
95
96 waiter = self.waiter
97 if waiter is not None:
98 self.waiter = None
99 if not waiter.cancelled():
100 waiter.set_exception(exc)
101
102 def set_transport(self, transport):
103 assert self._transport is None, 'Transport already set'
104 self._transport = transport
105
106 def _maybe_resume_transport(self):
107 if self._paused and self.byte_count <= self.limit:
108 self._paused = False
109 self._transport.resume()
110
111 def feed_eof(self):
112 self.eof = True
113 waiter = self.waiter
114 if waiter is not None:
115 self.waiter = None
116 if not waiter.cancelled():
117 waiter.set_result(True)
118
119 def feed_data(self, data):
120 if not data:
121 return
122
123 self.buffer.append(data)
124 self.byte_count += len(data)
125
126 waiter = self.waiter
127 if waiter is not None:
128 self.waiter = None
129 if not waiter.cancelled():
130 waiter.set_result(False)
131
132 if (self._transport is not None and
133 not self._paused and
134 self.byte_count > 2*self.limit):
135 try:
136 self._transport.pause()
137 except NotImplementedError:
138 # The transport can't be paused.
139 # We'll just have to buffer all data.
140 # Forget the transport so we don't keep trying.
141 self._transport = None
142 else:
143 self._paused = True
144
145 @tasks.coroutine
146 def readline(self):
147 if self._exception is not None:
148 raise self._exception
149
150 parts = []
151 parts_size = 0
152 not_enough = True
153
154 while not_enough:
155 while self.buffer and not_enough:
156 data = self.buffer.popleft()
157 ichar = data.find(b'\n')
158 if ichar < 0:
159 parts.append(data)
160 parts_size += len(data)
161 else:
162 ichar += 1
163 head, tail = data[:ichar], data[ichar:]
164 if tail:
165 self.buffer.appendleft(tail)
166 not_enough = False
167 parts.append(head)
168 parts_size += len(head)
169
170 if parts_size > self.limit:
171 self.byte_count -= parts_size
172 self._maybe_resume_transport()
173 raise ValueError('Line is too long')
174
175 if self.eof:
176 break
177
178 if not_enough:
179 assert self.waiter is None
180 self.waiter = futures.Future(loop=self.loop)
181 try:
182 yield from self.waiter
183 finally:
184 self.waiter = None
185
186 line = b''.join(parts)
187 self.byte_count -= parts_size
188 self._maybe_resume_transport()
189
190 return line
191
192 @tasks.coroutine
193 def read(self, n=-1):
194 if self._exception is not None:
195 raise self._exception
196
197 if not n:
198 return b''
199
200 if n < 0:
201 while not self.eof:
202 assert not self.waiter
203 self.waiter = futures.Future(loop=self.loop)
204 try:
205 yield from self.waiter
206 finally:
207 self.waiter = None
208 else:
209 if not self.byte_count and not self.eof:
210 assert not self.waiter
211 self.waiter = futures.Future(loop=self.loop)
212 try:
213 yield from self.waiter
214 finally:
215 self.waiter = None
216
217 if n < 0 or self.byte_count <= n:
218 data = b''.join(self.buffer)
219 self.buffer.clear()
220 self.byte_count = 0
221 self._maybe_resume_transport()
222 return data
223
224 parts = []
225 parts_bytes = 0
226 while self.buffer and parts_bytes < n:
227 data = self.buffer.popleft()
228 data_bytes = len(data)
229 if n < parts_bytes + data_bytes:
230 data_bytes = n - parts_bytes
231 data, rest = data[:data_bytes], data[data_bytes:]
232 self.buffer.appendleft(rest)
233
234 parts.append(data)
235 parts_bytes += data_bytes
236 self.byte_count -= data_bytes
237 self._maybe_resume_transport()
238
239 return b''.join(parts)
240
241 @tasks.coroutine
242 def readexactly(self, n):
243 if self._exception is not None:
244 raise self._exception
245
246 if n <= 0:
247 return b''
248
249 while self.byte_count < n and not self.eof:
250 assert not self.waiter
251 self.waiter = futures.Future(loop=self.loop)
252 try:
253 yield from self.waiter
254 finally:
255 self.waiter = None
256
257 return (yield from self.read(n))