Fixed 1529840 Add domain with add_credentials, and now allow adding SLL credentials
diff --git a/httplib2/__init__.py b/httplib2/__init__.py
index e74c476..4cb19f3 100644
--- a/httplib2/__init__.py
+++ b/httplib2/__init__.py
@@ -616,6 +616,26 @@
if os.path.exists(cacheFullPath):
os.remove(cacheFullPath)
+class Credentials:
+ def __init__(self):
+ self.credentials = []
+
+ def add(self, name, password, domain=""):
+ self.credentials.append((domain.lower(), name, password))
+
+ def clear(self):
+ self.credentials = []
+
+ def iter(self, domain):
+ for (cdomain, name, password) in self.credentials:
+ if cdomain == "" or domain == cdomain:
+ yield (name, password)
+
+class KeyCerts(Credentials):
+ """Identical to Credentials except that
+ name/password are mapped to key/cert."""
+ pass
+
class Http:
"""An HTTP client that handles all
methods, caching, ETags, compression,
@@ -631,8 +651,11 @@
else:
self.cache = cache
- # tuples of name, password
- self.credentials = []
+ # Name/password
+ self.credentials = Credentials()
+
+ # Key/cert
+ self.certificates = KeyCerts()
# authorization objects
self.authorizations = []
@@ -646,20 +669,25 @@
that can be applied to requests.
"""
challenges = _parse_www_authenticate(response, 'www-authenticate')
- for cred in self.credentials:
+ for cred in self.credentials.iter(host):
for scheme in AUTH_SCHEME_ORDER:
if challenges.has_key(scheme):
yield AUTH_SCHEME_CLASSES[scheme](cred, host, request_uri, headers, response, content, self)
- def add_credentials(self, name, password):
+ def add_credentials(self, name, password, domain=""):
"""Add a name and password that will be used
any time a request requires authentication."""
- self.credentials.append((name, password))
+ self.credentials.add(name, password, domain)
+
+ def add_certificate(self, key, cert, domain):
+ """Add a key and cert that will be used
+ any time a request requires authentication."""
+ self.certificates.add(key, cert, domain)
def clear_credentials(self):
"""Remove all the names and passwords
that are used for authentication"""
- self.credentials = []
+ self.credentials.clear()
self.authorizations = []
def _conn_request(self, conn, request_uri, method, body, headers):
@@ -784,12 +812,17 @@
(scheme, authority, request_uri, defrag_uri) = urlnorm(uri)
- if not self.connections.has_key(scheme+":"+authority):
- connection_type = (scheme == 'https') and httplib.HTTPSConnection or httplib.HTTPConnection
- conn = self.connections[scheme+":"+authority] = connection_type(authority)
- conn.set_debuglevel(debuglevel)
+ conn_key = scheme+":"+authority
+ if conn_key in self.connections:
+ conn = self.connections[conn_key]
else:
- conn = self.connections[scheme+":"+authority]
+ connection_type = (scheme == 'https') and httplib.HTTPSConnection or httplib.HTTPConnection
+ certs = list(self.certificates.iter(authority))
+ if scheme == 'https' and certs:
+ conn = self.connections[conn_key] = connection_type(authority, key_file=certs[0][0], cert_file=certs[0][1])
+ else:
+ conn = self.connections[conn_key] = connection_type(authority)
+ conn.set_debuglevel(debuglevel)
if method in ["GET", "HEAD"] and 'range' not in headers:
headers['accept-encoding'] = 'compress, gzip'
diff --git a/httplib2test.py b/httplib2test.py
index 4839559..d8c1859 100755
--- a/httplib2test.py
+++ b/httplib2test.py
@@ -15,7 +15,14 @@
__version__ = "0.1 ($Rev: 118 $)"
-import sys, unittest, httplib2, os, urlparse, time, base64
+import sys
+import unittest
+import httplib
+import httplib2
+import os
+import urlparse
+import time
+import base64
# Python 2.3 support
@@ -28,6 +35,26 @@
#base = 'http://localhost/projects/httplib2/test/'
cacheDirName = ".cache"
+
+class CredentialsTest(unittest.TestCase):
+ def test(self):
+ c = httplib2.Credentials()
+ c.add("joe", "password")
+ self.assertEqual(("joe", "password"), list(c.iter("bitworking.org"))[0])
+ self.assertEqual(("joe", "password"), list(c.iter(""))[0])
+ c.add("fred", "password2", "wellformedweb.org")
+ self.assertEqual(("joe", "password"), list(c.iter("bitworking.org"))[0])
+ self.assertEqual(1, len(list(c.iter("bitworking.org"))))
+ self.assertEqual(2, len(list(c.iter("wellformedweb.org"))))
+ self.assertTrue(("fred", "password2") in list(c.iter("wellformedweb.org")))
+ c.clear()
+ self.assertEqual(0, len(list(c.iter("bitworking.org"))))
+ c.add("fred", "password2", "wellformedweb.org")
+ self.assertTrue(("fred", "password2") in list(c.iter("wellformedweb.org")))
+ self.assertEqual(0, len(list(c.iter("bitworking.org"))))
+ self.assertEqual(0, len(list(c.iter(""))))
+
+
class ParserTest(unittest.TestCase):
def testFromStd66(self):
self.assertEqual( ('http', 'example.com', '', None, None ), httplib2.parse_uri("http://example.com"))
@@ -275,6 +302,31 @@
self.assertEqual(200, response.status)
self.assertNotEqual(None, response.previous)
+
+ def testGetViaHttpsKeyCert(self):
+ """At this point I can only test
+ that the key and cert files are passed in
+ correctly to httplib. It would be nice to have
+ a real https endpoint to test against.
+ """
+ http = httplib2.Http()
+ try:
+ (response, content) = http.request("https://example.org", "GET")
+ except:
+ pass
+ self.assertEqual(http.connections["https:example.org"].key_file, None)
+ self.assertEqual(http.connections["https:example.org"].cert_file, None)
+
+
+ http.add_certificate("akeyfile", "acertfile", "bitworking.org")
+ try:
+ (response, content) = http.request("https://bitworking.org", "GET")
+ except:
+ pass
+ self.assertEqual(http.connections["https:bitworking.org"].key_file, "akeyfile")
+ self.assertEqual(http.connections["https:bitworking.org"].cert_file, "acertfile")
+
+
def testGet303(self):
# Do a follow-up GET on a Location: header
# returned from a POST that gave a 303.
@@ -573,6 +625,38 @@
(response, content) = self.http.request(uri, "GET")
self.assertEqual(response.status, 200)
+ def testBasicAuthWithDomain(self):
+ # Test Basic Authentication
+ uri = urlparse.urljoin(base, "basic/file.txt")
+ (response, content) = self.http.request(uri, "GET")
+ self.assertEqual(response.status, 401)
+
+ uri = urlparse.urljoin(base, "basic/")
+ (response, content) = self.http.request(uri, "GET")
+ self.assertEqual(response.status, 401)
+
+ self.http.add_credentials('joe', 'password', "example.org")
+ (response, content) = self.http.request(uri, "GET")
+ self.assertEqual(response.status, 401)
+
+ uri = urlparse.urljoin(base, "basic/file.txt")
+ (response, content) = self.http.request(uri, "GET")
+ self.assertEqual(response.status, 401)
+
+ domain = urlparse.urlparse(base)[1]
+ self.http.add_credentials('joe', 'password', domain)
+ (response, content) = self.http.request(uri, "GET")
+ self.assertEqual(response.status, 200)
+
+ uri = urlparse.urljoin(base, "basic/file.txt")
+ (response, content) = self.http.request(uri, "GET")
+ self.assertEqual(response.status, 200)
+
+
+
+
+
+
def testBasicAuthTwoDifferentCredentials(self):
# Test Basic Authentication with multiple sets of credentials
uri = urlparse.urljoin(base, "basic2/file.txt")