Fix refresh token handling and add unit tests for oauth2client.client
diff --git a/oauth2client/client.py b/oauth2client/client.py
index 3d7ded7..c8d0941 100644
--- a/oauth2client/client.py
+++ b/oauth2client/client.py
@@ -36,7 +36,12 @@
pass
-class RequestError(Error):
+class FlowExchangeError(Error):
+ """Error occurred during request."""
+ pass
+
+
+class AccessTokenRefreshError(Error):
"""Error occurred during request."""
pass
@@ -144,7 +149,7 @@
@property
def invalid(self):
"""True if the credentials are invalid, such as being revoked."""
- return self._invalid
+ return getattr(self, '_invalid', False)
def set_store(self, store):
"""Set the storage for the credential.
@@ -204,16 +209,18 @@
else:
# An {'error':...} response body means the token is expired or revoked, so
# we flag the credentials as such.
+ logging.error('Failed to retrieve access token: %s' % content)
+ error_msg = 'Invalid response %s.' % resp['status']
try:
d = simplejson.loads(content)
if 'error' in d:
+ error_msg = d['error']
self._invalid = True
if self.store is not None:
self.store(self)
except:
pass
- logging.error('Failed to retrieve access token: %s' % content)
- raise RequestError('Invalid response %s.' % resp['status'])
+ raise AccessTokenRefreshError(error_msg)
def authorize(self, http):
"""Authorize an httplib2.Http instance with these credentials.
@@ -258,6 +265,7 @@
if resp.status == 401:
logging.info("Refreshing because we got a 401")
self._refresh(request_orig)
+ headers['authorization'] = 'OAuth ' + self.access_token
return request_orig(uri, method, body, headers,
redirections, connection_type)
else:
@@ -378,13 +386,14 @@
parts[4] = urllib.urlencode(query)
return urlparse.urlunparse(parts)
- def step2_exchange(self, code):
+ def step2_exchange(self, code, http=None):
"""Exhanges a code for OAuth2Credentials.
Args:
code: string or dict, either the code as a string, or a dictionary
of the query parameters to the redirect_uri, which contains
the code.
+ http: httplib2.Http, optional http instance to use to do the fetch
"""
if not (isinstance(code, str) or isinstance(code, unicode)):
@@ -402,8 +411,9 @@
'user-agent': self.user_agent,
'content-type': 'application/x-www-form-urlencoded'
}
- h = httplib2.Http()
- resp, content = h.request(self.token_uri, method='POST', body=body, headers=headers)
+ if http is None:
+ http = httplib2.Http()
+ resp, content = http.request(self.token_uri, method='POST', body=body, headers=headers)
if resp.status == 200:
# TODO(jcgregorio) Raise an error if simplejson.loads fails?
d = simplejson.loads(content)
@@ -419,4 +429,12 @@
self.user_agent)
else:
logging.error('Failed to retrieve access token: %s' % content)
- raise RequestError('Invalid response %s.' % resp['status'])
+ error_msg = 'Invalid response %s.' % resp['status']
+ try:
+ d = simplejson.loads(content)
+ if 'error' in d:
+ error_msg = d['error']
+ except:
+ pass
+
+ raise FlowExchangeError(error_msg)