blob: 987c158ee73bafb338161fee62538f0e101db343 [file] [log] [blame]
Victor Stinner231b4042015-01-14 00:19:09 +01001import collections
2try:
3 import ssl
4except ImportError: # pragma: no cover
5 ssl = None
6
7from . import protocols
8from . import transports
9from .log import logger
10
11
12def _create_transport_context(server_side, server_hostname):
13 if server_side:
14 raise ValueError('Server side SSL needs a valid SSLContext')
15
16 # Client side may pass ssl=True to use a default
17 # context; in that case the sslcontext passed is None.
18 # The default is secure for client connections.
19 if hasattr(ssl, 'create_default_context'):
20 # Python 3.4+: use up-to-date strong settings.
21 sslcontext = ssl.create_default_context()
22 if not server_hostname:
23 sslcontext.check_hostname = False
24 else:
25 # Fallback for Python 3.3.
26 sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
27 sslcontext.options |= ssl.OP_NO_SSLv2
28 sslcontext.options |= ssl.OP_NO_SSLv3
29 sslcontext.set_default_verify_paths()
30 sslcontext.verify_mode = ssl.CERT_REQUIRED
31 return sslcontext
32
33
34def _is_sslproto_available():
35 return hasattr(ssl, "MemoryBIO")
36
37
38# States of an _SSLPipe.
39_UNWRAPPED = "UNWRAPPED"
40_DO_HANDSHAKE = "DO_HANDSHAKE"
41_WRAPPED = "WRAPPED"
42_SHUTDOWN = "SHUTDOWN"
43
44
45class _SSLPipe(object):
46 """An SSL "Pipe".
47
48 An SSL pipe allows you to communicate with an SSL/TLS protocol instance
49 through memory buffers. It can be used to implement a security layer for an
50 existing connection where you don't have access to the connection's file
51 descriptor, or for some reason you don't want to use it.
52
53 An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode,
54 data is passed through untransformed. In wrapped mode, application level
55 data is encrypted to SSL record level data and vice versa. The SSL record
56 level is the lowest level in the SSL protocol suite and is what travels
57 as-is over the wire.
58
59 An SslPipe initially is in "unwrapped" mode. To start SSL, call
60 do_handshake(). To shutdown SSL again, call unwrap().
61 """
62
63 max_size = 256 * 1024 # Buffer size passed to read()
64
65 def __init__(self, context, server_side, server_hostname=None):
66 """
67 The *context* argument specifies the ssl.SSLContext to use.
68
69 The *server_side* argument indicates whether this is a server side or
70 client side transport.
71
72 The optional *server_hostname* argument can be used to specify the
73 hostname you are connecting to. You may only specify this parameter if
74 the _ssl module supports Server Name Indication (SNI).
75 """
76 self._context = context
77 self._server_side = server_side
78 self._server_hostname = server_hostname
79 self._state = _UNWRAPPED
80 self._incoming = ssl.MemoryBIO()
81 self._outgoing = ssl.MemoryBIO()
82 self._sslobj = None
83 self._need_ssldata = False
84 self._handshake_cb = None
85 self._shutdown_cb = None
86
87 @property
88 def context(self):
89 """The SSL context passed to the constructor."""
90 return self._context
91
92 @property
93 def ssl_object(self):
94 """The internal ssl.SSLObject instance.
95
96 Return None if the pipe is not wrapped.
97 """
98 return self._sslobj
99
100 @property
101 def need_ssldata(self):
102 """Whether more record level data is needed to complete a handshake
103 that is currently in progress."""
104 return self._need_ssldata
105
106 @property
107 def wrapped(self):
108 """
109 Whether a security layer is currently in effect.
110
111 Return False during handshake.
112 """
113 return self._state == _WRAPPED
114
115 def do_handshake(self, callback=None):
116 """Start the SSL handshake.
117
118 Return a list of ssldata. A ssldata element is a list of buffers
119
120 The optional *callback* argument can be used to install a callback that
121 will be called when the handshake is complete. The callback will be
122 called with None if successful, else an exception instance.
123 """
124 if self._state != _UNWRAPPED:
125 raise RuntimeError('handshake in progress or completed')
126 self._sslobj = self._context.wrap_bio(
127 self._incoming, self._outgoing,
128 server_side=self._server_side,
129 server_hostname=self._server_hostname)
130 self._state = _DO_HANDSHAKE
131 self._handshake_cb = callback
132 ssldata, appdata = self.feed_ssldata(b'', only_handshake=True)
133 assert len(appdata) == 0
134 return ssldata
135
136 def shutdown(self, callback=None):
137 """Start the SSL shutdown sequence.
138
139 Return a list of ssldata. A ssldata element is a list of buffers
140
141 The optional *callback* argument can be used to install a callback that
142 will be called when the shutdown is complete. The callback will be
143 called without arguments.
144 """
145 if self._state == _UNWRAPPED:
146 raise RuntimeError('no security layer present')
147 if self._state == _SHUTDOWN:
148 raise RuntimeError('shutdown in progress')
149 assert self._state in (_WRAPPED, _DO_HANDSHAKE)
150 self._state = _SHUTDOWN
151 self._shutdown_cb = callback
152 ssldata, appdata = self.feed_ssldata(b'')
153 assert appdata == [] or appdata == [b'']
154 return ssldata
155
156 def feed_eof(self):
157 """Send a potentially "ragged" EOF.
158
159 This method will raise an SSL_ERROR_EOF exception if the EOF is
160 unexpected.
161 """
162 self._incoming.write_eof()
163 ssldata, appdata = self.feed_ssldata(b'')
164 assert appdata == [] or appdata == [b'']
165
166 def feed_ssldata(self, data, only_handshake=False):
167 """Feed SSL record level data into the pipe.
168
169 The data must be a bytes instance. It is OK to send an empty bytes
170 instance. This can be used to get ssldata for a handshake initiated by
171 this endpoint.
172
173 Return a (ssldata, appdata) tuple. The ssldata element is a list of
174 buffers containing SSL data that needs to be sent to the remote SSL.
175
176 The appdata element is a list of buffers containing plaintext data that
177 needs to be forwarded to the application. The appdata list may contain
178 an empty buffer indicating an SSL "close_notify" alert. This alert must
179 be acknowledged by calling shutdown().
180 """
181 if self._state == _UNWRAPPED:
182 # If unwrapped, pass plaintext data straight through.
183 if data:
184 appdata = [data]
185 else:
186 appdata = []
187 return ([], appdata)
188
189 self._need_ssldata = False
190 if data:
191 self._incoming.write(data)
192
193 ssldata = []
194 appdata = []
195 try:
196 if self._state == _DO_HANDSHAKE:
197 # Call do_handshake() until it doesn't raise anymore.
198 self._sslobj.do_handshake()
199 self._state = _WRAPPED
200 if self._handshake_cb:
201 self._handshake_cb(None)
202 if only_handshake:
203 return (ssldata, appdata)
204 # Handshake done: execute the wrapped block
205
206 if self._state == _WRAPPED:
207 # Main state: read data from SSL until close_notify
208 while True:
209 chunk = self._sslobj.read(self.max_size)
210 appdata.append(chunk)
211 if not chunk: # close_notify
212 break
213
214 elif self._state == _SHUTDOWN:
215 # Call shutdown() until it doesn't raise anymore.
216 self._sslobj.unwrap()
217 self._sslobj = None
218 self._state = _UNWRAPPED
219 if self._shutdown_cb:
220 self._shutdown_cb()
221
222 elif self._state == _UNWRAPPED:
223 # Drain possible plaintext data after close_notify.
224 appdata.append(self._incoming.read())
225 except (ssl.SSLError, ssl.CertificateError) as exc:
226 if getattr(exc, 'errno', None) not in (
227 ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE,
228 ssl.SSL_ERROR_SYSCALL):
229 if self._state == _DO_HANDSHAKE and self._handshake_cb:
230 self._handshake_cb(exc)
231 raise
232 self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ)
233
234 # Check for record level data that needs to be sent back.
235 # Happens for the initial handshake and renegotiations.
236 if self._outgoing.pending:
237 ssldata.append(self._outgoing.read())
238 return (ssldata, appdata)
239
240 def feed_appdata(self, data, offset=0):
241 """Feed plaintext data into the pipe.
242
243 Return an (ssldata, offset) tuple. The ssldata element is a list of
244 buffers containing record level data that needs to be sent to the
245 remote SSL instance. The offset is the number of plaintext bytes that
246 were processed, which may be less than the length of data.
247
248 NOTE: In case of short writes, this call MUST be retried with the SAME
249 buffer passed into the *data* argument (i.e. the id() must be the
250 same). This is an OpenSSL requirement. A further particularity is that
251 a short write will always have offset == 0, because the _ssl module
252 does not enable partial writes. And even though the offset is zero,
253 there will still be encrypted data in ssldata.
254 """
255 assert 0 <= offset <= len(data)
256 if self._state == _UNWRAPPED:
257 # pass through data in unwrapped mode
258 if offset < len(data):
259 ssldata = [data[offset:]]
260 else:
261 ssldata = []
262 return (ssldata, len(data))
263
264 ssldata = []
265 view = memoryview(data)
266 while True:
267 self._need_ssldata = False
268 try:
269 if offset < len(view):
270 offset += self._sslobj.write(view[offset:])
271 except ssl.SSLError as exc:
272 # It is not allowed to call write() after unwrap() until the
273 # close_notify is acknowledged. We return the condition to the
274 # caller as a short write.
275 if exc.reason == 'PROTOCOL_IS_SHUTDOWN':
276 exc.errno = ssl.SSL_ERROR_WANT_READ
277 if exc.errno not in (ssl.SSL_ERROR_WANT_READ,
278 ssl.SSL_ERROR_WANT_WRITE,
279 ssl.SSL_ERROR_SYSCALL):
280 raise
281 self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ)
282
283 # See if there's any record level data back for us.
284 if self._outgoing.pending:
285 ssldata.append(self._outgoing.read())
286 if offset == len(view) or self._need_ssldata:
287 break
288 return (ssldata, offset)
289
290
291class _SSLProtocolTransport(transports._FlowControlMixin,
292 transports.Transport):
293
294 def __init__(self, loop, ssl_protocol, app_protocol):
295 self._loop = loop
296 self._ssl_protocol = ssl_protocol
297 self._app_protocol = app_protocol
298
299 def get_extra_info(self, name, default=None):
300 """Get optional transport information."""
301 return self._ssl_protocol._get_extra_info(name, default)
302
303 def close(self):
304 """Close the transport.
305
306 Buffered data will be flushed asynchronously. No more data
307 will be received. After all buffered data is flushed, the
308 protocol's connection_lost() method will (eventually) called
309 with None as its argument.
310 """
311 self._ssl_protocol._start_shutdown()
312
313 def pause_reading(self):
314 """Pause the receiving end.
315
316 No data will be passed to the protocol's data_received()
317 method until resume_reading() is called.
318 """
319 self._ssl_protocol._transport.pause_reading()
320
321 def resume_reading(self):
322 """Resume the receiving end.
323
324 Data received will once again be passed to the protocol's
325 data_received() method.
326 """
327 self._ssl_protocol._transport.resume_reading()
328
329 def set_write_buffer_limits(self, high=None, low=None):
330 """Set the high- and low-water limits for write flow control.
331
332 These two values control when to call the protocol's
333 pause_writing() and resume_writing() methods. If specified,
334 the low-water limit must be less than or equal to the
335 high-water limit. Neither value can be negative.
336
337 The defaults are implementation-specific. If only the
338 high-water limit is given, the low-water limit defaults to a
339 implementation-specific value less than or equal to the
340 high-water limit. Setting high to zero forces low to zero as
341 well, and causes pause_writing() to be called whenever the
342 buffer becomes non-empty. Setting low to zero causes
343 resume_writing() to be called only once the buffer is empty.
344 Use of zero for either limit is generally sub-optimal as it
345 reduces opportunities for doing I/O and computation
346 concurrently.
347 """
348 self._ssl_protocol._transport.set_write_buffer_limits(high, low)
349
350 def get_write_buffer_size(self):
351 """Return the current size of the write buffer."""
352 return self._ssl_protocol._transport.get_write_buffer_size()
353
354 def write(self, data):
355 """Write some data bytes to the transport.
356
357 This does not block; it buffers the data and arranges for it
358 to be sent out asynchronously.
359 """
360 if not isinstance(data, (bytes, bytearray, memoryview)):
361 raise TypeError("data: expecting a bytes-like instance, got {!r}"
362 .format(type(data).__name__))
363 if not data:
364 return
365 self._ssl_protocol._write_appdata(data)
366
367 def can_write_eof(self):
368 """Return True if this transport supports write_eof(), False if not."""
369 return False
370
371 def abort(self):
372 """Close the transport immediately.
373
374 Buffered data will be lost. No more data will be received.
375 The protocol's connection_lost() method will (eventually) be
376 called with None as its argument.
377 """
378 self._ssl_protocol._abort()
379
380
381class SSLProtocol(protocols.Protocol):
382 """SSL protocol.
383
384 Implementation of SSL on top of a socket using incoming and outgoing
385 buffers which are ssl.MemoryBIO objects.
386 """
387
388 def __init__(self, loop, app_protocol, sslcontext, waiter,
389 server_side=False, server_hostname=None):
390 if ssl is None:
391 raise RuntimeError('stdlib ssl module not available')
392
393 if not sslcontext:
394 sslcontext = _create_transport_context(server_side, server_hostname)
395
396 self._server_side = server_side
397 if server_hostname and not server_side:
398 self._server_hostname = server_hostname
399 else:
400 self._server_hostname = None
401 self._sslcontext = sslcontext
402 # SSL-specific extra info. More info are set when the handshake
403 # completes.
404 self._extra = dict(sslcontext=sslcontext)
405
406 # App data write buffering
407 self._write_backlog = collections.deque()
408 self._write_buffer_size = 0
409
410 self._waiter = waiter
411 self._closing = False
412 self._loop = loop
413 self._app_protocol = app_protocol
414 self._app_transport = _SSLProtocolTransport(self._loop,
415 self, self._app_protocol)
416 self._sslpipe = None
417 self._session_established = False
418 self._in_handshake = False
419 self._in_shutdown = False
420
421 def connection_made(self, transport):
422 """Called when the low-level connection is made.
423
424 Start the SSL handshake.
425 """
426 self._transport = transport
427 self._sslpipe = _SSLPipe(self._sslcontext,
428 self._server_side,
429 self._server_hostname)
430 self._start_handshake()
431
432 def connection_lost(self, exc):
433 """Called when the low-level connection is lost or closed.
434
435 The argument is an exception object or None (the latter
436 meaning a regular EOF is received or the connection was
437 aborted or closed).
438 """
439 if self._session_established:
440 self._session_established = False
441 self._loop.call_soon(self._app_protocol.connection_lost, exc)
442 self._transport = None
443 self._app_transport = None
444
445 def pause_writing(self):
446 """Called when the low-level transport's buffer goes over
447 the high-water mark.
448 """
449 self._app_protocol.pause_writing()
450
451 def resume_writing(self):
452 """Called when the low-level transport's buffer drains below
453 the low-water mark.
454 """
455 self._app_protocol.resume_writing()
456
457 def data_received(self, data):
458 """Called when some SSL data is received.
459
460 The argument is a bytes object.
461 """
462 try:
463 ssldata, appdata = self._sslpipe.feed_ssldata(data)
464 except ssl.SSLError as e:
465 if self._loop.get_debug():
466 logger.warning('%r: SSL error %s (reason %s)',
467 self, e.errno, e.reason)
468 self._abort()
469 return
470
471 for chunk in ssldata:
472 self._transport.write(chunk)
473
474 for chunk in appdata:
475 if chunk:
476 self._app_protocol.data_received(chunk)
477 else:
478 self._start_shutdown()
479 break
480
481 def eof_received(self):
482 """Called when the other end of the low-level stream
483 is half-closed.
484
485 If this returns a false value (including None), the transport
486 will close itself. If it returns a true value, closing the
487 transport is up to the protocol.
488 """
489 try:
490 if self._loop.get_debug():
491 logger.debug("%r received EOF", self)
492 if not self._in_handshake:
493 keep_open = self._app_protocol.eof_received()
494 if keep_open:
495 logger.warning('returning true from eof_received() '
496 'has no effect when using ssl')
497 finally:
498 self._transport.close()
499
500 def _get_extra_info(self, name, default=None):
501 if name in self._extra:
502 return self._extra[name]
503 else:
504 return self._transport.get_extra_info(name, default)
505
506 def _start_shutdown(self):
507 if self._in_shutdown:
508 return
509 self._in_shutdown = True
510 self._write_appdata(b'')
511
512 def _write_appdata(self, data):
513 self._write_backlog.append((data, 0))
514 self._write_buffer_size += len(data)
515 self._process_write_backlog()
516
517 def _start_handshake(self):
518 if self._loop.get_debug():
519 logger.debug("%r starts SSL handshake", self)
520 self._handshake_start_time = self._loop.time()
521 else:
522 self._handshake_start_time = None
523 self._in_handshake = True
524 # (b'', 1) is a special value in _process_write_backlog() to do
525 # the SSL handshake
526 self._write_backlog.append((b'', 1))
527 self._loop.call_soon(self._process_write_backlog)
528
529 def _on_handshake_complete(self, handshake_exc):
530 self._in_handshake = False
531
532 sslobj = self._sslpipe.ssl_object
533 peercert = None if handshake_exc else sslobj.getpeercert()
534 try:
535 if handshake_exc is not None:
536 raise handshake_exc
537 if not hasattr(self._sslcontext, 'check_hostname'):
538 # Verify hostname if requested, Python 3.4+ uses check_hostname
539 # and checks the hostname in do_handshake()
540 if (self._server_hostname
541 and self._sslcontext.verify_mode != ssl.CERT_NONE):
542 ssl.match_hostname(peercert, self._server_hostname)
543 except BaseException as exc:
544 if self._loop.get_debug():
545 if isinstance(exc, ssl.CertificateError):
546 logger.warning("%r: SSL handshake failed "
547 "on verifying the certificate",
548 self, exc_info=True)
549 else:
550 logger.warning("%r: SSL handshake failed",
551 self, exc_info=True)
552 self._transport.close()
553 if isinstance(exc, Exception):
554 if self._waiter is not None:
555 self._waiter.set_exception(exc)
556 return
557 else:
558 raise
559
560 if self._loop.get_debug():
561 dt = self._loop.time() - self._handshake_start_time
562 logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3)
563
564 # Add extra info that becomes available after handshake.
565 self._extra.update(peercert=peercert,
566 cipher=sslobj.cipher(),
567 compression=sslobj.compression(),
568 )
569 self._app_protocol.connection_made(self._app_transport)
570 if self._waiter is not None:
571 # wait until protocol.connection_made() has been called
572 self._waiter._set_result_unless_cancelled(None)
573 self._session_established = True
574 # In case transport.write() was already called
575 self._process_write_backlog()
576
577 def _process_write_backlog(self):
578 # Try to make progress on the write backlog.
579 if self._transport is None:
580 return
581
582 try:
583 for i in range(len(self._write_backlog)):
584 data, offset = self._write_backlog[0]
585 if data:
586 ssldata, offset = self._sslpipe.feed_appdata(data, offset)
587 elif offset:
588 ssldata = self._sslpipe.do_handshake(self._on_handshake_complete)
589 offset = 1
590 else:
591 ssldata = self._sslpipe.shutdown(self._finalize)
592 offset = 1
593
594 for chunk in ssldata:
595 self._transport.write(chunk)
596
597 if offset < len(data):
598 self._write_backlog[0] = (data, offset)
599 # A short write means that a write is blocked on a read
600 # We need to enable reading if it is paused!
601 assert self._sslpipe.need_ssldata
602 if self._transport._paused:
603 self._transport.resume_reading()
604 break
605
606 # An entire chunk from the backlog was processed. We can
607 # delete it and reduce the outstanding buffer size.
608 del self._write_backlog[0]
609 self._write_buffer_size -= len(data)
610 except BaseException as exc:
611 if self._in_handshake:
612 self._on_handshake_complete(exc)
613 else:
614 self._fatal_error(exc, 'Fatal error on SSL transport')
615
616 def _fatal_error(self, exc, message='Fatal error on transport'):
617 # Should be called from exception handler only.
618 if isinstance(exc, (BrokenPipeError, ConnectionResetError)):
619 if self._loop.get_debug():
620 logger.debug("%r: %s", self, message, exc_info=True)
621 else:
622 self._loop.call_exception_handler({
623 'message': message,
624 'exception': exc,
625 'transport': self._transport,
626 'protocol': self,
627 })
628 if self._transport:
629 self._transport._force_close(exc)
630
631 def _finalize(self):
632 if self._transport is not None:
633 self._transport.close()
634
635 def _abort(self):
636 if self._transport is not None:
637 try:
638 self._transport.abort()
639 finally:
640 self._finalize()