Add fancy locking to oauth2client.
Reviewed in http://codereview.appspot.com/4919049/
diff --git a/oauth2client/client.py b/oauth2client/client.py
index f547428..894bfb4 100644
--- a/oauth2client/client.py
+++ b/oauth2client/client.py
@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""An OAuth 2.0 client
+"""An OAuth 2.0 client.
-Tools for interacting with OAuth 2.0 protected
-resources.
+Tools for interacting with OAuth 2.0 protected resources.
"""
__author__ = 'jcgregorio@google.com (Joe Gregorio)'
@@ -27,9 +26,9 @@
import urllib
import urlparse
-try: # pragma: no cover
+try: # pragma: no cover
import simplejson
-except ImportError: # pragma: no cover
+except ImportError: # pragma: no cover
try:
# Try to import from django, should work on App Engine
from django.utils import simplejson
@@ -38,9 +37,11 @@
import json as simplejson
try:
- from urlparse import parse_qsl
+ from urlparse import parse_qsl
except ImportError:
- from cgi import parse_qsl
+ from cgi import parse_qsl
+
+logger = logging.getLogger(__name__)
class Error(Exception):
@@ -92,28 +93,76 @@
class Storage(object):
"""Base class for all Storage objects.
- Store and retrieve a single credential.
+ Store and retrieve a single credential. This class supports locking
+ such that multiple processes and threads can operate on a single
+ store.
"""
- def get(self):
+ def acquire_lock(self):
+ """Acquires any lock necessary to access this Storage.
+
+ This lock is not reentrant."""
+ pass
+
+ def release_lock(self):
+ """Release the Storage lock.
+
+ Trying to release a lock that isn't held will result in a
+ RuntimeError.
+ """
+ pass
+
+ def locked_get(self):
"""Retrieve credential.
+ The Storage lock must be held when this is called.
+
Returns:
oauth2client.client.Credentials
"""
_abstract()
- def put(self, credentials):
+ def locked_put(self, credentials):
"""Write a credential.
+ The Storage lock must be held when this is called.
+
Args:
credentials: Credentials, the credentials to store.
"""
_abstract()
+ def get(self):
+ """Retrieve credential.
+
+ The Storage lock must *not* be held when this is called.
+
+ Returns:
+ oauth2client.client.Credentials
+ """
+ self.acquire_lock()
+ try:
+ return self.locked_get()
+ finally:
+ self.release_lock()
+
+ def put(self, credentials):
+ """Write a credential.
+
+ The Storage lock must be held when this is called.
+
+ Args:
+ credentials: Credentials, the credentials to store.
+ """
+ self.acquire_lock()
+ try:
+ self.locked_put(credentials)
+ finally:
+ self.release_lock()
+
class OAuth2Credentials(Credentials):
- """Credentials object for OAuth 2.0
+ """Credentials object for OAuth 2.0.
Credentials can be applied to an httplib2.Http object using the authorize()
method, which then signs each request from that object with the OAuth 2.0
@@ -123,22 +172,21 @@
"""
def __init__(self, access_token, client_id, client_secret, refresh_token,
- token_expiry, token_uri, user_agent):
- """Create an instance of OAuth2Credentials
+ token_expiry, token_uri, user_agent):
+ """Create an instance of OAuth2Credentials.
This constructor is not usually called by the user, instead
OAuth2Credentials objects are instantiated by the OAuth2WebServerFlow.
Args:
- token_uri: string, URI of token endpoint.
+ access_token: string, access token.
client_id: string, client identifier.
client_secret: string, client secret.
- access_token: string, access token.
- token_expiry: datetime, when the access_token expires.
refresh_token: string, refresh token.
+ token_expiry: datetime, when the access_token expires.
+ token_uri: string, URI of token endpoint.
user_agent: string, The HTTP User-Agent to provide for this application.
-
Notes:
store: callable, a callable that when passed a Credential
will store the credential back to where it came from.
@@ -156,51 +204,66 @@
# True if the credentials have been revoked or expired and can't be
# refreshed.
- self._invalid = False
+ self.invalid = False
@property
- def invalid(self):
- """True if the credentials are invalid, such as being revoked."""
- return getattr(self, '_invalid', False)
+ def access_token_expired(self):
+ """True if the credential is expired or invalid.
+
+ If the token_expiry isn't set, we assume the token doesn't expire.
+ """
+ if self.invalid:
+ return True
+
+ if not self.token_expiry:
+ return False
+
+ now = datetime.datetime.now()
+ if now >= self.token_expiry:
+ logger.info('access_token is expired. Now: %s, token_expiry: %s',
+ now, self.token_expiry)
+ return True
+ return False
def set_store(self, store):
- """Set the storage for the credential.
+ """Set the Storage for the credential.
Args:
- store: callable, a callable that when passed a Credential
- will store the credential back to where it came from.
+ store: Storage, an implementation of Stroage object.
This is needed to store the latest access_token if it
- has expired and been refreshed.
+ has expired and been refreshed. This implementation uses
+ locking to check for updates before updating the
+ access_token.
"""
self.store = store
+ def _updateFromCredential(self, other):
+ """Update this Credential from another instance."""
+ self.__dict__.update(other.__getstate__())
+
def __getstate__(self):
- """Trim the state down to something that can be pickled.
- """
+ """Trim the state down to something that can be pickled."""
d = copy.copy(self.__dict__)
del d['store']
return d
def __setstate__(self, state):
- """Reconstitute the state of the object from being pickled.
- """
+ """Reconstitute the state of the object from being pickled."""
self.__dict__.update(state)
self.store = None
def _generate_refresh_request_body(self):
- """Generate the body that will be used in the refresh request
- """
+ """Generate the body that will be used in the refresh request."""
body = urllib.urlencode({
- 'grant_type': 'refresh_token',
- 'client_id': self.client_id,
- 'client_secret': self.client_secret,
- 'refresh_token': self.refresh_token,
- })
+ 'grant_type': 'refresh_token',
+ 'client_id': self.client_id,
+ 'client_secret': self.client_secret,
+ 'refresh_token': self.refresh_token,
+ })
return body
def _generate_refresh_request_headers(self):
- """Generate the headers that will be used in the refresh request
- """
+ """Generate the headers that will be used in the refresh request."""
headers = {
'content-type': 'application/x-www-form-urlencoded',
}
@@ -211,16 +274,41 @@
return headers
def _refresh(self, http_request):
+ """Refreshes the access_token.
+
+ This method first checks by reading the Storage object if available.
+ If a refresh is still needed, it holds the Storage lock until the
+ refresh is completed.
+ """
+ if not self.store:
+ self._do_refresh_request(http_request)
+ else:
+ self.store.acquire_lock()
+ try:
+ new_cred = self.store.locked_get()
+ if (new_cred and not new_cred.invalid and
+ new_cred.access_token != self.access_token):
+ logger.info('Updated access_token read from Storage')
+ self._updateFromCredential(new_cred)
+ else:
+ self._do_refresh_request(http_request)
+ finally:
+ self.store.release_lock()
+
+ def _do_refresh_request(self, http_request):
"""Refresh the access_token using the refresh_token.
Args:
http: An instance of httplib2.Http.request
or something that acts like it.
+
+ Raises:
+ AccessTokenRefreshError: When the refresh fails.
"""
body = self._generate_refresh_request_body()
headers = self._generate_refresh_request_headers()
- logging.info("Refresing access_token")
+ logger.info('Refresing access_token')
resp, content = http_request(
self.token_uri, method='POST', body=body, headers=headers)
if resp.status == 200:
@@ -233,23 +321,20 @@
seconds=int(d['expires_in'])) + datetime.datetime.now()
else:
self.token_expiry = None
- if self.store is not None:
- self.store(self)
+ if self.store:
+ self.store.locked_put(self)
else:
# An {'error':...} response body means the token is expired or revoked,
# so we flag the credentials as such.
- logging.error('Failed to retrieve access token: %s' % content)
+ logger.error('Failed to retrieve access token: %s' % content)
error_msg = 'Invalid response %s.' % resp['status']
try:
d = simplejson.loads(content)
if 'error' in d:
error_msg = d['error']
- self._invalid = True
- if self.store is not None:
- self.store(self)
- else:
- logging.warning(
- "Unable to store refreshed credentials, no Storage provided.")
+ self.invalid = True
+ if self.store:
+ self.store.locked_put(self)
except:
pass
raise AccessTokenRefreshError(error_msg)
@@ -269,13 +354,11 @@
h = httplib2.Http()
h = credentials.authorize(h)
- You can't create a new OAuth
- subclass of httplib2.Authenication because
- it never gets passed the absolute URI, which is
- needed for signing. So instead we have to overload
- 'request' with a closure that adds in the
- Authorization header and then calls the original version
- of 'request()'.
+ You can't create a new OAuth subclass of httplib2.Authenication
+ because it never gets passed the absolute URI, which is needed for
+ signing. So instead we have to overload 'request' with a closure
+ that adds in the Authorization header and then calls the original
+ version of 'request()'.
"""
request_orig = http.request
@@ -284,12 +367,12 @@
redirections=httplib2.DEFAULT_MAX_REDIRECTS,
connection_type=None):
if not self.access_token:
- logging.info("Attempting refresh to obtain initial access_token")
+ logger.info('Attempting refresh to obtain initial access_token')
self._refresh(request_orig)
- """Modify the request headers to add the appropriate
- Authorization header."""
- if headers == None:
+ # Modify the request headers to add the appropriate
+ # Authorization header.
+ if headers is None:
headers = {}
headers['authorization'] = 'OAuth ' + self.access_token
@@ -303,7 +386,7 @@
redirections, connection_type)
if resp.status == 401:
- logging.info("Refreshing because we got a 401")
+ logger.info('Refreshing due to a 401')
self._refresh(request_orig)
headers['authorization'] = 'OAuth ' + self.access_token
return request_orig(uri, method, body, headers,
@@ -316,14 +399,15 @@
class AccessTokenCredentials(OAuth2Credentials):
- """Credentials object for OAuth 2.0
+ """Credentials object for OAuth 2.0.
- Credentials can be applied to an httplib2.Http object using the authorize()
- method, which then signs each request from that object with the OAuth 2.0
- access token. This set of credentials is for the use case where you have
- acquired an OAuth 2.0 access_token from another place such as a JavaScript
- client or another web application, and wish to use it from Python. Because
- only the access_token is present it can not be refreshed and will in time
+ Credentials can be applied to an httplib2.Http object using the
+ authorize() method, which then signs each request from that object
+ with the OAuth 2.0 access token. This set of credentials is for the
+ use case where you have acquired an OAuth 2.0 access_token from
+ another place such as a JavaScript client or another web
+ application, and wish to use it from Python. Because only the
+ access_token is present it can not be refreshed and will in time
expire.
AccessTokenCredentials objects may be safely pickled and unpickled.
@@ -368,19 +452,20 @@
class AssertionCredentials(OAuth2Credentials):
- """Abstract Credentials object used for OAuth 2.0 assertion grants
+ """Abstract Credentials object used for OAuth 2.0 assertion grants.
- This credential does not require a flow to instantiate because it represents
- a two legged flow, and therefore has all of the required information to
- generate and refresh its own access tokens. It must be subclassed to
- generate the appropriate assertion string.
+ This credential does not require a flow to instantiate because it
+ represents a two legged flow, and therefore has all of the required
+ information to generate and refresh its own access tokens. It must
+ be subclassed to generate the appropriate assertion string.
AssertionCredentials objects may be safely pickled and unpickled.
"""
def __init__(self, assertion_type, user_agent,
- token_uri='https://accounts.google.com/o/oauth2/token', **kwargs):
- """Constructor for AssertionFlowCredentials
+ token_uri='https://accounts.google.com/o/oauth2/token',
+ **unused_kwargs):
+ """Constructor for AssertionFlowCredentials.
Args:
assertion_type: string, assertion type that will be declared to the auth
@@ -403,10 +488,10 @@
assertion = self._generate_assertion()
body = urllib.urlencode({
- 'assertion_type': self.assertion_type,
- 'assertion': assertion,
- 'grant_type': "assertion",
- })
+ 'assertion_type': self.assertion_type,
+ 'assertion': assertion,
+ 'grant_type': 'assertion',
+ })
return body
@@ -424,10 +509,10 @@
"""
def __init__(self, client_id, client_secret, scope, user_agent,
- auth_uri='https://accounts.google.com/o/oauth2/auth',
- token_uri='https://accounts.google.com/o/oauth2/token',
- **kwargs):
- """Constructor for OAuth2WebServerFlow
+ auth_uri='https://accounts.google.com/o/oauth2/auth',
+ token_uri='https://accounts.google.com/o/oauth2/token',
+ **kwargs):
+ """Constructor for OAuth2WebServerFlow.
Args:
client_id: string, client identifier.
@@ -466,11 +551,11 @@
self.redirect_uri = redirect_uri
query = {
- 'response_type': 'code',
- 'client_id': self.client_id,
- 'redirect_uri': redirect_uri,
- 'scope': self.scope,
- }
+ 'response_type': 'code',
+ 'client_id': self.client_id,
+ 'redirect_uri': redirect_uri,
+ 'scope': self.scope,
+ }
query.update(self.params)
parts = list(urlparse.urlparse(self.auth_uri))
query.update(dict(parse_qsl(parts[4]))) # 4 is the index of the query part
@@ -491,15 +576,16 @@
code = code['code']
body = urllib.urlencode({
- 'grant_type': 'authorization_code',
- 'client_id': self.client_id,
- 'client_secret': self.client_secret,
- 'code': code,
- 'redirect_uri': self.redirect_uri,
- 'scope': self.scope,
- })
+ 'grant_type': 'authorization_code',
+ 'client_id': self.client_id,
+ 'client_secret': self.client_secret,
+ 'code': code,
+ 'redirect_uri': self.redirect_uri,
+ 'scope': self.scope,
+ })
headers = {
- 'content-type': 'application/x-www-form-urlencoded',
+ 'user-agent': self.user_agent,
+ 'content-type': 'application/x-www-form-urlencoded',
}
if self.user_agent is not None:
@@ -519,12 +605,12 @@
token_expiry = datetime.datetime.now() + datetime.timedelta(
seconds=int(d['expires_in']))
- logging.info('Successfully retrieved access token: %s' % content)
+ logger.info('Successfully retrieved access token: %s' % content)
return OAuth2Credentials(access_token, self.client_id,
self.client_secret, refresh_token, token_expiry,
self.token_uri, self.user_agent)
else:
- logging.error('Failed to retrieve access token: %s' % content)
+ logger.error('Failed to retrieve access token: %s' % content)
error_msg = 'Invalid response %s.' % resp['status']
try:
d = simplejson.loads(content)