Fix unicode strings leaking into httplib2.
Reviewed in: https://codereview.appspot.com/6868054/
diff --git a/apiclient/http.py b/apiclient/http.py
index d590d1b..cd279b1 100644
--- a/apiclient/http.py
+++ b/apiclient/http.py
@@ -671,9 +671,8 @@
self.body = parsed.query
self.headers['content-length'] = str(len(self.body))
- resp, content = http.request(self.uri, method=self.method,
- body=self.body,
- headers=self.headers)
+ resp, content = http.request(str(self.uri), method=str(self.method),
+ body=self.body, headers=self.headers)
if resp.status >= 300:
raise HttpError(resp, content, uri=self.uri)
return self.postproc(resp, content)
@@ -1353,7 +1352,7 @@
class HttpMock(object):
"""Mock of httplib2.Http"""
- def __init__(self, filename, headers=None):
+ def __init__(self, filename=None, headers=None):
"""
Args:
filename: string, absolute filename to read response from
@@ -1361,10 +1360,19 @@
"""
if headers is None:
headers = {'status': '200 OK'}
- f = file(filename, 'r')
- self.data = f.read()
- f.close()
- self.headers = headers
+ if filename:
+ f = file(filename, 'r')
+ self.data = f.read()
+ f.close()
+ else:
+ self.data = None
+ self.response_headers = headers
+ self.headers = None
+ self.uri = None
+ self.method = None
+ self.body = None
+ self.headers = None
+
def request(self, uri,
method='GET',
@@ -1372,7 +1380,11 @@
headers=None,
redirections=1,
connection_type=None):
- return httplib2.Response(self.headers), self.data
+ self.uri = uri
+ self.method = method
+ self.body = body
+ self.headers = headers
+ return httplib2.Response(self.response_headers), self.data
class HttpMockSequence(object):
diff --git a/oauth2client/client.py b/oauth2client/client.py
index cce4ae6..1b0f828 100644
--- a/oauth2client/client.py
+++ b/oauth2client/client.py
@@ -89,6 +89,11 @@
pass
+class NonAsciiHeaderError(Error):
+ """Header names and values must be ASCII strings."""
+ pass
+
+
def _abstract():
raise NotImplementedError('You need to override this function')
@@ -319,6 +324,28 @@
self.release_lock()
+def clean_headers(headers):
+ """Forces header keys and values to be strings, i.e not unicode.
+
+ The httplib module just concats the header keys and values in a way that may
+ make the message header a unicode string, which, if it then tries to
+ contatenate to a binary request body may result in a unicode decode error.
+
+ Args:
+ headers: dict, A dictionary of headers.
+
+ Returns:
+ The same dictionary but with all the keys converted to strings.
+ """
+ clean = {}
+ try:
+ for k, v in headers.iteritems():
+ clean[str(k)] = str(v)
+ except UnicodeEncodeError:
+ raise NonAsciiHeaderError(k + ': ' + v)
+ return clean
+
+
class OAuth2Credentials(Credentials):
"""Credentials object for OAuth 2.0.
@@ -416,7 +443,7 @@
else:
headers['user-agent'] = self.user_agent
- resp, content = request_orig(uri, method, body, headers,
+ resp, content = request_orig(uri, method, body, clean_headers(headers),
redirections, connection_type)
# Older API (GData) respond with 403
@@ -424,7 +451,7 @@
logger.info('Refreshing due to a %s' % str(resp.status))
self._refresh(request_orig)
self.apply(headers)
- return request_orig(uri, method, body, headers,
+ return request_orig(uri, method, body, clean_headers(headers),
redirections, connection_type)
else:
return (resp, content)
diff --git a/tests/test_http.py b/tests/test_http.py
index 3609f40..81248eb 100644
--- a/tests/test_http.py
+++ b/tests/test_http.py
@@ -483,6 +483,26 @@
self.exceptions[request_id] = exception
+class TestHttpRequest(unittest.TestCase):
+ def test_unicode(self):
+ http = HttpMock(datafile('zoo.json'), headers={'status': '200'})
+ model = JsonModel()
+ uri = u'https://www.googleapis.com/someapi/v1/collection/?foo=bar'
+ method = u'POST'
+ request = HttpRequest(
+ http,
+ model.response,
+ uri,
+ method=method,
+ body=u'{}',
+ headers={'content-type': 'application/json'})
+ request.execute()
+ self.assertEqual(uri, http.uri)
+ self.assertEqual(str, type(http.uri))
+ self.assertEqual(method, http.method)
+ self.assertEqual(str, type(http.method))
+
+
class TestBatch(unittest.TestCase):
def setUp(self):
diff --git a/tests/test_oauth2client.py b/tests/test_oauth2client.py
index 8f855d4..2de87a1 100644
--- a/tests/test_oauth2client.py
+++ b/tests/test_oauth2client.py
@@ -34,6 +34,7 @@
except ImportError:
from cgi import parse_qs
+from apiclient.http import HttpMock
from apiclient.http import HttpMockSequence
from oauth2client.anyjson import simplejson
from oauth2client.clientsecrets import _loadfile
@@ -44,13 +45,14 @@
from oauth2client.client import Credentials
from oauth2client.client import FlowExchangeError
from oauth2client.client import MemoryCache
+from oauth2client.client import NonAsciiHeaderError
from oauth2client.client import OAuth2Credentials
from oauth2client.client import OAuth2WebServerFlow
from oauth2client.client import OOB_CALLBACK_URN
from oauth2client.client import VerifyJwtTokenError
from oauth2client.client import _extract_id_token
-from oauth2client.client import credentials_from_code
from oauth2client.client import credentials_from_clientsecrets_and_code
+from oauth2client.client import credentials_from_code
from oauth2client.client import flow_from_clientsecrets
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
@@ -86,7 +88,7 @@
restored = Credentials.new_from_json(json)
-class OAuth2CredentialsTests(unittest.TestCase):
+class BasicCredentialsTests(unittest.TestCase):
def setUp(self):
access_token = "foo"
@@ -146,6 +148,35 @@
self.assertEqual(instance.__dict__, self.credentials.__dict__)
+ def test_no_unicode_in_request_params(self):
+ access_token = u'foo'
+ client_id = u'some_client_id'
+ client_secret = u'cOuDdkfjxxnv+'
+ refresh_token = u'1/0/a.df219fjls0'
+ token_expiry = unicode(datetime.datetime.utcnow())
+ token_uri = u'https://www.google.com/accounts/o8/oauth2/token'
+ user_agent = u'refresh_checker/1.0'
+ credentials = OAuth2Credentials(access_token, client_id, client_secret,
+ refresh_token, token_expiry, token_uri,
+ user_agent)
+
+ http = HttpMock(headers={'status': '200'})
+ http = credentials.authorize(http)
+ http.request(u'http://example.com', method=u'GET', headers={
+ u'foo': u'bar'
+ })
+ for k, v in http.headers.iteritems():
+ self.assertEqual(str, type(k))
+ self.assertEqual(str, type(v))
+
+ # Test again with unicode strings that can't simple be converted to ASCII.
+ try:
+ http.request(
+ u'http://example.com', method=u'GET', headers={u'foo': u'\N{COMET}'})
+ self.fail('Expected exception to be raised.')
+ except NonAsciiHeaderError:
+ pass
+
class AccessTokenCredentialsTests(unittest.TestCase):