Merged revisions 76309 via svnmerge from
svn+ssh://pythondev@svn.python.org/python/trunk

........
  r76309 | antoine.pitrou | 2009-11-15 18:22:09 +0100 (dim., 15 nov. 2009) | 4 lines

  Issue #2054: ftplib now provides an FTP_TLS class to do secure FTP using
  TLS or SSL.  Patch by Giampaolo Rodola'.
........
diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py
index b83cba0..708ac41 100644
--- a/Lib/test/test_ftplib.py
+++ b/Lib/test/test_ftplib.py
@@ -1,6 +1,7 @@
 """Test script for ftplib module."""
 
-# Modified by Giampaolo Rodola' to test FTP class and IPv6 environment
+# Modified by Giampaolo Rodola' to test FTP class, IPv6 and TLS
+# environment
 
 import ftplib
 import threading
@@ -8,6 +9,12 @@
 import asynchat
 import socket
 import io
+import errno
+import os
+try:
+    import ssl
+except ImportError:
+    ssl = None
 
 from unittest import TestCase
 from test import support
@@ -40,6 +47,8 @@
 
 class DummyFTPHandler(asynchat.async_chat):
 
+    dtp_handler = DummyDTPHandler
+
     def __init__(self, conn):
         asynchat.async_chat.__init__(self, conn)
         self.set_terminator(b"\r\n")
@@ -83,7 +92,7 @@
         ip = '%d.%d.%d.%d' %tuple(addr[:4])
         port = (addr[4] * 256) + addr[5]
         s = socket.create_connection((ip, port), timeout=2)
-        self.dtp = DummyDTPHandler(s, baseclass=self)
+        self.dtp = self.dtp_handler(s, baseclass=self)
         self.push('200 active data connection established')
 
     def cmd_pasv(self, arg):
@@ -95,13 +104,13 @@
         ip = ip.replace('.', ','); p1 = port / 256; p2 = port % 256
         self.push('227 entering passive mode (%s,%d,%d)' %(ip, p1, p2))
         conn, addr = sock.accept()
-        self.dtp = DummyDTPHandler(conn, baseclass=self)
+        self.dtp = self.dtp_handler(conn, baseclass=self)
 
     def cmd_eprt(self, arg):
         af, ip, port = arg.split(arg[0])[1:-1]
         port = int(port)
         s = socket.create_connection((ip, port), timeout=2)
-        self.dtp = DummyDTPHandler(s, baseclass=self)
+        self.dtp = self.dtp_handler(s, baseclass=self)
         self.push('200 active data connection established')
 
     def cmd_epsv(self, arg):
@@ -112,7 +121,7 @@
         port = sock.getsockname()[1]
         self.push('229 entering extended passive mode (|||%d|)' %port)
         conn, addr = sock.accept()
-        self.dtp = DummyDTPHandler(conn, baseclass=self)
+        self.dtp = self.dtp_handler(conn, baseclass=self)
 
     def cmd_echo(self, arg):
         # sends back the received string (used by the test suite)
@@ -227,6 +236,128 @@
         raise
 
 
+if ssl is not None:
+
+    CERTFILE = os.path.join(os.path.dirname(__file__), "keycert.pem")
+
+    class SSLConnection(asyncore.dispatcher):
+        """An asyncore.dispatcher subclass supporting TLS/SSL."""
+
+        _ssl_accepting = False
+
+        def secure_connection(self):
+            self.del_channel()
+            socket = ssl.wrap_socket(self.socket, suppress_ragged_eofs=False,
+                                     certfile=CERTFILE, server_side=True,
+                                     do_handshake_on_connect=False,
+                                     ssl_version=ssl.PROTOCOL_SSLv23)
+            self.set_socket(socket)
+            self._ssl_accepting = True
+
+        def _do_ssl_handshake(self):
+            try:
+                self.socket.do_handshake()
+            except ssl.SSLError as err:
+                if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
+                                   ssl.SSL_ERROR_WANT_WRITE):
+                    return
+                elif err.args[0] == ssl.SSL_ERROR_EOF:
+                    return self.handle_close()
+                raise
+            except socket.error as err:
+                if err.args[0] == errno.ECONNABORTED:
+                    return self.handle_close()
+            else:
+                self._ssl_accepting = False
+
+        def handle_read_event(self):
+            if self._ssl_accepting:
+                self._do_ssl_handshake()
+            else:
+                super(SSLConnection, self).handle_read_event()
+
+        def handle_write_event(self):
+            if self._ssl_accepting:
+                self._do_ssl_handshake()
+            else:
+                super(SSLConnection, self).handle_write_event()
+
+        def send(self, data):
+            try:
+                return super(SSLConnection, self).send(data)
+            except ssl.SSLError as err:
+                if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN):
+                    return 0
+                raise
+
+        def recv(self, buffer_size):
+            try:
+                return super(SSLConnection, self).recv(buffer_size)
+            except ssl.SSLError as err:
+                if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN):
+                    self.handle_close()
+                    return b''
+                raise
+
+        def handle_error(self):
+            raise
+
+        def close(self):
+            try:
+                if isinstance(self.socket, ssl.SSLSocket):
+                    if self.socket._sslobj is not None:
+                        self.socket.unwrap()
+            finally:
+                super(SSLConnection, self).close()
+
+
+    class DummyTLS_DTPHandler(SSLConnection, DummyDTPHandler):
+        """A DummyDTPHandler subclass supporting TLS/SSL."""
+
+        def __init__(self, conn, baseclass):
+            DummyDTPHandler.__init__(self, conn, baseclass)
+            if self.baseclass.secure_data_channel:
+                self.secure_connection()
+
+
+    class DummyTLS_FTPHandler(SSLConnection, DummyFTPHandler):
+        """A DummyFTPHandler subclass supporting TLS/SSL."""
+
+        dtp_handler = DummyTLS_DTPHandler
+
+        def __init__(self, conn):
+            DummyFTPHandler.__init__(self, conn)
+            self.secure_data_channel = False
+
+        def cmd_auth(self, line):
+            """Set up secure control channel."""
+            self.push('234 AUTH TLS successful')
+            self.secure_connection()
+
+        def cmd_pbsz(self, line):
+            """Negotiate size of buffer for secure data transfer.
+            For TLS/SSL the only valid value for the parameter is '0'.
+            Any other value is accepted but ignored.
+            """
+            self.push('200 PBSZ=0 successful.')
+
+        def cmd_prot(self, line):
+            """Setup un/secure data channel."""
+            arg = line.upper()
+            if arg == 'C':
+                self.push('200 Protection set to Clear')
+                self.secure_data_channel = False
+            elif arg == 'P':
+                self.push('200 Protection set to Private')
+                self.secure_data_channel = True
+            else:
+                self.push("502 Unrecognized PROT type (use C or P).")
+
+
+    class DummyTLS_FTPServer(DummyFTPServer):
+        handler = DummyTLS_FTPHandler
+
+
 class TestFTPClass(TestCase):
 
     def setUp(self):
