Added timeout to smtplib (to SMTP and SMTP_SSL). Also created
the test_smtplib.py file, with a basic test and the timeout
ones. Docs are updated too.
diff --git a/Lib/smtplib.py b/Lib/smtplib.py
index 4618671..fc1df51 100755
--- a/Lib/smtplib.py
+++ b/Lib/smtplib.py
@@ -230,7 +230,7 @@
ehlo_resp = None
does_esmtp = 0
- def __init__(self, host = '', port = 0, local_hostname = None):
+ def __init__(self, host='', port=0, local_hostname=None, timeout=None):
"""Initialize a new instance.
If specified, `host' is the name of the remote host to which to
@@ -241,6 +241,7 @@
the local hostname is found using socket.getfqdn().
"""
+ self.timeout = timeout
self.esmtp_features = {}
self.default_port = SMTP_PORT
if host:
@@ -274,12 +275,11 @@
"""
self.debuglevel = debuglevel
- def _get_socket(self,af, socktype, proto,sa):
+ def _get_socket(self, port, host, timeout):
# This makes it simpler for SMTP_SSL to use the SMTP connect code
# and just alter the socket connection bit.
- self.sock = socket.socket(af, socktype, proto)
if self.debuglevel > 0: print>>stderr, 'connect:', (host, port)
- self.sock.connect(sa)
+ return socket.create_connection((port, host), timeout)
def connect(self, host='localhost', port = 0):
"""Connect to a host on a given port.
@@ -301,21 +301,7 @@
raise socket.error, "nonnumeric port"
if not port: port = self.default_port
if self.debuglevel > 0: print>>stderr, 'connect:', (host, port)
- msg = "getaddrinfo returns an empty list"
- self.sock = None
- for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
- af, socktype, proto, canonname, sa = res
- try:
- self._get_socket(af,socktype,proto,sa)
- except socket.error, msg:
- if self.debuglevel > 0: print>>stderr, 'connect fail:', msg
- if self.sock:
- self.sock.close()
- self.sock = None
- continue
- break
- if not self.sock:
- raise socket.error, msg
+ self.sock = self._get_socket(host, port, self.timeout)
(code, msg) = self.getreply()
if self.debuglevel > 0: print>>stderr, "connect:", msg
return (code, msg)
@@ -732,17 +718,16 @@
are also optional - they can contain a PEM formatted private key and
certificate chain file for the SSL connection.
"""
- def __init__(self, host = '', port = 0, local_hostname = None,
- keyfile = None, certfile = None):
+ def __init__(self, host='', port=0, local_hostname=None,
+ keyfile=None, certfile=None, timeout=None):
self.keyfile = keyfile
self.certfile = certfile
- SMTP.__init__(self,host,port,local_hostname)
+ SMTP.__init__(self, host, port, local_hostname, timeout)
self.default_port = SMTP_SSL_PORT
- def _get_socket(self,af, socktype, proto,sa):
- self.sock = socket.socket(af, socktype, proto)
+ def _get_socket(self, host, port, timeout):
if self.debuglevel > 0: print>>stderr, 'connect:', (host, port)
- self.sock.connect(sa)
+ self.sock = socket.create_connection((host, port), timeout)
sslobj = socket.ssl(self.sock, self.keyfile, self.certfile)
self.sock = SSLFakeSocket(self.sock, sslobj)
self.file = SSLFakeFile(sslobj)
diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py
new file mode 100644
index 0000000..09c1d96
--- /dev/null
+++ b/Lib/test/test_smtplib.py
@@ -0,0 +1,71 @@
+import socket
+import threading
+import smtplib
+import time
+
+from unittest import TestCase
+from test import test_support
+
+
+def server(evt):
+ serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ serv.settimeout(3)
+ serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ serv.bind(("", 9091))
+ serv.listen(5)
+ try:
+ conn, addr = serv.accept()
+ except socket.timeout:
+ pass
+ else:
+ conn.send("220 Hola mundo\n")
+ conn.close()
+ finally:
+ serv.close()
+ evt.set()
+
+class GeneralTests(TestCase):
+
+ def setUp(self):
+ self.evt = threading.Event()
+ threading.Thread(target=server, args=(self.evt,)).start()
+ time.sleep(.1)
+
+ def tearDown(self):
+ self.evt.wait()
+
+ def testBasic(self):
+ # connects
+ smtp = smtplib.SMTP("localhost", 9091)
+ smtp.sock.close()
+
+ def testTimeoutDefault(self):
+ # default
+ smtp = smtplib.SMTP("localhost", 9091)
+ self.assertTrue(smtp.sock.gettimeout() is None)
+ smtp.sock.close()
+
+ def testTimeoutValue(self):
+ # a value
+ smtp = smtplib.SMTP("localhost", 9091, timeout=30)
+ self.assertEqual(smtp.sock.gettimeout(), 30)
+ smtp.sock.close()
+
+ def testTimeoutNone(self):
+ # None, having other default
+ previous = socket.getdefaulttimeout()
+ socket.setdefaulttimeout(30)
+ try:
+ smtp = smtplib.SMTP("localhost", 9091, timeout=None)
+ finally:
+ socket.setdefaulttimeout(previous)
+ self.assertEqual(smtp.sock.gettimeout(), 30)
+ smtp.sock.close()
+
+
+
+def test_main(verbose=None):
+ test_support.run_unittest(GeneralTests)
+
+if __name__ == '__main__':
+ test_main()