Fix refresh token handling and add unit tests for oauth2client.client
diff --git a/apiclient/http.py b/apiclient/http.py
index cecf3f2..f713ec8 100644
--- a/apiclient/http.py
+++ b/apiclient/http.py
@@ -177,3 +177,45 @@
redirections=1,
connection_type=None):
return httplib2.Response(self.headers), self.data
+
+
+class HttpMockSequence(object):
+ """Mock of httplib2.Http
+
+ Mocks a sequence of calls to request returning different responses for each
+ call. Create an instance initialized with the desired response headers
+ and content and then use as if an httplib2.Http instance.
+
+ http = HttpMockSequence([
+ ({'status': '401'}, ''),
+ ({'status': '200'}, '{"access_token":"1/3w","expires_in":3600}'),
+ ({'status': '200'}, 'echo_request_headers'),
+ ])
+ resp, content = http.request("http://examples.com")
+
+ There are special values you can pass in for content to trigger
+ behavours that are helpful in testing.
+
+ 'echo_request_headers' means return the request headers in the response body
+ 'echo_request_body' means return the request body in the response body
+ """
+
+ def __init__(self, iterable):
+ """
+ Args:
+ iterable: iterable, a sequence of pairs of (headers, body)
+ """
+ self._iterable = iterable
+
+ def request(self, uri,
+ method='GET',
+ body=None,
+ headers=None,
+ redirections=1,
+ connection_type=None):
+ resp, content = self._iterable.pop(0)
+ if content == 'echo_request_headers':
+ content = headers
+ elif content == 'echo_request_body':
+ content = body
+ return httplib2.Response(resp), content
diff --git a/oauth2client/client.py b/oauth2client/client.py
index 3d7ded7..c8d0941 100644
--- a/oauth2client/client.py
+++ b/oauth2client/client.py
@@ -36,7 +36,12 @@
pass
-class RequestError(Error):
+class FlowExchangeError(Error):
+ """Error occurred during request."""
+ pass
+
+
+class AccessTokenRefreshError(Error):
"""Error occurred during request."""
pass
@@ -144,7 +149,7 @@
@property
def invalid(self):
"""True if the credentials are invalid, such as being revoked."""
- return self._invalid
+ return getattr(self, '_invalid', False)
def set_store(self, store):
"""Set the storage for the credential.
@@ -204,16 +209,18 @@
else:
# An {'error':...} response body means the token is expired or revoked, so
# we flag the credentials as such.
+ logging.error('Failed to retrieve access token: %s' % content)
+ error_msg = 'Invalid response %s.' % resp['status']
try:
d = simplejson.loads(content)
if 'error' in d:
+ error_msg = d['error']
self._invalid = True
if self.store is not None:
self.store(self)
except:
pass
- logging.error('Failed to retrieve access token: %s' % content)
- raise RequestError('Invalid response %s.' % resp['status'])
+ raise AccessTokenRefreshError(error_msg)
def authorize(self, http):
"""Authorize an httplib2.Http instance with these credentials.
@@ -258,6 +265,7 @@
if resp.status == 401:
logging.info("Refreshing because we got a 401")
self._refresh(request_orig)
+ headers['authorization'] = 'OAuth ' + self.access_token
return request_orig(uri, method, body, headers,
redirections, connection_type)
else:
@@ -378,13 +386,14 @@
parts[4] = urllib.urlencode(query)
return urlparse.urlunparse(parts)
- def step2_exchange(self, code):
+ def step2_exchange(self, code, http=None):
"""Exhanges a code for OAuth2Credentials.
Args:
code: string or dict, either the code as a string, or a dictionary
of the query parameters to the redirect_uri, which contains
the code.
+ http: httplib2.Http, optional http instance to use to do the fetch
"""
if not (isinstance(code, str) or isinstance(code, unicode)):
@@ -402,8 +411,9 @@
'user-agent': self.user_agent,
'content-type': 'application/x-www-form-urlencoded'
}
- h = httplib2.Http()
- resp, content = h.request(self.token_uri, method='POST', body=body, headers=headers)
+ if http is None:
+ http = httplib2.Http()
+ resp, content = http.request(self.token_uri, method='POST', body=body, headers=headers)
if resp.status == 200:
# TODO(jcgregorio) Raise an error if simplejson.loads fails?
d = simplejson.loads(content)
@@ -419,4 +429,12 @@
self.user_agent)
else:
logging.error('Failed to retrieve access token: %s' % content)
- raise RequestError('Invalid response %s.' % resp['status'])
+ error_msg = 'Invalid response %s.' % resp['status']
+ try:
+ d = simplejson.loads(content)
+ if 'error' in d:
+ error_msg = d['error']
+ except:
+ pass
+
+ raise FlowExchangeError(error_msg)
diff --git a/samples/moderator/moderator.py b/samples/moderator/moderator.py
index c22f86e..93e3ff6 100644
--- a/samples/moderator/moderator.py
+++ b/samples/moderator/moderator.py
@@ -33,7 +33,7 @@
flow = FlowThreeLegged(moderator_discovery,
consumer_key='anonymous',
consumer_secret='anonymous',
- user_agent='google-api-client-python-mdrtr-cmdline/1.0',
+ user_agent='python-moderator-sample/1.0',
domain='anonymous',
scope='https://www.googleapis.com/auth/moderator',
#scope='tag:google.com,2010:auth/moderator',
diff --git a/samples/threadqueue/main.py b/samples/threadqueue/main.py
index d184f6d..63d092d 100644
--- a/samples/threadqueue/main.py
+++ b/samples/threadqueue/main.py
@@ -88,7 +88,7 @@
flow = FlowThreeLegged(moderator_discovery,
consumer_key='anonymous',
consumer_secret='anonymous',
- user_agent='google-api-client-python-thread-sample/1.0',
+ user_agent='python-threading-sample/1.0',
domain='anonymous',
scope='https://www.googleapis.com/auth/moderator',
xoauth_displayname='Google API Client Example App')
diff --git a/tests/test_oauth2client.py b/tests/test_oauth2client.py
new file mode 100644
index 0000000..9f30eff
--- /dev/null
+++ b/tests/test_oauth2client.py
@@ -0,0 +1,172 @@
+#!/usr/bin/python2.4
+#
+# Copyright 2010 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Discovery document tests
+
+Unit tests for objects created from discovery documents.
+"""
+
+__author__ = 'jcgregorio@google.com (Joe Gregorio)'
+
+import unittest
+import urlparse
+try:
+ from urlparse import parse_qs
+except ImportError:
+ from cgi import parse_qs
+
+from apiclient.http import HttpMockSequence
+from oauth2client.client import AccessTokenCredentials
+from oauth2client.client import AccessTokenCredentialsError
+from oauth2client.client import AccessTokenRefreshError
+from oauth2client.client import FlowExchangeError
+from oauth2client.client import OAuth2Credentials
+from oauth2client.client import OAuth2WebServerFlow
+
+
+class OAuth2CredentialsTests(unittest.TestCase):
+
+ def setUp(self):
+ access_token = "foo"
+ client_id = "some_client_id"
+ client_secret = "cOuDdkfjxxnv+"
+ refresh_token = "1/0/a.df219fjls0"
+ token_expiry = "ignored"
+ token_uri = "https://www.google.com/accounts/o8/oauth2/token"
+ user_agent = "refresh_checker/1.0"
+ self.credentials = OAuth2Credentials(
+ access_token, client_id, client_secret,
+ refresh_token, token_expiry, token_uri,
+ user_agent)
+
+ def test_token_refresh_success(self):
+ http = HttpMockSequence([
+ ({'status': '401'}, ''),
+ ({'status': '200'}, '{"access_token":"1/3w","expires_in":3600}'),
+ ({'status': '200'}, 'echo_request_headers'),
+ ])
+ http = self.credentials.authorize(http)
+ resp, content = http.request("http://example.com")
+ self.assertEqual(content['authorization'], 'OAuth 1/3w')
+
+ def test_token_refresh_failure(self):
+ http = HttpMockSequence([
+ ({'status': '401'}, ''),
+ ({'status': '400'}, '{"error":"access_denied"}'),
+ ])
+ http = self.credentials.authorize(http)
+ try:
+ http.request("http://example.com")
+ self.fail("should raise AccessTokenRefreshError exception")
+ except AccessTokenRefreshError:
+ pass
+
+ def test_non_401_error_response(self):
+ http = HttpMockSequence([
+ ({'status': '400'}, ''),
+ ])
+ http = self.credentials.authorize(http)
+ resp, content = http.request("http://example.com")
+ self.assertEqual(400, resp.status)
+
+
+class AccessTokenCredentialsTests(unittest.TestCase):
+
+ def setUp(self):
+ access_token = "foo"
+ user_agent = "refresh_checker/1.0"
+ self.credentials = AccessTokenCredentials(access_token, user_agent)
+
+ def test_token_refresh_success(self):
+ http = HttpMockSequence([
+ ({'status': '401'}, ''),
+ ])
+ http = self.credentials.authorize(http)
+ try:
+ resp, content = http.request("http://example.com")
+ self.fail("should throw exception if token expires")
+ except AccessTokenCredentialsError:
+ pass
+ except Exception:
+ self.fail("should only throw AccessTokenCredentialsError")
+
+ def test_non_401_error_response(self):
+ http = HttpMockSequence([
+ ({'status': '400'}, ''),
+ ])
+ http = self.credentials.authorize(http)
+ resp, content = http.request("http://example.com")
+ self.assertEqual(400, resp.status)
+
+
+class OAuth2WebServerFlowTest(unittest.TestCase):
+
+ def setUp(self):
+ self.flow = OAuth2WebServerFlow(
+ client_id='client_id+1',
+ client_secret='secret+1',
+ scope='foo',
+ user_agent='unittest-sample/1.0',
+ )
+
+ def test_construct_authorize_url(self):
+ authorize_url = self.flow.step1_get_authorize_url('oob')
+
+ parsed = urlparse.urlparse(authorize_url)
+ q = parse_qs(parsed[4])
+ self.assertEqual(q['client_id'][0], 'client_id+1')
+ self.assertEqual(q['response_type'][0], 'code')
+ self.assertEqual(q['scope'][0], 'foo')
+ self.assertEqual(q['redirect_uri'][0], 'oob')
+
+ def test_exchange_failure(self):
+ http = HttpMockSequence([
+ ({'status': '400'}, '{"error":"invalid_request"}')
+ ])
+
+ try:
+ credentials = self.flow.step2_exchange('some random code', http)
+ self.fail("should raise exception if exchange doesn't get 200")
+ except FlowExchangeError:
+ pass
+
+ def test_exchange_success(self):
+ http = HttpMockSequence([
+ ({'status': '200'},
+ """{ "access_token":"SlAV32hkKG",
+ "expires_in":3600,
+ "refresh_token":"8xLOxBtZp8" }"""),
+ ])
+
+ credentials = self.flow.step2_exchange('some random code', http)
+ self.assertEqual(credentials.access_token, 'SlAV32hkKG')
+ self.assertNotEqual(credentials.token_expiry, None)
+ self.assertEqual(credentials.refresh_token, '8xLOxBtZp8')
+
+
+ def test_exchange_no_expires_in(self):
+ http = HttpMockSequence([
+ ({'status': '200'}, """{ "access_token":"SlAV32hkKG",
+ "refresh_token":"8xLOxBtZp8" }"""),
+ ])
+
+ credentials = self.flow.step2_exchange('some random code', http)
+ self.assertEqual(credentials.token_expiry, None)
+
+
+if __name__ == '__main__':
+ unittest.main()