bpo-34271: Add ssl debugging helpers (GH-10031)

The ssl module now can dump key material to a keylog file and trace TLS
protocol messages with a tracing callback. The default and stdlib
contexts also support SSLKEYLOGFILE env var.

The msg_callback and related enums are private members. The feature
is designed for internal debugging and not for end users.

Signed-off-by: Christian Heimes <christian@python.org>
diff --git a/Lib/ssl.py b/Lib/ssl.py
index 793ed49..f5fa6ae 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -165,6 +165,90 @@
     MAXIMUM_SUPPORTED = _ssl.PROTO_MAXIMUM_SUPPORTED
 
 
+class _TLSContentType(_IntEnum):
+    """Content types (record layer)
+
+    See RFC 8446, section B.1
+    """
+    CHANGE_CIPHER_SPEC = 20
+    ALERT = 21
+    HANDSHAKE = 22
+    APPLICATION_DATA = 23
+    # pseudo content types
+    HEADER = 0x100
+    INNER_CONTENT_TYPE = 0x101
+
+
+class _TLSAlertType(_IntEnum):
+    """Alert types for TLSContentType.ALERT messages
+
+    See RFC 8466, section B.2
+    """
+    CLOSE_NOTIFY = 0
+    UNEXPECTED_MESSAGE = 10
+    BAD_RECORD_MAC = 20
+    DECRYPTION_FAILED = 21
+    RECORD_OVERFLOW = 22
+    DECOMPRESSION_FAILURE = 30
+    HANDSHAKE_FAILURE = 40
+    NO_CERTIFICATE = 41
+    BAD_CERTIFICATE = 42
+    UNSUPPORTED_CERTIFICATE = 43
+    CERTIFICATE_REVOKED = 44
+    CERTIFICATE_EXPIRED = 45
+    CERTIFICATE_UNKNOWN = 46
+    ILLEGAL_PARAMETER = 47
+    UNKNOWN_CA = 48
+    ACCESS_DENIED = 49
+    DECODE_ERROR = 50
+    DECRYPT_ERROR = 51
+    EXPORT_RESTRICTION = 60
+    PROTOCOL_VERSION = 70
+    INSUFFICIENT_SECURITY = 71
+    INTERNAL_ERROR = 80
+    INAPPROPRIATE_FALLBACK = 86
+    USER_CANCELED = 90
+    NO_RENEGOTIATION = 100
+    MISSING_EXTENSION = 109
+    UNSUPPORTED_EXTENSION = 110
+    CERTIFICATE_UNOBTAINABLE = 111
+    UNRECOGNIZED_NAME = 112
+    BAD_CERTIFICATE_STATUS_RESPONSE = 113
+    BAD_CERTIFICATE_HASH_VALUE = 114
+    UNKNOWN_PSK_IDENTITY = 115
+    CERTIFICATE_REQUIRED = 116
+    NO_APPLICATION_PROTOCOL = 120
+
+
+class _TLSMessageType(_IntEnum):
+    """Message types (handshake protocol)
+
+    See RFC 8446, section B.3
+    """
+    HELLO_REQUEST = 0
+    CLIENT_HELLO = 1
+    SERVER_HELLO = 2
+    HELLO_VERIFY_REQUEST = 3
+    NEWSESSION_TICKET = 4
+    END_OF_EARLY_DATA = 5
+    HELLO_RETRY_REQUEST = 6
+    ENCRYPTED_EXTENSIONS = 8
+    CERTIFICATE = 11
+    SERVER_KEY_EXCHANGE = 12
+    CERTIFICATE_REQUEST = 13
+    SERVER_DONE = 14
+    CERTIFICATE_VERIFY = 15
+    CLIENT_KEY_EXCHANGE = 16
+    FINISHED = 20
+    CERTIFICATE_URL = 21
+    CERTIFICATE_STATUS = 22
+    SUPPLEMENTAL_DATA = 23
+    KEY_UPDATE = 24
+    NEXT_PROTO = 67
+    MESSAGE_HASH = 254
+    CHANGE_CIPHER_SPEC = 0x0101
+
+
 if sys.platform == "win32":
     from _ssl import enum_certificates, enum_crls
 
@@ -524,6 +608,83 @@
             return True
 
     @property
