Fix bugs with auth, batch and retries.
Reviewed in http://codereview.appspot.com/5633052/.
diff --git a/tests/test_http.py b/tests/test_http.py
index 67e8aa6..cb7a832 100644
--- a/tests/test_http.py
+++ b/tests/test_http.py
@@ -35,6 +35,52 @@
from apiclient.http import MediaInMemoryUpload
from apiclient.http import set_user_agent
from apiclient.model import JsonModel
+from oauth2client.client import Credentials
+
+
+class MockCredentials(Credentials):
+ """Mock class for all Credentials objects."""
+ def __init__(self, bearer_token):
+ super(MockCredentials, self).__init__()
+ self._authorized = 0
+ self._refreshed = 0
+ self._applied = 0
+ self._bearer_token = bearer_token
+
+ def authorize(self, http):
+ self._authorized += 1
+
+ request_orig = http.request
+
+ # The closure that will replace 'httplib2.Http.request'.
+ def new_request(uri, method='GET', body=None, headers=None,
+ redirections=httplib2.DEFAULT_MAX_REDIRECTS,
+ connection_type=None):
+ # Modify the request headers to add the appropriate
+ # Authorization header.
+ if headers is None:
+ headers = {}
+ self.apply(headers)
+
+ resp, content = request_orig(uri, method, body, headers,
+ redirections, connection_type)
+
+ return resp, content
+
+ # Replace the request method with our own closure.
+ http.request = new_request
+
+ # Set credentials as a property of the request method.
+ setattr(http.request, 'credentials', self)
+
+ return http
+
+ def refresh(self, http):
+ self._refreshed += 1
+
+ def apply(self, headers):
+ self._applied += 1
+ headers['authorization'] = self._bearer_token + ' ' + str(self._refreshed)
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
@@ -52,7 +98,7 @@
http = set_user_agent(http, "my_app/5.5")
resp, content = http.request("http://example.com")
- self.assertEqual(content['user-agent'], 'my_app/5.5')
+ self.assertEqual('my_app/5.5', content['user-agent'])
def test_set_user_agent_nested(self):
http = HttpMockSequence([
@@ -62,25 +108,25 @@
http = set_user_agent(http, "my_app/5.5")
http = set_user_agent(http, "my_library/0.1")
resp, content = http.request("http://example.com")
- self.assertEqual(content['user-agent'], 'my_app/5.5 my_library/0.1')
+ self.assertEqual('my_app/5.5 my_library/0.1', content['user-agent'])
def test_media_file_upload_to_from_json(self):
upload = MediaFileUpload(
datafile('small.png'), chunksize=500, resumable=True)
- self.assertEquals('image/png', upload.mimetype())
- self.assertEquals(190, upload.size())
- self.assertEquals(True, upload.resumable())
- self.assertEquals(500, upload.chunksize())
- self.assertEquals('PNG', upload.getbytes(1, 3))
+ self.assertEqual('image/png', upload.mimetype())
+ self.assertEqual(190, upload.size())
+ self.assertEqual(True, upload.resumable())
+ self.assertEqual(500, upload.chunksize())
+ self.assertEqual('PNG', upload.getbytes(1, 3))
json = upload.to_json()
new_upload = MediaUpload.new_from_json(json)
- self.assertEquals('image/png', new_upload.mimetype())
- self.assertEquals(190, new_upload.size())
- self.assertEquals(True, new_upload.resumable())
- self.assertEquals(500, new_upload.chunksize())
- self.assertEquals('PNG', new_upload.getbytes(1, 3))
+ self.assertEqual('image/png', new_upload.mimetype())
+ self.assertEqual(190, new_upload.size())
+ self.assertEqual(True, new_upload.resumable())
+ self.assertEqual(500, new_upload.chunksize())
+ self.assertEqual('PNG', new_upload.getbytes(1, 3))
def test_http_request_to_from_json(self):
@@ -103,13 +149,13 @@
json = req.to_json()
new_req = HttpRequest.from_json(json, http, _postproc)
- self.assertEquals(new_req.headers,
- {'content-type':
- 'multipart/related; boundary="---flubber"'})
- self.assertEquals(new_req.uri, 'http://example.com')
- self.assertEquals(new_req.body, '{}')
- self.assertEquals(new_req.http, http)
- self.assertEquals(new_req.resumable.to_json(), media_upload.to_json())
+ self.assertEqual({'content-type':
+ 'multipart/related; boundary="---flubber"'},
+ new_req.headers)
+ self.assertEqual('http://example.com', new_req.uri)
+ self.assertEqual('{}', new_req.body)
+ self.assertEqual(http, new_req.http)
+ self.assertEqual(media_upload.to_json(), new_req.resumable.to_json())
EXPECTED = """POST /someapi/v1/collection/?foo=bar HTTP/1.1
Content-Type: application/json
@@ -153,6 +199,50 @@
--batch_foobarbaz--"""
+BATCH_RESPONSE_WITH_401 = """--batch_foobarbaz
+Content-Type: application/http
+Content-Transfer-Encoding: binary
+Content-ID: <randomness+1>
+
+HTTP/1.1 401 Authoration Required
+Content-Type application/json
+Content-Length: 14
+ETag: "etag/pony"\r\n\r\n{"error": {"message":
+ "Authorizaton failed."}}
+
+--batch_foobarbaz
+Content-Type: application/http
+Content-Transfer-Encoding: binary
+Content-ID: <randomness+2>
+
+HTTP/1.1 200 OK
+Content-Type application/json
+Content-Length: 14
+ETag: "etag/sheep"\r\n\r\n{"baz": "qux"}
+--batch_foobarbaz--"""
+
+
+BATCH_SINGLE_RESPONSE = """--batch_foobarbaz
+Content-Type: application/http
+Content-Transfer-Encoding: binary
+Content-ID: <randomness+1>
+
+HTTP/1.1 200 OK
+Content-Type application/json
+Content-Length: 14
+ETag: "etag/pony"\r\n\r\n{"foo": 42}
+--batch_foobarbaz--"""
+
+class Callbacks(object):
+ def __init__(self):
+ self.responses = {}
+ self.exceptions = {}
+
+ def f(self, request_id, response, exception):
+ self.responses[request_id] = response
+ self.exceptions[request_id] = exception
+
+
class TestBatch(unittest.TestCase):
def setUp(self):
@@ -196,7 +286,7 @@
methodId=None,
resumable=None)
s = batch._serialize_request(request).splitlines()
- self.assertEquals(s, EXPECTED.splitlines())
+ self.assertEqual(EXPECTED.splitlines(), s)
def test_serialize_request_media_body(self):
batch = BatchHttpRequest()
@@ -213,9 +303,9 @@
headers={'content-type': 'application/json'},
methodId=None,
resumable=None)
+ # Just testing it shouldn't raise an exception.
s = batch._serialize_request(request).splitlines()
-
def test_serialize_request_no_body(self):
batch = BatchHttpRequest()
request = HttpRequest(
@@ -228,30 +318,30 @@
methodId=None,
resumable=None)
s = batch._serialize_request(request).splitlines()
- self.assertEquals(s, NO_BODY_EXPECTED.splitlines())
+ self.assertEqual(NO_BODY_EXPECTED.splitlines(), s)
def test_deserialize_response(self):
batch = BatchHttpRequest()
resp, content = batch._deserialize_response(RESPONSE)
- self.assertEquals(resp.status, 200)
- self.assertEquals(resp.reason, 'OK')
- self.assertEquals(resp.version, 11)
- self.assertEquals(content, '{"answer": 42}')
+ self.assertEqual(200, resp.status)
+ self.assertEqual('OK', resp.reason)
+ self.assertEqual(11, resp.version)
+ self.assertEqual('{"answer": 42}', content)
def test_new_id(self):
batch = BatchHttpRequest()
id_ = batch._new_id()
- self.assertEquals(id_, '1')
+ self.assertEqual('1', id_)
id_ = batch._new_id()
- self.assertEquals(id_, '2')
+ self.assertEqual('2', id_)
batch.add(self.request1, request_id='3')
id_ = batch._new_id()
- self.assertEquals(id_, '4')
+ self.assertEqual('4', id_)
def test_add(self):
batch = BatchHttpRequest()
@@ -267,13 +357,6 @@
self.assertRaises(BatchError, batch.add, self.request1, request_id='1')
def test_execute(self):
- class Callbacks(object):
- def __init__(self):
- self.responses = {}
-
- def f(self, request_id, response):
- self.responses[request_id] = response
-
batch = BatchHttpRequest()
callbacks = Callbacks()
@@ -285,8 +368,10 @@
BATCH_RESPONSE),
])
batch.execute(http)
- self.assertEqual(callbacks.responses['1'], {'foo': 42})
- self.assertEqual(callbacks.responses['2'], {'baz': 'qux'})
+ self.assertEqual({'foo': 42}, callbacks.responses['1'])
+ self.assertEqual(None, callbacks.exceptions['1'])
+ self.assertEqual({'baz': 'qux'}, callbacks.responses['2'])
+ self.assertEqual(None, callbacks.exceptions['2'])
def test_execute_request_body(self):
batch = BatchHttpRequest()
@@ -311,14 +396,82 @@
header = parts[1].splitlines()[1]
self.assertEqual('Content-Type: application/http', header)
+ def test_execute_refresh_and_retry_on_401(self):
+ batch = BatchHttpRequest()
+ callbacks = Callbacks()
+ cred_1 = MockCredentials('Foo')
+ cred_2 = MockCredentials('Bar')
+
+ http = HttpMockSequence([
+ ({'status': '200',
+ 'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
+ BATCH_RESPONSE_WITH_401),
+ ({'status': '200',
+ 'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
+ BATCH_SINGLE_RESPONSE),
+ ])
+
+ creds_http_1 = HttpMockSequence([])
+ cred_1.authorize(creds_http_1)
+
+ creds_http_2 = HttpMockSequence([])
+ cred_2.authorize(creds_http_2)
+
+ self.request1.http = creds_http_1
+ self.request2.http = creds_http_2
+
+ batch.add(self.request1, callback=callbacks.f)
+ batch.add(self.request2, callback=callbacks.f)
+ batch.execute(http)
+
+ self.assertEqual({'foo': 42}, callbacks.responses['1'])
+ self.assertEqual(None, callbacks.exceptions['1'])
+ self.assertEqual({'baz': 'qux'}, callbacks.responses['2'])
+ self.assertEqual(None, callbacks.exceptions['2'])
+
+ self.assertEqual(1, cred_1._refreshed)
+ self.assertEqual(0, cred_2._refreshed)
+
+ self.assertEqual(1, cred_1._authorized)
+ self.assertEqual(1, cred_2._authorized)
+
+ self.assertEqual(1, cred_2._applied)
+ self.assertEqual(2, cred_1._applied)
+
+ def test_http_errors_passed_to_callback(self):
+ batch = BatchHttpRequest()
+ callbacks = Callbacks()
+ cred_1 = MockCredentials('Foo')
+ cred_2 = MockCredentials('Bar')
+
+ http = HttpMockSequence([
+ ({'status': '200',
+ 'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
+ BATCH_RESPONSE_WITH_401),
+ ({'status': '200',
+ 'content-type': 'multipart/mixed; boundary="batch_foobarbaz"'},
+ BATCH_RESPONSE_WITH_401),
+ ])
+
+ creds_http_1 = HttpMockSequence([])
+ cred_1.authorize(creds_http_1)
+
+ creds_http_2 = HttpMockSequence([])
+ cred_2.authorize(creds_http_2)
+
+ self.request1.http = creds_http_1
+ self.request2.http = creds_http_2
+
+ batch.add(self.request1, callback=callbacks.f)
+ batch.add(self.request2, callback=callbacks.f)
+ batch.execute(http)
+
+ self.assertEqual(None, callbacks.responses['1'])
+ self.assertEqual(401, callbacks.exceptions['1'].resp.status)
+ self.assertEqual({u'baz': u'qux'}, callbacks.responses['2'])
+ self.assertEqual(None, callbacks.exceptions['2'])
+
def test_execute_global_callback(self):
- class Callbacks(object):
- def __init__(self):
- self.responses = {}
-
- def f(self, request_id, response):
- self.responses[request_id] = response
-
callbacks = Callbacks()
batch = BatchHttpRequest(callback=callbacks.f)
@@ -330,8 +483,8 @@
BATCH_RESPONSE),
])
batch.execute(http)
- self.assertEqual(callbacks.responses['1'], {'foo': 42})
- self.assertEqual(callbacks.responses['2'], {'baz': 'qux'})
+ self.assertEqual({'foo': 42}, callbacks.responses['1'])
+ self.assertEqual({'baz': 'qux'}, callbacks.responses['2'])
def test_media_inmemory_upload(self):
media = MediaInMemoryUpload('abcdef', 'text/plain', chunksize=10,
diff --git a/tests/test_oauth2client.py b/tests/test_oauth2client.py
index d50d6a0..3e512a1 100644
--- a/tests/test_oauth2client.py
+++ b/tests/test_oauth2client.py
@@ -70,7 +70,7 @@
])
http = self.credentials.authorize(http)
resp, content = http.request("http://example.com")
- self.assertEqual(content['authorization'], 'OAuth 1/3w')
+ self.assertEqual('Bearer 1/3w', content['Authorization'])
def test_token_refresh_failure(self):
http = HttpMockSequence([
@@ -95,11 +95,11 @@
def test_to_from_json(self):
json = self.credentials.to_json()
instance = OAuth2Credentials.from_json(json)
- self.assertEquals(type(instance), OAuth2Credentials)
+ self.assertEqual(OAuth2Credentials, type(instance))
instance.token_expiry = None
self.credentials.token_expiry = None
- self.assertEquals(self.credentials.__dict__, instance.__dict__)
+ self.assertEqual(instance.__dict__, self.credentials.__dict__)
class AccessTokenCredentialsTests(unittest.TestCase):
@@ -136,7 +136,7 @@
])
http = self.credentials.authorize(http)
resp, content = http.request('http://example.com')
- self.assertEqual(content['authorization'], 'OAuth foo')
+ self.assertEqual('Bearer foo', content['Authorization'])
class TestAssertionCredentials(unittest.TestCase):
@@ -155,8 +155,8 @@
def test_assertion_body(self):
body = urlparse.parse_qs(self.credentials._generate_refresh_request_body())
- self.assertEqual(body['assertion'][0], self.assertion_text)
- self.assertEqual(body['assertion_type'][0], self.assertion_type)
+ self.assertEqual(self.assertion_text, body['assertion'][0])
+ self.assertEqual(self.assertion_type, body['assertion_type'][0])
def test_assertion_refresh(self):
http = HttpMockSequence([
@@ -165,7 +165,7 @@
])
http = self.credentials.authorize(http)
resp, content = http.request("http://example.com")
- self.assertEqual(content['authorization'], 'OAuth 1/3w')
+ self.assertEqual('Bearer 1/3w', content['Authorization'])
class ExtractIdTokenText(unittest.TestCase):
@@ -177,7 +177,7 @@
jwt = 'stuff.' + payload + '.signature'
extracted = _extract_id_token(jwt)
- self.assertEqual(body, extracted)
+ self.assertEqual(extracted, body)
def test_extract_failure(self):
body = {'foo': 'bar'}
@@ -201,11 +201,11 @@
parsed = urlparse.urlparse(authorize_url)
q = parse_qs(parsed[4])
- self.assertEqual(q['client_id'][0], 'client_id+1')
- self.assertEqual(q['response_type'][0], 'code')
- self.assertEqual(q['scope'][0], 'foo')
- self.assertEqual(q['redirect_uri'][0], 'OOB_CALLBACK_URN')
- self.assertEqual(q['access_type'][0], 'offline')
+ self.assertEqual('client_id+1', q['client_id'][0])
+ self.assertEqual('code', q['response_type'][0])
+ self.assertEqual('foo', q['scope'][0])
+ self.assertEqual('OOB_CALLBACK_URN', q['redirect_uri'][0])
+ self.assertEqual('offline', q['access_type'][0])
def test_override_flow_access_type(self):
"""Passing access_type overrides the default."""
@@ -220,11 +220,11 @@
parsed = urlparse.urlparse(authorize_url)
q = parse_qs(parsed[4])
- self.assertEqual(q['client_id'][0], 'client_id+1')
- self.assertEqual(q['response_type'][0], 'code')
- self.assertEqual(q['scope'][0], 'foo')
- self.assertEqual(q['redirect_uri'][0], 'OOB_CALLBACK_URN')
- self.assertEqual(q['access_type'][0], 'online')
+ self.assertEqual('client_id+1', q['client_id'][0])
+ self.assertEqual('code', q['response_type'][0])
+ self.assertEqual('foo', q['scope'][0])
+ self.assertEqual('OOB_CALLBACK_URN', q['redirect_uri'][0])
+ self.assertEqual('online', q['access_type'][0])
def test_exchange_failure(self):
http = HttpMockSequence([
@@ -246,9 +246,9 @@
])
credentials = self.flow.step2_exchange('some random code', http)
- self.assertEqual(credentials.access_token, 'SlAV32hkKG')
- self.assertNotEqual(credentials.token_expiry, None)
- self.assertEqual(credentials.refresh_token, '8xLOxBtZp8')
+ self.assertEqual('SlAV32hkKG', credentials.access_token)
+ self.assertNotEqual(None, credentials.token_expiry)
+ self.assertEqual('8xLOxBtZp8', credentials.refresh_token)
def test_exchange_no_expires_in(self):
http = HttpMockSequence([
@@ -257,7 +257,7 @@
])
credentials = self.flow.step2_exchange('some random code', http)
- self.assertEqual(credentials.token_expiry, None)
+ self.assertEqual(None, credentials.token_expiry)
def test_exchange_id_token_fail(self):
http = HttpMockSequence([
@@ -282,7 +282,7 @@
])
credentials = self.flow.step2_exchange('some random code', http)
- self.assertEquals(body, credentials.id_token)
+ self.assertEqual(credentials.id_token, body)
if __name__ == '__main__':
diff --git a/tests/test_oauth2client_jwt.py b/tests/test_oauth2client_jwt.py
index 83551c6..dcbb33c 100644
--- a/tests/test_oauth2client_jwt.py
+++ b/tests/test_oauth2client_jwt.py
@@ -100,8 +100,8 @@
certs = {'foo': public_key }
audience = 'some_audience_address@testing.gserviceaccount.com'
contents = crypt.verify_signed_jwt_with_certs(jwt, certs, audience)
- self.assertEquals('billy bob', contents['user'])
- self.assertEquals('data', contents['metadata']['meta'])
+ self.assertEqual('billy bob', contents['user'])
+ self.assertEqual('data', contents['metadata']['meta'])
def test_verify_id_token_with_certs_uri(self):
jwt = self._create_signed_jwt()
@@ -112,8 +112,8 @@
contents = verify_id_token(jwt,
'some_audience_address@testing.gserviceaccount.com', http)
- self.assertEquals('billy bob', contents['user'])
- self.assertEquals('data', contents['metadata']['meta'])
+ self.assertEqual('billy bob', contents['user'])
+ self.assertEqual('data', contents['metadata']['meta'])
def test_verify_id_token_with_certs_uri_fails(self):
jwt = self._create_signed_jwt()
@@ -195,7 +195,7 @@
])
http = credentials.authorize(http)
resp, content = http.request('http://example.org')
- self.assertEquals(content['authorization'], 'OAuth 1/3w')
+ self.assertEqual('Bearer 1/3w', content['Authorization'])
if __name__ == '__main__':
unittest.main()