Add fancy locking to oauth2client.

Reviewed in http://codereview.appspot.com/4919049/
diff --git a/oauth2client/client.py b/oauth2client/client.py
index f547428..894bfb4 100644
--- a/oauth2client/client.py
+++ b/oauth2client/client.py
@@ -12,10 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-"""An OAuth 2.0 client
+"""An OAuth 2.0 client.
 
-Tools for interacting with OAuth 2.0 protected
-resources.
+Tools for interacting with OAuth 2.0 protected resources.
 """
 
 __author__ = 'jcgregorio@google.com (Joe Gregorio)'
@@ -27,9 +26,9 @@
 import urllib
 import urlparse
 
-try: # pragma: no cover
+try:  # pragma: no cover
   import simplejson
-except ImportError: # pragma: no cover
+except ImportError:  # pragma: no cover
   try:
     # Try to import from django, should work on App Engine
     from django.utils import simplejson
@@ -38,9 +37,11 @@
     import json as simplejson
 
 try:
-    from urlparse import parse_qsl
+  from urlparse import parse_qsl
 except ImportError:
-    from cgi import parse_qsl
+  from cgi import parse_qsl
+
+logger = logging.getLogger(__name__)
 
 
 class Error(Exception):
@@ -92,28 +93,76 @@
 class Storage(object):
   """Base class for all Storage objects.
 
-  Store and retrieve a single credential.
+  Store and retrieve a single credential.  This class supports locking
+  such that multiple processes and threads can operate on a single
+  store.
   """
 
-  def get(self):
+  def acquire_lock(self):
+    """Acquires any lock necessary to access this Storage.
+
+    This lock is not reentrant."""
+    pass
+
+  def release_lock(self):
+    """Release the Storage lock.
+
+    Trying to release a lock that isn't held will result in a
+    RuntimeError.
+    """
+    pass
+
+  def locked_get(self):
     """Retrieve credential.
 
+    The Storage lock must be held when this is called.
+
     Returns:
       oauth2client.client.Credentials
     """
     _abstract()
 
-  def put(self, credentials):
+  def locked_put(self, credentials):
     """Write a credential.
 
+    The Storage lock must be held when this is called.
+
     Args:
       credentials: Credentials, the credentials to store.
     """
     _abstract()
 
+  def get(self):
+    """Retrieve credential.
+
+    The Storage lock must *not* be held when this is called.
+
+    Returns:
+      oauth2client.client.Credentials
+    """
+    self.acquire_lock()
+    try:
+      return self.locked_get()
+    finally:
+      self.release_lock()
+
+  def put(self, credentials):
+    """Write a credential.
+
+    The Storage lock must be held when this is called.
+
+    Args:
+      credentials: Credentials, the credentials to store.
+    """
+    self.acquire_lock()
+    try:
+      self.locked_put(credentials)
+    finally:
+      self.release_lock()
+
 
 class OAuth2Credentials(Credentials):
-  """Credentials object for OAuth 2.0
+  """Credentials object for OAuth 2.0.
 
   Credentials can be applied to an httplib2.Http object using the authorize()
   method, which then signs each request from that object with the OAuth 2.0
@@ -123,22 +172,21 @@
   """
 
   def __init__(self, access_token, client_id, client_secret, refresh_token,
-      token_expiry, token_uri, user_agent):
-    """Create an instance of OAuth2Credentials
+               token_expiry, token_uri, user_agent):
+    """Create an instance of OAuth2Credentials.
 
     This constructor is not usually called by the user, instead
     OAuth2Credentials objects are instantiated by the OAuth2WebServerFlow.
 
     Args:
-      token_uri: string, URI of token endpoint.
+      access_token: string, access token.
       client_id: string, client identifier.
       client_secret: string, client secret.
-      access_token: string, access token.
-      token_expiry: datetime, when the access_token expires.
       refresh_token: string, refresh token.
+      token_expiry: datetime, when the access_token expires.
+      token_uri: string, URI of token endpoint.
       user_agent: string, The HTTP User-Agent to provide for this application.
 
