Merged revisions 80151 via svnmerge from
svn+ssh://pythondev@svn.python.org/python/trunk

........
  r80151 | antoine.pitrou | 2010-04-17 19:10:38 +0200 (sam., 17 avril 2010) | 4 lines

  Issue #8322: Add a *ciphers* argument to SSL sockets, so as to change the
  available cipher list.  Helps fix test_ssl with OpenSSL 1.0.0.
........
diff --git a/Lib/ssl.py b/Lib/ssl.py
index 75f542d..a42643f 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -94,7 +94,7 @@
                  ssl_version=PROTOCOL_SSLv23, ca_certs=None,
                  do_handshake_on_connect=True,
                  family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
-                 suppress_ragged_eofs=True):
+                 suppress_ragged_eofs=True, ciphers=None):
 
         if sock is not None:
             socket.__init__(self,
@@ -123,7 +123,8 @@
             try:
                 self._sslobj = _ssl.sslwrap(self, server_side,
                                             keyfile, certfile,
-                                            cert_reqs, ssl_version, ca_certs)
+                                            cert_reqs, ssl_version, ca_certs,
+                                            ciphers)
                 if do_handshake_on_connect:
                     timeout = self.gettimeout()
                     if timeout == 0.0:
@@ -140,6 +141,7 @@
         self.cert_reqs = cert_reqs
         self.ssl_version = ssl_version
         self.ca_certs = ca_certs
+        self.ciphers = ciphers
         self.do_handshake_on_connect = do_handshake_on_connect
         self.suppress_ragged_eofs = suppress_ragged_eofs
 
@@ -325,7 +327,7 @@
         socket.connect(self, addr)
         self._sslobj = _ssl.sslwrap(self, False, self.keyfile, self.certfile,
                                     self.cert_reqs, self.ssl_version,
-                                    self.ca_certs)
+                                    self.ca_certs, self.ciphers)
         try:
             if self.do_handshake_on_connect:
                 self.do_handshake()
@@ -345,6 +347,7 @@
                           cert_reqs=self.cert_reqs,
                           ssl_version=self.ssl_version,
                           ca_certs=self.ca_certs,
+                          ciphers=self.ciphers,
                           do_handshake_on_connect=
                               self.do_handshake_on_connect),
                 addr)
@@ -358,13 +361,14 @@
                 server_side=False, cert_reqs=CERT_NONE,
                 ssl_version=PROTOCOL_SSLv23, ca_certs=None,
                 do_handshake_on_connect=True,
-                suppress_ragged_eofs=True):
+                suppress_ragged_eofs=True, ciphers=None):
 
     return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
                      server_side=server_side, cert_reqs=cert_reqs,
                      ssl_version=ssl_version, ca_certs=ca_certs,
                      do_handshake_on_connect=do_handshake_on_connect,
-                     suppress_ragged_eofs=suppress_ragged_eofs)
+                     suppress_ragged_eofs=suppress_ragged_eofs,
+                     ciphers=ciphers)
 
 # some utility functions
 
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index 1804fcd..c1c59b5 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -121,6 +121,23 @@
         self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)),
                         (s, t))
 
+    def test_ciphers(self):
+        if not support.is_resource_enabled('network'):
+            return
+        remote = ("svn.python.org", 443)
+        s = ssl.wrap_socket(socket.socket(socket.AF_INET),
+                            cert_reqs=ssl.CERT_NONE, ciphers="ALL")
+        s.connect(remote)
+        s = ssl.wrap_socket(socket.socket(socket.AF_INET),
+                            cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT")
+        s.connect(remote)
+        # Error checking occurs when connecting, because the SSL context
+        # isn't created before.
+        s = ssl.wrap_socket(socket.socket(socket.AF_INET),
+                            cert_reqs=ssl.CERT_NONE, ciphers="^$:,;?*'dorothyx")
+        with self.assertRaisesRegexp(ssl.SSLError, "No cipher can be selected"):
+            s.connect(remote)
+
 
 class NetworkedTests(unittest.TestCase):
 
@@ -234,7 +251,8 @@
                                                    certfile=self.server.certificate,
                                                    ssl_version=self.server.protocol,
                                                    ca_certs=self.server.cacerts,
-                                                   cert_reqs=self.server.certreqs)
+                                                   cert_reqs=self.server.certreqs,
+                                                   ciphers=self.server.ciphers)
                 except:
                     if self.server.chatty:
                         handle_error("\n server:  bad connection attempt from " + repr(self.addr) + ":\n")
@@ -333,7 +351,8 @@
 
         def __init__(self, certificate, ssl_version=None,
                      certreqs=None, cacerts=None, expect_bad_connects=False,
-                     chatty=True, connectionchatty=False, starttls_server=False):
+                     chatty=True, connectionchatty=False, starttls_server=False,
+                     ciphers=None):
             if ssl_version is None:
                 ssl_version = ssl.PROTOCOL_TLSv1
             if certreqs is None:
@@ -342,6 +361,7 @@
             self.protocol = ssl_version
             self.certreqs = certreqs
             self.cacerts = cacerts
+            self.ciphers = ciphers
             self.expect_bad_connects = expect_bad_connects
             self.chatty = chatty
             self.connectionchatty = connectionchatty
@@ -648,12 +668,13 @@
     def serverParamsTest (certfile, protocol, certreqs, cacertsfile,
                           client_certfile, client_protocol=None,
                           indata="FOO\n",
-                          chatty=False, connectionchatty=False):
+                          ciphers=None, chatty=False, connectionchatty=False):
 
         server = ThreadedEchoServer(certfile,
                                     certreqs=certreqs,
                                     ssl_version=protocol,
                                     cacerts=cacertsfile,
+                                    ciphers=ciphers,
                                     chatty=chatty,
                                     connectionchatty=False)
         flag = threading.Event()
@@ -669,6 +690,7 @@
                                 certfile=client_certfile,
                                 ca_certs=cacertsfile,
                                 cert_reqs=certreqs,
+                                ciphers=ciphers,
                                 ssl_version=client_protocol)
             s.connect((HOST, server.port))
         except ssl.SSLError as x:
@@ -723,8 +745,12 @@
                               ssl.get_protocol_name(server_protocol),
                               certtype))
         try:
+            # NOTE: we must enable "ALL" ciphers, otherwise an SSLv23 client
+            # will send an SSLv3 hello (rather than SSLv2) starting from
+            # OpenSSL 1.0.0 (see issue #8322).
             serverParamsTest(CERTFILE, server_protocol, certsreqs,
                              CERTFILE, CERTFILE, client_protocol,
+                             ciphers="ALL",
                              chatty=False, connectionchatty=False)
         except support.TestFailed:
             if expectedToWork: