Add certificate validation. Work initially started by Christoph Kern.
diff --git a/python3/httplib2/__init__.py b/python3/httplib2/__init__.py
index 90ec4d9..af2c3ee 100644
--- a/python3/httplib2/__init__.py
+++ b/python3/httplib2/__init__.py
@@ -92,6 +92,7 @@
 class MalformedHeader(HttpLib2Error): pass
 class RelativeURIError(HttpLib2Error): pass
 class ServerNotFoundError(HttpLib2Error): pass
+class CertificateValidationUnsupportedInPython31(HttpLib2Error): pass
 
 # Open Items:
 # -----------
@@ -118,6 +119,10 @@
 # Which headers are hop-by-hop headers by default
 HOP_BY_HOP = ['connection', 'keep-alive', 'proxy-authenticate', 'proxy-authorization', 'te', 'trailers', 'transfer-encoding', 'upgrade']
 
+# Default CA certificates file bundled with httplib2.
+CA_CERTS = os.path.join(
+        os.path.dirname(os.path.abspath(__file__ )), "cacerts.txt")
+
 def _get_end2end_headers(response):
     hopbyhop = list(HOP_BY_HOP)
     hopbyhop.extend([x.strip() for x in response.get('connection', '').split(',')])
@@ -219,10 +224,10 @@
           while authenticate:
               # Break off the scheme at the beginning of the line
               if headername == 'authentication-info':
-                  (auth_scheme, the_rest) = ('digest', authenticate)                
+                  (auth_scheme, the_rest) = ('digest', authenticate)
               else:
                   (auth_scheme, the_rest) = authenticate.split(" ", 1)
-              # Now loop over all the key value pairs that come after the scheme, 
+              # Now loop over all the key value pairs that come after the scheme,
               # being careful not to roll into the next scheme
               match = www_auth.search(the_rest)
               auth_params = {}
@@ -712,43 +717,11 @@
     http://docs.python.org/library/socket.html#socket.setdefaulttimeout
     """
 
-    def __init__(self, host, port=None, strict=None, timeout=None, proxy_info=None):
-        http.client.HTTPConnection.__init__(self, host, port, strict, timeout)
+    def __init__(self, host, port=None, timeout=None, proxy_info=None):
+        http.client.HTTPConnection.__init__(self, host, port=port,
+                                            timeout=timeout)
         self.proxy_info = proxy_info
 
-    def connect(self):
-        """Connect to the host and port specified in __init__."""
-        self.sock = socket.create_connection((self.host,self.port),
-                                             self.timeout)
-        # Mostly verbatim from httplib.py.
-        msg = "getaddrinfo returns an empty list"
-        for res in socket.getaddrinfo(self.host, self.port, 0,
-                socket.SOCK_STREAM):
-            af, socktype, proto, canonname, sa = res
-            try:
-                if self.proxy_info and self.proxy_info.isgood():
-                    self.sock = socks.socksocket(af, socktype, proto)
-                    self.sock.setproxy(*self.proxy_info.astuple())
-                else:
-                    self.sock = socket.socket(af, socktype, proto)
-                    self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
-                # Different from httplib: support timeouts.
-                if has_timeout(self.timeout):
-                    self.sock.settimeout(self.timeout)
-                    # End of difference from httplib.
-                if self.debuglevel > 0:
-                    print("connect: (%s, %s)" % (self.host, self.port))
-                self.sock.connect(sa)
-            except socket.error as msg:
-                if self.debuglevel > 0:
-                    print('connect fail:', (self.host, self.port))
-                if self.sock:
-                    self.sock.close()
-                self.sock = None
-                continue
-            break
-        if not self.sock:
-            raise socket.error(msg)
 
 class HTTPSConnectionWithTimeout(http.client.HTTPSConnection):
     """
@@ -761,43 +734,25 @@
     """
 
     def __init__(self, host, port=None, key_file=None, cert_file=None,
-                 strict=None, timeout=None, proxy_info=None):
+                 timeout=None, proxy_info=None,
+                 ca_certs=None, disable_ssl_certificate_validation=False):
         self.proxy_info = proxy_info
+        context = None
+        if ca_certs is None:
+          ca_certs = CA_CERTS
+        if (cert_file or ca_certs) and not disable_ssl_certificate_validation:
+          if not hasattr(ssl, 'SSLContext'):
+            raise CertificateValidationUnsupportedInPython31()
+          context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+          context.verify_mode = ssl.CERT_REQUIRED
+          if cert_file:
+            context.load_cert_chain(cert_file, key_file)
+          if ca_certs:
+            context.load_verify_locations(ca_certs)
         http.client.HTTPSConnection.__init__(self, host, port=port, key_file=key_file,
-                cert_file=cert_file, strict=strict, timeout=timeout)
+                cert_file=cert_file, timeout=timeout, context=context,
+                                             check_hostname=True)
 
-    def connect(self):
-        "Connect to a host on a given (SSL) port."
-
-        msg = "getaddrinfo returns an empty list"
-        self.sock = None
-        for family, socktype, proto, canonname, sockaddr in socket.getaddrinfo(
-            self.host, self.port, 0, socket.SOCK_STREAM):
-            try:
-              if self.proxy_info and self.proxy_info.isgood():
-                  sock = socks.socksocket(family, socktype, proto)
-                  sock.setproxy(*self.proxy_info.astuple())
-              else:
-                  sock = socket.socket(family, socktype, proto)
-                  sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
-
-              if has_timeout(self.timeout):
-                  sock.settimeout(self.timeout)
-              sock.connect((self.host, self.port))
-              self.sock =_ssl_wrap_socket(sock, self.key_file, self.cert_file)
-              if self.debuglevel > 0:
-                  print("connect: (%s, %s)" % (self.host, self.port))
-            except socket.error as err:
-              if self.debuglevel > 0:
-                  print('connect fail:', (self.host, self.port))
-              if self.sock:
-                  self.sock.close()
-              self.sock = None
-              msg = err
-              continue
-            break
-        if self.sock is None:
-          raise socket.error(msg)
 
 
 class Http(object):
@@ -813,13 +768,17 @@
 
 and more.
     """
-    def __init__(self, cache=None, timeout=None, proxy_info=None):
+    def __init__(self, cache=None, timeout=None, proxy_info=None,
+        ca_certs=None, disable_ssl_certificate_validation=False):
         """The value of proxy_info is a ProxyInfo instance.
 
 If 'cache' is a string then it is used as a directory name
 for a disk cache. Otherwise it must be an object that supports
 the same interface as FileCache."""
         self.proxy_info = proxy_info
+        self.ca_certs = ca_certs
+        self.disable_ssl_certificate_validation = \
+                disable_ssl_certificate_validation
         # Map domain name to an httplib connection
         self.connections = {}
         # The location of the cache, for now a directory
@@ -884,8 +843,11 @@
     def _conn_request(self, conn, request_uri, method, body, headers):
         for i in range(2):
             try:
+                if conn.sock is None:
+                  conn.connect()
                 conn.request(method, request_uri, body, headers)
             except socket.timeout:
+                conn.close()
                 raise
             except socket.gaierror:
                 conn.close()
@@ -913,6 +875,7 @@
             try:
                 response = conn.getresponse()
             except (socket.error, http.client.HTTPException):
+                conn.close()
                 if i == 0:
                     conn.close()
                     conn.connect()
@@ -1054,11 +1017,26 @@
                 if not connection_type:
                     connection_type = (scheme == 'https') and HTTPSConnectionWithTimeout or HTTPConnectionWithTimeout
                 certs = list(self.certificates.iter(authority))