-
     Notes:
       store: callable, a callable that when passed a Credential
         will store the credential back to where it came from.
@@ -156,51 +204,66 @@
 
     # True if the credentials have been revoked or expired and can't be
     # refreshed.
-    self._invalid = False
+    self.invalid = False
 
   @property
-  def invalid(self):
-    """True if the credentials are invalid, such as being revoked."""
-    return getattr(self, '_invalid', False)
+  def access_token_expired(self):
+    """True if the credential is expired or invalid.
+
+    If the token_expiry isn't set, we assume the token doesn't expire.
+    """
+    if self.invalid:
+      return True
+
+    if not self.token_expiry:
+      return False
+
+    now = datetime.datetime.now()
+    if now >= self.token_expiry:
+      logger.info('access_token is expired. Now: %s, token_expiry: %s',
+                  now, self.token_expiry)
+      return True
+    return False
 
   def set_store(self, store):
-    """Set the storage for the credential.
+    """Set the Storage for the credential.
 
     Args:
-      store: callable, a callable that when passed a Credential
-        will store the credential back to where it came from.
+      store: Storage, an implementation of Stroage object.
         This is needed to store the latest access_token if it
-        has expired and been refreshed.
+        has expired and been refreshed.  This implementation uses
+        locking to check for updates before updating the
+        access_token.
     """
     self.store = store
 
+  def _updateFromCredential(self, other):
+    """Update this Credential from another instance."""
+    self.__dict__.update(other.__getstate__())
+
   def __getstate__(self):
-    """Trim the state down to something that can be pickled.
-    """
+    """Trim the state down to something that can be pickled."""
     d = copy.copy(self.__dict__)
     del d['store']
     return d
 
   def __setstate__(self, state):
-    """Reconstitute the state of the object from being pickled.
-    """
+    """Reconstitute the state of the object from being pickled."""
     self.__dict__.update(state)
     self.store = None
 
   def _generate_refresh_request_body(self):
-    """Generate the body that will be used in the refresh request
-    """
+    """Generate the body that will be used in the refresh request."""
     body = urllib.urlencode({
-      'grant_type': 'refresh_token',
-      'client_id': self.client_id,
-      'client_secret': self.client_secret,
-      'refresh_token': self.refresh_token,
-      })
+        'grant_type': 'refresh_token',
+        'client_id': self.client_id,
+        'client_secret': self.client_secret,
+        'refresh_token': self.refresh_token,
+        })
     return body
 
   def _generate_refresh_request_headers(self):
-    """Generate the headers that will be used in the refresh request
-    """
+    """Generate the headers that will be used in the refresh request."""
     headers = {
         'content-type': 'application/x-www-form-urlencoded',
     }
@@ -211,16 +274,41 @@
     return headers
 
   def _refresh(self, http_request):
+    """Refreshes the access_token.
+
+    This method first checks by reading the Storage object if available.
+    If a refresh is still needed, it holds the Storage lock until the
+    refresh is completed.
+    """
+    if not self.store:
+      self._do_refresh_request(http_request)
+    else:
+      self.store.acquire_lock()
+      try:
+        new_cred = self.store.locked_get()
+        if (new_cred and not new_cred.invalid and
+            new_cred.access_token != self.access_token):
+          logger.info('Updated access_token read from Storage')
+          self._updateFromCredential(new_cred)
+        else:
+          self._do_refresh_request(http_request)
+      finally:
+        self.store.release_lock()
+
+  def _do_refresh_request(self, http_request):
     """Refresh the access_token using the refresh_token.
 
     Args:
        http: An instance of httplib2.Http.request
            or something that acts like it.