+    def _msg_callback(self):
+        """TLS message callback
+
+        The message callback provides a debugging hook to analyze TLS
+        connections. The callback is called for any TLS protocol message
+        (header, handshake, alert, and more), but not for application data.
+        Due to technical  limitations, the callback can't be used to filter
+        traffic or to abort a connection. Any exception raised in the
+        callback is delayed until the handshake, read, or write operation
+        has been performed.
+
+        def msg_cb(conn, direction, version, content_type, msg_type, data):
+            pass
+
+        conn
+            :class:`SSLSocket` or :class:`SSLObject` instance
+        direction
+            ``read`` or ``write``
+        version
+            :class:`TLSVersion` enum member or int for unknown version. For a
+            frame header, it's the header version.
+        content_type
+            :class:`_TLSContentType` enum member or int for unsupported
+            content type.
+        msg_type
+            Either a :class:`_TLSContentType` enum number for a header
+            message, a :class:`_TLSAlertType` enum member for an alert
+            message, a :class:`_TLSMessageType` enum member for other
+            messages, or int for unsupported message types.
+        data
+            Raw, decrypted message content as bytes
+        """
+        inner = super()._msg_callback
+        if inner is not None:
+            return inner.user_function
+        else:
+            return None
+
+    @_msg_callback.setter
+    def _msg_callback(self, callback):
+        if callback is None:
+            super(SSLContext, SSLContext)._msg_callback.__set__(self, None)
+            return
+
+        if not hasattr(callback, '__call__'):
+            raise TypeError(f"{callback} is not callable.")
+
+        def inner(conn, direction, version, content_type, msg_type, data):
+            try:
+                version = TLSVersion(version)
+            except TypeError:
+                pass
+
+            try:
+                content_type = _TLSContentType(content_type)
+            except TypeError:
+                pass
+
+            if content_type == _TLSContentType.HEADER:
+                msg_enum = _TLSContentType
+            elif content_type == _TLSContentType.ALERT:
+                msg_enum = _TLSAlertType
+            else:
+                msg_enum = _TLSMessageType
+            try:
+                msg_type = msg_enum(msg_type)
+            except TypeError:
+                pass
+
+            return callback(conn, direction, version,
+                            content_type, msg_type, data)
+
+        inner.user_function = callback
+
+        super(SSLContext, SSLContext)._msg_callback.__set__(self, inner)
+
+    @property
     def protocol(self):
         return _SSLMethod(super().protocol)
 
@@ -576,6 +737,11 @@
         # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
         # root CA certificates for the given purpose. This may fail silently.
         context.load_default_certs(purpose)