-                if scheme == 'https' and certs:
-                    conn = self.connections[conn_key] = connection_type(authority, key_file=certs[0][0],
-                        cert_file=certs[0][1], timeout=self.timeout, proxy_info=self.proxy_info)
+                if issubclass(connection_type, HTTPSConnectionWithTimeout):
+                    if certs:
+                        conn = self.connections[conn_key] = connection_type(
+                                authority, key_file=certs[0][0],
+                                cert_file=certs[0][1], timeout=self.timeout,
+                                proxy_info=self.proxy_info,
+                                ca_certs=self.ca_certs,
+                                disable_ssl_certificate_validation=
+                                        self.disable_ssl_certificate_validation)
+                    else:
+                        conn = self.connections[conn_key] = connection_type(
+                                authority, timeout=self.timeout,
+                                proxy_info=self.proxy_info,
+                                ca_certs=self.ca_certs,
+                                disable_ssl_certificate_validation=
+                                        self.disable_ssl_certificate_validation)
                 else:
-                    conn = self.connections[conn_key] = connection_type(authority, timeout=self.timeout, proxy_info=self.proxy_info)
+                    conn = self.connections[conn_key] = connection_type(
+                            authority, timeout=self.timeout,
+                            proxy_info=self.proxy_info)
                 conn.set_debuglevel(debuglevel)
 
             if 'range' not in headers and 'accept-encoding' not in headers:
diff --git a/python3/httplib2test.py b/python3/httplib2test.py
index 4ef3a78..40e087c 100755
--- a/python3/httplib2test.py
+++ b/python3/httplib2test.py
@@ -20,6 +20,7 @@
 import io

 import os

 import socket

+import ssl

 import sys

 import time

 import unittest

@@ -117,6 +118,7 @@
         self.port = port

         self.timeout = timeout

         self.log = ""

+        self.sock = None

 

     def set_debuglevel(self, level):

         pass

@@ -473,8 +475,26 @@
           # Skip on 3.2

           pass

 

+    def testSslCertValidation(self):

+          # Test that we get an ssl.SSLError when specifying a non-existent CA

+          # certs file.

+          http = httplib2.Http(ca_certs='/nosuchfile')

+          self.assertRaises(IOError,

+                  http.request, "https://www.google.com/", "GET")

 

+          # Test that we get a SSLHandshakeError if we try to access

+          # https://www.google.com, using a CA cert file that doesn't contain

+          # the CA Gogole uses (i.e., simulating a cert that's not signed by a

+          # trusted CA).

+          other_ca_certs = os.path.join(

+                  os.path.dirname(os.path.abspath(httplib2.__file__ )),

+                  "test", "other_cacerts.txt")

+          http = httplib2.Http(ca_certs=other_ca_certs)

+          self.assertRaises(ssl.SSLError,

+            http.request,"https://www.google.com/", "GET")

 

+    def testSniHostnameValidation(self):

+        self.http.request("https://google.com/", method="GET")

 

     def testGet303(self):

         # Do a follow-up GET on a Location: header

@@ -736,20 +756,6 @@
         self.assertEqual(response.status, 500)

         self.assertTrue(response.reason.startswith("Content purported"))

 

-    def testTimeout(self):

-        self.http.force_exception_to_status_code = True 

-        uri = urllib.parse.urljoin(base, "timeout/timeout.cgi")

-        try:

-            import socket

-            socket.setdefaulttimeout(1) 

-        except:

-            # Don't run the test if we can't set the timeout

-            return 

-        (response, content) = self.http.request(uri)

-        self.assertEqual(response.status, 408)

-        self.assertTrue(response.reason.startswith("Request Timeout"))

-        self.assertTrue(content.startswith(b"Request Timeout"))

-

     def testIndividualTimeout(self):

         uri = urllib.parse.urljoin(base, "timeout/timeout.cgi")

         http = httplib2.Http(timeout=1)

@@ -1469,11 +1475,11 @@
         # Degenerate case of no headers

         response = {}

         end2end = httplib2._get_end2end_headers(response)

-        self.assertEquals(0, len(end2end))

+        self.assertEqual(0, len(end2end))

 

         # Degenerate case of connection referrring to a header not passed in 

         response = {'connection': 'content-type'}

         end2end = httplib2._get_end2end_headers(response)

-        self.assertEquals(0, len(end2end))

+        self.assertEqual(0, len(end2end))

 

 unittest.main()