+
+    Raises:
+      AccessTokenRefreshError: When the refresh fails.
     """
     body = self._generate_refresh_request_body()
     headers = self._generate_refresh_request_headers()
 
-    logging.info("Refresing access_token")
+    logger.info('Refresing access_token')
     resp, content = http_request(
         self.token_uri, method='POST', body=body, headers=headers)
     if resp.status == 200:
@@ -233,23 +321,20 @@
             seconds=int(d['expires_in'])) + datetime.datetime.now()
       else:
         self.token_expiry = None
-      if self.store is not None:
-        self.store(self)
+      if self.store:
+        self.store.locked_put(self)
     else:
       # An {'error':...} response body means the token is expired or revoked,
       # so we flag the credentials as such.
-      logging.error('Failed to retrieve access token: %s' % content)
+      logger.error('Failed to retrieve access token: %s' % content)
       error_msg = 'Invalid response %s.' % resp['status']
       try:
         d = simplejson.loads(content)
         if 'error' in d:
           error_msg = d['error']
-          self._invalid = True
-          if self.store is not None:
-            self.store(self)
-          else:
-            logging.warning(
-                "Unable to store refreshed credentials, no Storage provided.")
+          self.invalid = True
+          if self.store:
+            self.store.locked_put(self)
       except:
         pass
       raise AccessTokenRefreshError(error_msg)
@@ -269,13 +354,11 @@
       h = httplib2.Http()
       h = credentials.authorize(h)
 
-    You can't create a new OAuth
-    subclass of httplib2.Authenication because
-    it never gets passed the absolute URI, which is
-    needed for signing. So instead we have to overload
-    'request' with a closure that adds in the
-    Authorization header and then calls the original version
-    of 'request()'.
+    You can't create a new OAuth subclass of httplib2.Authenication
+    because it never gets passed the absolute URI, which is needed for
+    signing. So instead we have to overload 'request' with a closure
+    that adds in the Authorization header and then calls the original
+    version of 'request()'.
     """
     request_orig = http.request
 
@@ -284,12 +367,12 @@
                     redirections=httplib2.DEFAULT_MAX_REDIRECTS,
                     connection_type=None):
       if not self.access_token:
-        logging.info("Attempting refresh to obtain initial access_token")
+        logger.info('Attempting refresh to obtain initial access_token')
         self._refresh(request_orig)
 
-      """Modify the request headers to add the appropriate
-      Authorization header."""
-      if headers == None:
+      # Modify the request headers to add the appropriate
+      # Authorization header.
+      if headers is None:
         headers = {}
       headers['authorization'] = 'OAuth ' + self.access_token
 
@@ -303,7 +386,7 @@
                                    redirections, connection_type)
 
       if resp.status == 401:
-        logging.info("Refreshing because we got a 401")
+        logger.info('Refreshing due to a 401')
         self._refresh(request_orig)
         headers['authorization'] = 'OAuth ' + self.access_token
         return request_orig(uri, method, body, headers,
@@ -316,14 +399,15 @@
 
 
 class AccessTokenCredentials(OAuth2Credentials):
-  """Credentials object for OAuth 2.0
+  """Credentials object for OAuth 2.0.
 
-  Credentials can be applied to an httplib2.Http object using the authorize()
-  method, which then signs each request from that object with the OAuth 2.0
-  access token.  This set of credentials is for the use case where you have
-  acquired an OAuth 2.0 access_token from another place such as a JavaScript
-  client or another web application, and wish to use it from Python. Because
-  only the access_token is present it can not be refreshed and will in time
+  Credentials can be applied to an httplib2.Http object using the
+  authorize() method, which then signs each request from that object
+  with the OAuth 2.0 access token.  This set of credentials is for the
+  use case where you have acquired an OAuth 2.0 access_token from
+  another place such as a JavaScript client or another web
+  application, and wish to use it from Python. Because only the
+  access_token is present it can not be refreshed and will in time
   expire.
 
   AccessTokenCredentials objects may be safely pickled and unpickled.
@@ -368,19 +452,20 @@
 
 
 class AssertionCredentials(OAuth2Credentials):
-  """Abstract Credentials object used for OAuth 2.0 assertion grants
+  """Abstract Credentials object used for OAuth 2.0 assertion grants.
 
-  This credential does not require a flow to instantiate because it represents
-  a two legged flow, and therefore has all of the required information to
-  generate and refresh its own access tokens.  It must be subclassed to
-  generate the appropriate assertion string.
+  This credential does not require a flow to instantiate because it
+  represents a two legged flow, and therefore has all of the required
+  information to generate and refresh its own access tokens.  It must
+  be subclassed to generate the appropriate assertion string.
 
   AssertionCredentials objects may be safely pickled and unpickled.
   """
 
   def __init__(self, assertion_type, user_agent,
-      token_uri='https://accounts.google.com/o/oauth2/token', **kwargs):
-    """Constructor for AssertionFlowCredentials
+               token_uri='https://accounts.google.com/o/oauth2/token',
+               **unused_kwargs):
+    """Constructor for AssertionFlowCredentials.
 
     Args:
       assertion_type: string, assertion type that will be declared to the auth
@@ -403,10 +488,10 @@
     assertion = self._generate_assertion()
 
     body = urllib.urlencode({
-      'assertion_type': self.assertion_type,
-      'assertion': assertion,
-      'grant_type': "assertion",
-    })
+        'assertion_type': self.assertion_type,
+        'assertion': assertion,
+        'grant_type': 'assertion',
+        })
 
     return body
 
@@ -424,10 +509,10 @@
   """
 
   def __init__(self, client_id, client_secret, scope, user_agent,
-      auth_uri='https://accounts.google.com/o/oauth2/auth',
-      token_uri='https://accounts.google.com/o/oauth2/token',
-      **kwargs):
-    """Constructor for OAuth2WebServerFlow
+               auth_uri='https://accounts.google.com/o/oauth2/auth',
+               token_uri='https://accounts.google.com/o/oauth2/token',
+               **kwargs):
+    """Constructor for OAuth2WebServerFlow.
 
     Args:
       client_id: string, client identifier.
@@ -466,11 +551,11 @@
 
     self.redirect_uri = redirect_uri
     query = {
-      'response_type': 'code',
-      'client_id': self.client_id,
-      'redirect_uri': redirect_uri,
-      'scope': self.scope,
-      }
+        'response_type': 'code',
+        'client_id': self.client_id,
+        'redirect_uri': redirect_uri,
+        'scope': self.scope,
+        }
     query.update(self.params)
     parts = list(urlparse.urlparse(self.auth_uri))
     query.update(dict(parse_qsl(parts[4]))) # 4 is the index of the query part
@@ -491,15 +576,16 @@
       code = code['code']
 
     body = urllib.urlencode({
-      'grant_type': 'authorization_code',
-      'client_id': self.client_id,
-      'client_secret': self.client_secret,
-      'code': code,
-      'redirect_uri': self.redirect_uri,
-      'scope': self.scope,
-      })
+        'grant_type': 'authorization_code',
+        'client_id': self.client_id,
+        'client_secret': self.client_secret,
+        'code': code,
+        'redirect_uri': self.redirect_uri,
+        'scope': self.scope,
+        })
     headers = {
-      'content-type': 'application/x-www-form-urlencoded',
+        'user-agent': self.user_agent,
+        'content-type': 'application/x-www-form-urlencoded',
     }
 
     if self.user_agent is not None:
@@ -519,12 +605,12 @@
         token_expiry = datetime.datetime.now() + datetime.timedelta(
             seconds=int(d['expires_in']))
 
-      logging.info('Successfully retrieved access token: %s' % content)
+      logger.info('Successfully retrieved access token: %s' % content)
       return OAuth2Credentials(access_token, self.client_id,
                                self.client_secret, refresh_token, token_expiry,
                                self.token_uri, self.user_agent)
     else:
-      logging.error('Failed to retrieve access token: %s' % content)
+      logger.error('Failed to retrieve access token: %s' % content)
       error_msg = 'Invalid response %s.' % resp['status']
       try:
         d = simplejson.loads(content)