various SSL fixes; issues 1251, 3162, 3212
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index eb4d00c..d786154 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -3,7 +3,9 @@
 import sys
 import unittest
 from test import test_support
+import asyncore
 import socket
+import select
 import errno
 import subprocess
 import time
@@ -97,8 +99,7 @@
         if (d1 != d2):
             raise test_support.TestFailed("PEM-to-DER or DER-to-PEM translation failed")
 
-
-class NetworkTests(unittest.TestCase):
+class NetworkedTests(unittest.TestCase):
 
     def testConnect(self):
         s = ssl.wrap_socket(socket.socket(socket.AF_INET),
@@ -130,6 +131,31 @@
         finally:
             s.close()
 
+
+    def testNonBlockingHandshake(self):
+        s = socket.socket(socket.AF_INET)
+        s.connect(("svn.python.org", 443))
+        s.setblocking(False)
+        s = ssl.wrap_socket(s,
+                            cert_reqs=ssl.CERT_NONE,
+                            do_handshake_on_connect=False)
+        count = 0
+        while True:
+            try:
+                count += 1
+                s.do_handshake()
+                break
+            except ssl.SSLError, err:
+                if err.args[0] == ssl.SSL_ERROR_WANT_READ:
+                    select.select([s], [], [])
+                elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
+                    select.select([], [s], [])
+                else:
+                    raise
+        s.close()
+        if test_support.verbose:
+            sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
+
     def testFetchServerCert(self):
 
         pem = ssl.get_server_certificate(("svn.python.org", 443))
@@ -176,6 +202,18 @@
                 threading.Thread.__init__(self)
                 self.setDaemon(True)
 
+            def show_conn_details(self):
+                if self.server.certreqs == ssl.CERT_REQUIRED:
+                    cert = self.sslconn.getpeercert()
+                    if test_support.verbose and self.server.chatty:
+                        sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
+                    cert_binary = self.sslconn.getpeercert(True)
+                    if test_support.verbose and self.server.chatty:
+                        sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n")
+                cipher = self.sslconn.cipher()
+                if test_support.verbose and self.server.chatty:
+                    sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
+
             def wrap_conn (self):
                 try:
                     self.sslconn = ssl.wrap_socket(self.sock, server_side=True,
@@ -187,6 +225,7 @@
                     if self.server.chatty:
                         handle_error("\n server:  bad connection attempt from " +
                                      str(self.sock.getpeername()) + ":\n")
+                    self.close()
                     if not self.server.expect_bad_connects:
                         # here, we want to stop the server, because this shouldn't
                         # happen in the context of our test case
@@ -197,16 +236,6 @@
                     return False
 
                 else:
-                    if self.server.certreqs == ssl.CERT_REQUIRED:
-                        cert = self.sslconn.getpeercert()
-                        if test_support.verbose and self.server.chatty:
-                            sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
-                        cert_binary = self.sslconn.getpeercert(True)
-                        if test_support.verbose and self.server.chatty:
-                            sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n")
-                    cipher = self.sslconn.cipher()
-                    if test_support.verbose and self.server.chatty:
-                        sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
                     return True
 
             def read(self):
@@ -225,13 +254,16 @@
                 if self.sslconn:
                     self.sslconn.close()
                 else:
-                    self.sock.close()
+                    self.sock._sock.close()
 
             def run (self):
                 self.running = True
                 if not self.server.starttls_server:
-                    if not self.wrap_conn():
+                    if isinstance(self.sock, ssl.SSLSocket):
+                        self.sslconn = self.sock
+                    elif not self.wrap_conn():
                         return
+                    self.show_conn_details()
                 while self.running:
                     try:
                         msg = self.read()
@@ -270,7 +302,9 @@
 
         def __init__(self, certificate, ssl_version=None,
                      certreqs=None, cacerts=None, expect_bad_connects=False,
-                     chatty=True, connectionchatty=False, starttls_server=False):
+                     chatty=True, connectionchatty=False, starttls_server=False,
+                     wrap_accepting_socket=False):
+
             if ssl_version is None:
                 ssl_version = ssl.PROTOCOL_TLSv1
             if certreqs is None:
@@ -284,8 +318,16 @@
             self.connectionchatty = connectionchatty
             self.starttls_server = starttls_server
             self.sock = socket.socket()
-            self.port = test_support.bind_port(self.sock)
             self.flag = None
+            if wrap_accepting_socket:
+                self.sock = ssl.wrap_socket(self.sock, server_side=True,
+                                            certfile=self.certificate,
+                                            cert_reqs = self.certreqs,
+                                            ca_certs = self.cacerts,
+                                            ssl_version = self.protocol)
+                if test_support.verbose and self.chatty:
+                    sys.stdout.write(' server:  wrapped server socket as %s\n' % str(self.sock))
+            self.port = test_support.bind_port(self.sock)
             self.active = False
             threading.Thread.__init__(self)
             self.setDaemon(False)
@@ -316,13 +358,86 @@
                 except:
                     if self.chatty:
                         handle_error("Test server failure:\n")
+            self.sock.close()
 
         def stop (self):
             self.active = False
-            self.sock.close()
 
+    class AsyncoreEchoServer(threading.Thread):
 
-    class AsyncoreHTTPSServer(threading.Thread):
+        class EchoServer (asyncore.dispatcher):
+
+            class ConnectionHandler (asyncore.dispatcher_with_send):
+
+                def __init__(self, conn, certfile):
+                    asyncore.dispatcher_with_send.__init__(self, conn)
+                    self.socket = ssl.wrap_socket(conn, server_side=True,
+                                                  certfile=certfile,
+                                                  do_handshake_on_connect=True)
+
+                def readable(self):
+                    if isinstance(self.socket, ssl.SSLSocket):
+                        while self.socket.pending() > 0:
+                            self.handle_read_event()
+                    return True
+
+                def handle_read(self):
+                    data = self.recv(1024)
+                    self.send(data.lower())
+
+                def handle_close(self):
+                    if test_support.verbose:
+                        sys.stdout.write(" server:  closed connection %s\n" % self.socket)
+
+                def handle_error(self):
+                    raise
+
+            def __init__(self, certfile):
+                self.certfile = certfile
+                asyncore.dispatcher.__init__(self)
+                self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
+                self.port = test_support.bind_port(self.socket)
+                self.listen(5)
+
+            def handle_accept(self):
+                sock_obj, addr = self.accept()
+                if test_support.verbose:
+                    sys.stdout.write(" server:  new connection from %s:%s\n" %addr)
+                self.ConnectionHandler(sock_obj, self.certfile)
+
+            def handle_error(self):
+                raise
+
+        def __init__(self, certfile):
+            self.flag = None
+            self.active = False
+            self.server = self.EchoServer(certfile)
+            self.port = self.server.port
+            threading.Thread.__init__(self)
+            self.setDaemon(True)
+
+        def __str__(self):
+            return "<%s %s>" % (self.__class__.__name__, self.server)
+
+        def start (self, flag=None):
+            self.flag = flag
+            threading.Thread.start(self)
+
+        def run (self):
+            self.active = True
+            if self.flag:
+                self.flag.set()
+            while self.active:
+                try:
+                    asyncore.loop(1)
+                except:
+                    pass
+
+        def stop (self):
+            self.active = False
+            self.server.close()
+
+    class SocketServerHTTPSServer(threading.Thread):
 
         class HTTPSServer(HTTPServer):
 
@@ -335,6 +450,12 @@
                 self.active_lock = threading.Lock()
                 self.allow_reuse_address = True
 
+            def __str__(self):
+                return ('<%s %s:%s>' %
+                        (self.__class__.__name__,
+                         self.server_name,
+                         self.server_port))
+
             def get_request (self):
                 # override this to wrap socket with SSL
                 sock, addr = self.socket.accept()
@@ -421,8 +542,8 @@
                 # we override this to suppress logging unless "verbose"
 
                 if test_support.verbose:
-                    sys.stdout.write(" server (%s, %d, %s):\n   [%s] %s\n" %
-                                     (self.server.server_name,
+                    sys.stdout.write(" server (%s:%d %s):\n   [%s] %s\n" %
+                                     (self.server.server_address,
                                       self.server.server_port,
                                       self.request.cipher(),
                                       self.log_date_time_string(),
@@ -440,9 +561,7 @@
             self.setDaemon(True)
 
         def __str__(self):
-            return '<%s %s:%d>' % (self.__class__.__name__,
-                                   self.server.server_name,
-                                   self.server.server_port)
+            return "<%s %s>" % (self.__class__.__name__, self.server)
 
         def start (self, flag=None):
             self.flag = flag
@@ -487,14 +606,16 @@
 
     def serverParamsTest (certfile, protocol, certreqs, cacertsfile,
                           client_certfile, client_protocol=None, indata="FOO\n",
-                          chatty=True, connectionchatty=False):
+                          chatty=True, connectionchatty=False,
+                          wrap_accepting_socket=False):
 
         server = ThreadedEchoServer(certfile,
                                     certreqs=certreqs,
                                     ssl_version=protocol,
                                     cacerts=cacertsfile,
                                     chatty=chatty,
-                                    connectionchatty=connectionchatty)
+                                    connectionchatty=connectionchatty,
+                                    wrap_accepting_socket=wrap_accepting_socket)
         flag = threading.Event()
         server.start(flag)
         # wait for it to start
@@ -572,7 +693,7 @@
                        ssl.get_protocol_name(server_protocol)))
 
 
-    class ConnectedTests(unittest.TestCase):
+    class ThreadedTests(unittest.TestCase):
 
         def testRudeShutdown(self):
 
@@ -600,7 +721,7 @@
                 listener_gone.wait()
                 try:
                     ssl_sock = ssl.wrap_socket(s)
-                except socket.sslerror:
+                except IOError:
                     pass
                 else:
                     raise test_support.TestFailed(
@@ -680,6 +801,9 @@
         def testMalformedCert(self):
             badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir,
                                      "badcert.pem"))