@@ -404,6 +535,81 @@
         retr()
 
 
+class TestTLS_FTPClassMixin(TestFTPClass):
+    """Repeat TestFTPClass tests starting the TLS layer for both control
+    and data connections first.
+    """
+
+    def setUp(self):
+        self.server = DummyTLS_FTPServer((HOST, 0))
+        self.server.start()
+        self.client = ftplib.FTP_TLS(timeout=2)
+        self.client.connect(self.server.host, self.server.port)
+        # enable TLS
+        self.client.auth()
+        self.client.prot_p()
+
+
+class TestTLS_FTPClass(TestCase):
+    """Specific TLS_FTP class tests."""
+
+    def setUp(self):
+        self.server = DummyTLS_FTPServer((HOST, 0))
+        self.server.start()
+        self.client = ftplib.FTP_TLS(timeout=2)
+        self.client.connect(self.server.host, self.server.port)
+
+    def tearDown(self):
+        self.client.close()
+        self.server.stop()
+
+    def test_control_connection(self):
+        self.assertFalse(isinstance(self.client.sock, ssl.SSLSocket))
+        self.client.auth()
+        self.assertTrue(isinstance(self.client.sock, ssl.SSLSocket))
+
+    def test_data_connection(self):
+        # clear text
+        sock = self.client.transfercmd('list')
+        self.assertFalse(isinstance(sock, ssl.SSLSocket))
+        sock.close()
+        self.client.voidresp()
+
+        # secured, after PROT P
+        self.client.prot_p()
+        sock = self.client.transfercmd('list')
+        self.assertTrue(isinstance(sock, ssl.SSLSocket))
+        sock.close()
+        self.client.voidresp()
+
+        # PROT C is issued, the connection must be in cleartext again
+        self.client.prot_c()
+        sock = self.client.transfercmd('list')
+        self.assertFalse(isinstance(sock, ssl.SSLSocket))
+        sock.close()
+        self.client.voidresp()
+
+    def test_login(self):
+        # login() is supposed to implicitly secure the control connection
+        self.assertFalse(isinstance(self.client.sock, ssl.SSLSocket))
+        self.client.login()
+        self.assertTrue(isinstance(self.client.sock, ssl.SSLSocket))
+        # make sure that AUTH TLS doesn't get issued again
+        self.client.login()
+
+    def test_auth_issued_twice(self):
+        self.client.auth()
+        self.assertRaises(ValueError, self.client.auth)
+
+    def test_auth_ssl(self):
+        try:
+            self.client.ssl_version = ssl.PROTOCOL_SSLv3
+            self.client.auth()
+            self.assertRaises(ValueError, self.client.auth)
+        finally:
+            self.client.ssl_version = ssl.PROTOCOL_TLSv1
+
+
 class TestTimeouts(TestCase):
 
     def setUp(self):
@@ -505,6 +711,10 @@
             pass
         else:
             tests.append(TestIPv6Environment)
+
+    if ssl is not None:
+        tests.extend([TestTLS_FTPClassMixin, TestTLS_FTPClass])
+
     thread_info = support.threading_setup()
     try:
         support.run_unittest(*tests)