+    # OpenSSL 1.1.1 keylog file
+    if hasattr(context, 'keylog_filename'):
+        keylogfile = os.environ.get('SSLKEYLOGFILE')
+        if keylogfile and not sys.flags.ignore_environment:
+            context.keylog_filename = keylogfile
     return context
 
 def _create_unverified_context(protocol=PROTOCOL_TLS, *, cert_reqs=CERT_NONE,
@@ -617,7 +783,11 @@
         # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
         # root CA certificates for the given purpose. This may fail silently.
         context.load_default_certs(purpose)
-
+    # OpenSSL 1.1.1 keylog file
+    if hasattr(context, 'keylog_filename'):
+        keylogfile = os.environ.get('SSLKEYLOGFILE')
+        if keylogfile and not sys.flags.ignore_environment:
+            context.keylog_filename = keylogfile
     return context
 
 # Used by http.client if no context is explicitly passed.
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index d48d6e5..f368906 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -2,6 +2,7 @@
 
 import sys
 import unittest
+import unittest.mock
 from test import support
 import socket
 import select
@@ -25,6 +26,7 @@
 
 ssl = support.import_module("ssl")
 
+from ssl import TLSVersion, _TLSContentType, _TLSMessageType, _TLSAlertType
 
 PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
 HOST = support.HOST
@@ -4405,6 +4407,170 @@
                 self.assertIn(b'WRONG_SSL_VERSION', s.recv(1024))
 
 
+HAS_KEYLOG = hasattr(ssl.SSLContext, 'keylog_filename')
+requires_keylog = unittest.skipUnless(
+    HAS_KEYLOG, 'test requires OpenSSL 1.1.1 with keylog callback')
+
+class TestSSLDebug(unittest.TestCase):
+
+    def keylog_lines(self, fname=support.TESTFN):
+        with open(fname) as f:
+            return len(list(f))
+
+    @requires_keylog
+    def test_keylog_defaults(self):
+        self.addCleanup(support.unlink, support.TESTFN)
+        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+        self.assertEqual(ctx.keylog_filename, None)
+
+        self.assertFalse(os.path.isfile(support.TESTFN))
+        ctx.keylog_filename = support.TESTFN
+        self.assertEqual(ctx.keylog_filename, support.TESTFN)
+        self.assertTrue(os.path.isfile(support.TESTFN))
+        self.assertEqual(self.keylog_lines(), 1)
+
+        ctx.keylog_filename = None
+        self.assertEqual(ctx.keylog_filename, None)
+
+        with self.assertRaises((IsADirectoryError, PermissionError)):
+            # Windows raises PermissionError
+            ctx.keylog_filename = os.path.dirname(
+                os.path.abspath(support.TESTFN))
+
+        with self.assertRaises(TypeError):
+            ctx.keylog_filename = 1
+
+    @requires_keylog
+    def test_keylog_filename(self):
+        self.addCleanup(support.unlink, support.TESTFN)
+        client_context, server_context, hostname = testing_context()
+
+        client_context.keylog_filename = support.TESTFN
+        server = ThreadedEchoServer(context=server_context, chatty=False)
+        with server:
+            with client_context.wrap_socket(socket.socket(),
+                                            server_hostname=hostname) as s:
+                s.connect((HOST, server.port))
+        # header, 5 lines for TLS 1.3
+        self.assertEqual(self.keylog_lines(), 6)
+
+        client_context.keylog_filename = None
+        server_context.keylog_filename = support.TESTFN
+        server = ThreadedEchoServer(context=server_context, chatty=False)
+        with server:
+            with client_context.wrap_socket(socket.socket(),
+                                            server_hostname=hostname) as s:
+                s.connect((HOST, server.port))
+        self.assertGreaterEqual(self.keylog_lines(), 11)
+
+        client_context.keylog_filename = support.TESTFN
+        server_context.keylog_filename = support.TESTFN
+        server = ThreadedEchoServer(context=server_context, chatty=False)
+        with server:
+            with client_context.wrap_socket(socket.socket(),
+                                            server_hostname=hostname) as s:
+                s.connect((HOST, server.port))
+        self.assertGreaterEqual(self.keylog_lines(), 21)
+
+        client_context.keylog_filename = None
+        server_context.keylog_filename = None
+
+    @requires_keylog
+    @unittest.skipIf(sys.flags.ignore_environment,
+                     "test is not compatible with ignore_environment")
+    def test_keylog_env(self):
+        self.addCleanup(support.unlink, support.TESTFN)
+        with unittest.mock.patch.dict(os.environ):
+            os.environ['SSLKEYLOGFILE'] = support.TESTFN
+            self.assertEqual(os.environ['SSLKEYLOGFILE'], support.TESTFN)
+
+            ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+            self.assertEqual(ctx.keylog_filename, None)
+
+            ctx = ssl.create_default_context()
+            self.assertEqual(ctx.keylog_filename, support.TESTFN)
+
+            ctx = ssl._create_stdlib_context()
+            self.assertEqual(ctx.keylog_filename, support.TESTFN)
+
+    def test_msg_callback(self):
+        client_context, server_context, hostname = testing_context()
+
+        def msg_cb(conn, direction, version, content_type, msg_type, data):
+            pass
+
+        self.assertIs(client_context._msg_callback, None)
+        client_context._msg_callback = msg_cb
+        self.assertIs(client_context._msg_callback, msg_cb)
+        with self.assertRaises(TypeError):
+            client_context._msg_callback = object()
+
+    def test_msg_callback_tls12(self):
+        client_context, server_context, hostname = testing_context()
+        client_context.options |= ssl.OP_NO_TLSv1_3
+
+        msg = []
+
+        def msg_cb(conn, direction, version, content_type, msg_type, data):
+            self.assertIsInstance(conn, ssl.SSLSocket)
+            self.assertIsInstance(data, bytes)
+            self.assertIn(direction, {'read', 'write'})
+            msg.append((direction, version, content_type, msg_type))
+
+        client_context._msg_callback = msg_cb
+
+        server = ThreadedEchoServer(context=server_context, chatty=False)
+        with server:
+            with client_context.wrap_socket(socket.socket(),
+                                            server_hostname=hostname) as s:
+                s.connect((HOST, server.port))
+
+        self.assertEqual(msg, [
+            ("write", TLSVersion.TLSv1, _TLSContentType.HEADER,
+             _TLSMessageType.CERTIFICATE_STATUS),
+            ("write", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+             _TLSMessageType.CLIENT_HELLO),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+             _TLSMessageType.CERTIFICATE_STATUS),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+             _TLSMessageType.SERVER_HELLO),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+             _TLSMessageType.CERTIFICATE_STATUS),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+             _TLSMessageType.CERTIFICATE),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+             _TLSMessageType.CERTIFICATE_STATUS),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+             _TLSMessageType.SERVER_KEY_EXCHANGE),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+             _TLSMessageType.CERTIFICATE_STATUS),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+             _TLSMessageType.SERVER_DONE),
+            ("write", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+             _TLSMessageType.CERTIFICATE_STATUS),
+            ("write", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+             _TLSMessageType.CLIENT_KEY_EXCHANGE),
+            ("write", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+             _TLSMessageType.FINISHED),
+            ("write", TLSVersion.TLSv1_2, _TLSContentType.CHANGE_CIPHER_SPEC,
+             _TLSMessageType.CHANGE_CIPHER_SPEC),
+            ("write", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+             _TLSMessageType.CERTIFICATE_STATUS),
+            ("write", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+             _TLSMessageType.FINISHED),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+             _TLSMessageType.CERTIFICATE_STATUS),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+             _TLSMessageType.NEWSESSION_TICKET),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+             _TLSMessageType.FINISHED),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+             _TLSMessageType.CERTIFICATE_STATUS),
+            ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+             _TLSMessageType.FINISHED),
+        ])
+
+
 def test_main(verbose=False):
     if support.verbose:
         import warnings
@@ -4440,7 +4606,7 @@
     tests = [
         ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests,
         SSLObjectTests, SimpleBackgroundTests, ThreadedTests,
-        TestPostHandshakeAuth
+        TestPostHandshakeAuth, TestSSLDebug
     ]
 
     if support.is_resource_enabled('network'):