+        def testWrongCert(self):
+            badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir,
+                                     "wrongcert.pem"))
         def testMalformedKey(self):
             badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir,
                                      "badkey.pem"))
@@ -796,9 +920,9 @@
                 server.stop()
                 server.join()
 
-        def testAsyncore(self):
+        def testSocketServer(self):
 
-            server = AsyncoreHTTPSServer(CERTFILE)
+            server = SocketServerHTTPSServer(CERTFILE)
             flag = threading.Event()
             server.start(flag)
             # wait for it to start
@@ -810,8 +934,8 @@
                 d1 = open(CERTFILE, 'rb').read()
                 d2 = ''
                 # now fetch the same data from the HTTPS server
-                url = 'https://%s:%d/%s' % (
-                    HOST, server.port, os.path.split(CERTFILE)[1])
+                url = 'https://127.0.0.1:%d/%s' % (
+                    server.port, os.path.split(CERTFILE)[1])
                 f = urllib.urlopen(url)
                 dlen = f.info().getheader("content-length")
                 if dlen and (int(dlen) > 0):
@@ -834,6 +958,58 @@
                 server.stop()
                 server.join()
 
+        def testWrappedAccept (self):
+
+            if test_support.verbose:
+                sys.stdout.write("\n")
+            serverParamsTest(CERTFILE, ssl.PROTOCOL_SSLv23, ssl.CERT_REQUIRED,
+                             CERTFILE, CERTFILE, ssl.PROTOCOL_SSLv23,
+                             chatty=True, connectionchatty=True,
+                             wrap_accepting_socket=True)
+
+
+        def testAsyncoreServer (self):
+
+            indata = "TEST MESSAGE of mixed case\n"
+
+            if test_support.verbose:
+                sys.stdout.write("\n")
+            server = AsyncoreEchoServer(CERTFILE)
+            flag = threading.Event()
+            server.start(flag)
+            # wait for it to start
+            flag.wait()
+            # try to connect
+            try:
+                try:
+                    s = ssl.wrap_socket(socket.socket())
+                    s.connect(('127.0.0.1', server.port))
+                except ssl.SSLError, x:
+                    raise test_support.TestFailed("Unexpected SSL error:  " + str(x))
+                except Exception, x:
+                    raise test_support.TestFailed("Unexpected exception:  " + str(x))
+                else:
+                    if test_support.verbose:
+                        sys.stdout.write(
+                            " client:  sending %s...\n" % (repr(indata)))
+                    s.write(indata)
+                    outdata = s.read()
+                    if test_support.verbose:
+                        sys.stdout.write(" client:  read %s\n" % repr(outdata))
+                    if outdata != indata.lower():
+                        raise test_support.TestFailed(
+                            "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
+                            % (outdata[:min(len(outdata),20)], len(outdata),
+                               indata[:min(len(indata),20)].lower(), len(indata)))
+                    s.write("over\n")
+                    if test_support.verbose:
+                        sys.stdout.write(" client:  closing connection.\n")
+                    s.close()
+            finally:
+                server.stop()
+                # wait for server thread to end
+                server.join()
+
 
 def test_main(verbose=False):
     if skip_expected:
@@ -850,15 +1026,19 @@
         not os.path.exists(SVN_PYTHON_ORG_ROOT_CERT)):
         raise test_support.TestFailed("Can't read certificate files!")
 
+    TESTPORT = test_support.find_unused_port()
+    if not TESTPORT:
+        raise test_support.TestFailed("Can't find open port to test servers on!")
+
     tests = [BasicTests]
 
     if test_support.is_resource_enabled('network'):
-        tests.append(NetworkTests)
+        tests.append(NetworkedTests)
 
     if _have_threads:
         thread_info = test_support.threading_setup()
         if thread_info and test_support.is_resource_enabled('network'):
-            tests.append(ConnectedTests)
+            tests.append(ThreadedTests)
 
     test_support.run_unittest(*tests)