Add fancy locking to oauth2client.
Reviewed in http://codereview.appspot.com/4919049/
diff --git a/oauth2client/appengine.py b/oauth2client/appengine.py
index 439a579..64fd3ac 100644
--- a/oauth2client/appengine.py
+++ b/oauth2client/appengine.py
@@ -23,7 +23,6 @@
import pickle
import time
import base64
-import logging
try: # pragma: no cover
import simplejson
@@ -222,7 +221,7 @@
entity = self._model.get_or_insert(self._key_name)
credential = getattr(entity, self._property_name)
if credential and hasattr(credential, 'set_store'):
- credential.set_store(self.put)
+ credential.set_store(self)
if self._cache:
self._cache.set(self._key_name, pickle.dumps(credentials))
diff --git a/oauth2client/client.py b/oauth2client/client.py
index f547428..894bfb4 100644
--- a/oauth2client/client.py
+++ b/oauth2client/client.py
@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""An OAuth 2.0 client
+"""An OAuth 2.0 client.
-Tools for interacting with OAuth 2.0 protected
-resources.
+Tools for interacting with OAuth 2.0 protected resources.
"""
__author__ = 'jcgregorio@google.com (Joe Gregorio)'
@@ -27,9 +26,9 @@
import urllib
import urlparse
-try: # pragma: no cover
+try: # pragma: no cover
import simplejson
-except ImportError: # pragma: no cover
+except ImportError: # pragma: no cover
try:
# Try to import from django, should work on App Engine
from django.utils import simplejson
@@ -38,9 +37,11 @@
import json as simplejson
try:
- from urlparse import parse_qsl
+ from urlparse import parse_qsl
except ImportError:
- from cgi import parse_qsl
+ from cgi import parse_qsl
+
+logger = logging.getLogger(__name__)
class Error(Exception):
@@ -92,28 +93,76 @@
class Storage(object):
"""Base class for all Storage objects.
- Store and retrieve a single credential.
+ Store and retrieve a single credential. This class supports locking
+ such that multiple processes and threads can operate on a single
+ store.
"""
- def get(self):
+ def acquire_lock(self):
+ """Acquires any lock necessary to access this Storage.
+
+ This lock is not reentrant."""
+ pass
+
+ def release_lock(self):
+ """Release the Storage lock.
+
+ Trying to release a lock that isn't held will result in a
+ RuntimeError.
+ """
+ pass
+
+ def locked_get(self):
"""Retrieve credential.
+ The Storage lock must be held when this is called.
+
Returns:
oauth2client.client.Credentials
"""
_abstract()
- def put(self, credentials):
+ def locked_put(self, credentials):
"""Write a credential.
+ The Storage lock must be held when this is called.
+
Args:
credentials: Credentials, the credentials to store.
"""
_abstract()
+ def get(self):
+ """Retrieve credential.
+
+ The Storage lock must *not* be held when this is called.
+
+ Returns:
+ oauth2client.client.Credentials
+ """
+ self.acquire_lock()
+ try:
+ return self.locked_get()
+ finally:
+ self.release_lock()
+
+ def put(self, credentials):
+ """Write a credential.
+
+ The Storage lock must be held when this is called.
+
+ Args:
+ credentials: Credentials, the credentials to store.
+ """
+ self.acquire_lock()
+ try:
+ self.locked_put(credentials)
+ finally:
+ self.release_lock()
+
class OAuth2Credentials(Credentials):
- """Credentials object for OAuth 2.0
+ """Credentials object for OAuth 2.0.
Credentials can be applied to an httplib2.Http object using the authorize()
method, which then signs each request from that object with the OAuth 2.0
@@ -123,22 +172,21 @@
"""
def __init__(self, access_token, client_id, client_secret, refresh_token,
- token_expiry, token_uri, user_agent):
- """Create an instance of OAuth2Credentials
+ token_expiry, token_uri, user_agent):
+ """Create an instance of OAuth2Credentials.
This constructor is not usually called by the user, instead
OAuth2Credentials objects are instantiated by the OAuth2WebServerFlow.
Args:
- token_uri: string, URI of token endpoint.
+ access_token: string, access token.
client_id: string, client identifier.
client_secret: string, client secret.
- access_token: string, access token.
- token_expiry: datetime, when the access_token expires.
refresh_token: string, refresh token.
+ token_expiry: datetime, when the access_token expires.
+ token_uri: string, URI of token endpoint.
user_agent: string, The HTTP User-Agent to provide for this application.
-
Notes:
store: callable, a callable that when passed a Credential
will store the credential back to where it came from.
@@ -156,51 +204,66 @@
# True if the credentials have been revoked or expired and can't be
# refreshed.
- self._invalid = False
+ self.invalid = False
@property
- def invalid(self):
- """True if the credentials are invalid, such as being revoked."""
- return getattr(self, '_invalid', False)
+ def access_token_expired(self):
+ """True if the credential is expired or invalid.
+
+ If the token_expiry isn't set, we assume the token doesn't expire.
+ """
+ if self.invalid:
+ return True
+
+ if not self.token_expiry:
+ return False
+
+ now = datetime.datetime.now()
+ if now >= self.token_expiry:
+ logger.info('access_token is expired. Now: %s, token_expiry: %s',
+ now, self.token_expiry)
+ return True
+ return False
def set_store(self, store):
- """Set the storage for the credential.
+ """Set the Storage for the credential.
Args:
- store: callable, a callable that when passed a Credential
- will store the credential back to where it came from.
+ store: Storage, an implementation of Stroage object.
This is needed to store the latest access_token if it
- has expired and been refreshed.
+ has expired and been refreshed. This implementation uses
+ locking to check for updates before updating the
+ access_token.
"""
self.store = store
+ def _updateFromCredential(self, other):
+ """Update this Credential from another instance."""
+ self.__dict__.update(other.__getstate__())
+
def __getstate__(self):
- """Trim the state down to something that can be pickled.
- """
+ """Trim the state down to something that can be pickled."""
d = copy.copy(self.__dict__)
del d['store']
return d
def __setstate__(self, state):
- """Reconstitute the state of the object from being pickled.
- """
+ """Reconstitute the state of the object from being pickled."""
self.__dict__.update(state)
self.store = None
def _generate_refresh_request_body(self):
- """Generate the body that will be used in the refresh request
- """
+ """Generate the body that will be used in the refresh request."""
body = urllib.urlencode({
- 'grant_type': 'refresh_token',
- 'client_id': self.client_id,
- 'client_secret': self.client_secret,
- 'refresh_token': self.refresh_token,
- })
+ 'grant_type': 'refresh_token',
+ 'client_id': self.client_id,
+ 'client_secret': self.client_secret,
+ 'refresh_token': self.refresh_token,
+ })
return body
def _generate_refresh_request_headers(self):
- """Generate the headers that will be used in the refresh request
- """
+ """Generate the headers that will be used in the refresh request."""
headers = {
'content-type': 'application/x-www-form-urlencoded',
}
@@ -211,16 +274,41 @@
return headers
def _refresh(self, http_request):
+ """Refreshes the access_token.
+
+ This method first checks by reading the Storage object if available.
+ If a refresh is still needed, it holds the Storage lock until the
+ refresh is completed.
+ """
+ if not self.store:
+ self._do_refresh_request(http_request)
+ else:
+ self.store.acquire_lock()
+ try:
+ new_cred = self.store.locked_get()
+ if (new_cred and not new_cred.invalid and
+ new_cred.access_token != self.access_token):
+ logger.info('Updated access_token read from Storage')
+ self._updateFromCredential(new_cred)
+ else:
+ self._do_refresh_request(http_request)
+ finally:
+ self.store.release_lock()
+
+ def _do_refresh_request(self, http_request):
"""Refresh the access_token using the refresh_token.
Args:
http: An instance of httplib2.Http.request
or something that acts like it.
+
+ Raises:
+ AccessTokenRefreshError: When the refresh fails.
"""
body = self._generate_refresh_request_body()
headers = self._generate_refresh_request_headers()
- logging.info("Refresing access_token")
+ logger.info('Refresing access_token')
resp, content = http_request(
self.token_uri, method='POST', body=body, headers=headers)
if resp.status == 200:
@@ -233,23 +321,20 @@
seconds=int(d['expires_in'])) + datetime.datetime.now()
else:
self.token_expiry = None
- if self.store is not None:
- self.store(self)
+ if self.store:
+ self.store.locked_put(self)
else:
# An {'error':...} response body means the token is expired or revoked,
# so we flag the credentials as such.
- logging.error('Failed to retrieve access token: %s' % content)
+ logger.error('Failed to retrieve access token: %s' % content)
error_msg = 'Invalid response %s.' % resp['status']
try:
d = simplejson.loads(content)
if 'error' in d:
error_msg = d['error']
- self._invalid = True
- if self.store is not None:
- self.store(self)
- else:
- logging.warning(
- "Unable to store refreshed credentials, no Storage provided.")
+ self.invalid = True
+ if self.store:
+ self.store.locked_put(self)
except:
pass
raise AccessTokenRefreshError(error_msg)
@@ -269,13 +354,11 @@
h = httplib2.Http()
h = credentials.authorize(h)
- You can't create a new OAuth
- subclass of httplib2.Authenication because
- it never gets passed the absolute URI, which is
- needed for signing. So instead we have to overload
- 'request' with a closure that adds in the
- Authorization header and then calls the original version
- of 'request()'.
+ You can't create a new OAuth subclass of httplib2.Authenication
+ because it never gets passed the absolute URI, which is needed for
+ signing. So instead we have to overload 'request' with a closure
+ that adds in the Authorization header and then calls the original
+ version of 'request()'.
"""
request_orig = http.request
@@ -284,12 +367,12 @@
redirections=httplib2.DEFAULT_MAX_REDIRECTS,
connection_type=None):
if not self.access_token:
- logging.info("Attempting refresh to obtain initial access_token")
+ logger.info('Attempting refresh to obtain initial access_token')
self._refresh(request_orig)
- """Modify the request headers to add the appropriate
- Authorization header."""
- if headers == None:
+ # Modify the request headers to add the appropriate
+ # Authorization header.
+ if headers is None:
headers = {}
headers['authorization'] = 'OAuth ' + self.access_token
@@ -303,7 +386,7 @@
redirections, connection_type)
if resp.status == 401:
- logging.info("Refreshing because we got a 401")
+ logger.info('Refreshing due to a 401')
self._refresh(request_orig)
headers['authorization'] = 'OAuth ' + self.access_token
return request_orig(uri, method, body, headers,
@@ -316,14 +399,15 @@
class AccessTokenCredentials(OAuth2Credentials):
- """Credentials object for OAuth 2.0
+ """Credentials object for OAuth 2.0.
- Credentials can be applied to an httplib2.Http object using the authorize()
- method, which then signs each request from that object with the OAuth 2.0
- access token. This set of credentials is for the use case where you have
- acquired an OAuth 2.0 access_token from another place such as a JavaScript
- client or another web application, and wish to use it from Python. Because
- only the access_token is present it can not be refreshed and will in time
+ Credentials can be applied to an httplib2.Http object using the
+ authorize() method, which then signs each request from that object
+ with the OAuth 2.0 access token. This set of credentials is for the
+ use case where you have acquired an OAuth 2.0 access_token from
+ another place such as a JavaScript client or another web
+ application, and wish to use it from Python. Because only the
+ access_token is present it can not be refreshed and will in time
expire.
AccessTokenCredentials objects may be safely pickled and unpickled.
@@ -368,19 +452,20 @@
class AssertionCredentials(OAuth2Credentials):
- """Abstract Credentials object used for OAuth 2.0 assertion grants
+ """Abstract Credentials object used for OAuth 2.0 assertion grants.
- This credential does not require a flow to instantiate because it represents
- a two legged flow, and therefore has all of the required information to
- generate and refresh its own access tokens. It must be subclassed to
- generate the appropriate assertion string.
+ This credential does not require a flow to instantiate because it
+ represents a two legged flow, and therefore has all of the required
+ information to generate and refresh its own access tokens. It must
+ be subclassed to generate the appropriate assertion string.
AssertionCredentials objects may be safely pickled and unpickled.
"""
def __init__(self, assertion_type, user_agent,
- token_uri='https://accounts.google.com/o/oauth2/token', **kwargs):
- """Constructor for AssertionFlowCredentials
+ token_uri='https://accounts.google.com/o/oauth2/token',
+ **unused_kwargs):
+ """Constructor for AssertionFlowCredentials.
Args:
assertion_type: string, assertion type that will be declared to the auth
@@ -403,10 +488,10 @@
assertion = self._generate_assertion()
body = urllib.urlencode({
- 'assertion_type': self.assertion_type,
- 'assertion': assertion,
- 'grant_type': "assertion",
- })
+ 'assertion_type': self.assertion_type,
+ 'assertion': assertion,
+ 'grant_type': 'assertion',
+ })
return body
@@ -424,10 +509,10 @@
"""
def __init__(self, client_id, client_secret, scope, user_agent,
- auth_uri='https://accounts.google.com/o/oauth2/auth',
- token_uri='https://accounts.google.com/o/oauth2/token',
- **kwargs):
- """Constructor for OAuth2WebServerFlow
+ auth_uri='https://accounts.google.com/o/oauth2/auth',
+ token_uri='https://accounts.google.com/o/oauth2/token',
+ **kwargs):
+ """Constructor for OAuth2WebServerFlow.
Args:
client_id: string, client identifier.
@@ -466,11 +551,11 @@
self.redirect_uri = redirect_uri
query = {
- 'response_type': 'code',
- 'client_id': self.client_id,
- 'redirect_uri': redirect_uri,
- 'scope': self.scope,
- }
+ 'response_type': 'code',
+ 'client_id': self.client_id,
+ 'redirect_uri': redirect_uri,
+ 'scope': self.scope,
+ }
query.update(self.params)
parts = list(urlparse.urlparse(self.auth_uri))
query.update(dict(parse_qsl(parts[4]))) # 4 is the index of the query part
@@ -491,15 +576,16 @@
code = code['code']
body = urllib.urlencode({
- 'grant_type': 'authorization_code',
- 'client_id': self.client_id,
- 'client_secret': self.client_secret,
- 'code': code,
- 'redirect_uri': self.redirect_uri,
- 'scope': self.scope,
- })
+ 'grant_type': 'authorization_code',
+ 'client_id': self.client_id,
+ 'client_secret': self.client_secret,
+ 'code': code,
+ 'redirect_uri': self.redirect_uri,
+ 'scope': self.scope,
+ })
headers = {
- 'content-type': 'application/x-www-form-urlencoded',
+ 'user-agent': self.user_agent,
+ 'content-type': 'application/x-www-form-urlencoded',
}
if self.user_agent is not None:
@@ -519,12 +605,12 @@
token_expiry = datetime.datetime.now() + datetime.timedelta(
seconds=int(d['expires_in']))
- logging.info('Successfully retrieved access token: %s' % content)
+ logger.info('Successfully retrieved access token: %s' % content)
return OAuth2Credentials(access_token, self.client_id,
self.client_secret, refresh_token, token_expiry,
self.token_uri, self.user_agent)
else:
- logging.error('Failed to retrieve access token: %s' % content)
+ logger.error('Failed to retrieve access token: %s' % content)
error_msg = 'Invalid response %s.' % resp['status']
try:
d = simplejson.loads(content)
diff --git a/oauth2client/django_orm.py b/oauth2client/django_orm.py
index c818ea2..581fe8e 100644
--- a/oauth2client/django_orm.py
+++ b/oauth2client/django_orm.py
@@ -99,7 +99,7 @@
if len(entities) > 0:
credential = getattr(entities[0], self.property_name)
if credential and hasattr(credential, 'set_store'):
- credential.set_store(self.put)
+ credential.set_store(self)
return credential
def put(self, credentials):
diff --git a/oauth2client/file.py b/oauth2client/file.py
index da666c4..b7f9c7d 100644
--- a/oauth2client/file.py
+++ b/oauth2client/file.py
@@ -44,7 +44,7 @@
f = open(self._filename, 'r')
credentials = pickle.loads(f.read())
f.close()
- credentials.set_store(self.put)
+ credentials.set_store(self)
except:
credentials = None
self._lock.release()
diff --git a/oauth2client/multistore_file.py b/oauth2client/multistore_file.py
new file mode 100644
index 0000000..8841194
--- /dev/null
+++ b/oauth2client/multistore_file.py
@@ -0,0 +1,361 @@
+# Copyright 2011 Google Inc. All Rights Reserved.
+
+"""Multi-credential file store with lock support.
+
+This module implements a JSON credential store where multiple
+credentials can be stored in one file. That file supports locking
+both in a single process and across processes.
+
+The credential themselves are keyed off of:
+* client_id
+* user_agent
+* scope
+
+The format of the stored data is like so:
+{
+ 'file_version': 1,
+ 'data': [
+ {
+ 'key': {
+ 'clientId': '<client id>',
+ 'userAgent': '<user agent>',
+ 'scope': '<scope>'
+ },
+ 'credential': '<base64 encoding of pickeled Credential object>'
+ }
+ ]
+}
+"""
+
+__author__ = 'jbeda@google.com (Joe Beda)'
+
+import base64
+import fcntl
+import logging
+import os
+import pickle
+import threading
+
+try: # pragma: no cover
+ import simplejson
+except ImportError: # pragma: no cover
+ try:
+ # Try to import from django, should work on App Engine
+ from django.utils import simplejson
+ except ImportError:
+ # Should work for Python2.6 and higher.
+ import json as simplejson
+
+from client import Storage as BaseStorage
+
+logger = logging.getLogger(__name__)
+
+# A dict from 'filename'->_MultiStore instances
+_multistores = {}
+_multistores_lock = threading.Lock()
+
+
+class Error(Exception):
+ """Base error for this module."""
+ pass
+
+
+class NewerCredentialStoreError(Error):
+ """The credential store is a newer version that supported."""
+ pass
+
+
+def get_credential_storage(filename, client_id, user_agent, scope,
+ warn_on_readonly=True):
+ """Get a Storage instance for a credential.
+
+ Args:
+ filename: The JSON file storing a set of credentials
+ client_id: The client_id for the credential
+ user_agent: The user agent for the credential
+ scope: A string for the scope being requested
+ warn_on_readonly: if True, log a warning if the store is readonly
+
+ Returns:
+ An object derived from client.Storage for getting/setting the
+ credential.
+ """
+ filename = os.path.realpath(os.path.expanduser(filename))
+ _multistores_lock.acquire()
+ try:
+ multistore = _multistores.setdefault(
+ filename, _MultiStore(filename, warn_on_readonly))
+ finally:
+ _multistores_lock.release()
+ return multistore._get_storage(client_id, user_agent, scope)
+
+
+class _MultiStore(object):
+ """A file backed store for multiple credentials."""
+
+ def __init__(self, filename, warn_on_readonly=True):
+ """Initialize the class.
+
+ This will create the file if necessary.
+ """
+ self._filename = filename
+ self._thread_lock = threading.Lock()
+ self._file_handle = None
+ self._read_only = False
+ self._warn_on_readonly = warn_on_readonly
+
+ self._create_file_if_needed()
+
+ # Cache of deserialized store. This is only valid after the
+ # _MultiStore is locked or _refresh_data_cache is called. This is
+ # of the form of:
+ #
+ # (client_id, user_agent, scope) -> OAuth2Credential
+ #
+ # If this is None, then the store hasn't been read yet.
+ self._data = None
+
+ class _Storage(BaseStorage):
+ """A Storage object that knows how to read/write a single credential."""
+
+ def __init__(self, multistore, client_id, user_agent, scope):
+ self._multistore = multistore
+ self._client_id = client_id
+ self._user_agent = user_agent
+ self._scope = scope
+
+ def acquire_lock(self):
+ """Acquires any lock necessary to access this Storage.
+
+ This lock is not reentrant.
+ """
+ self._multistore._lock()
+
+ def release_lock(self):
+ """Release the Storage lock.
+
+ Trying to release a lock that isn't held will result in a
+ RuntimeError.
+ """
+ self._multistore._unlock()
+
+ def locked_get(self):
+ """Retrieve credential.
+
+ The Storage lock must be held when this is called.
+
+ Returns:
+ oauth2client.client.Credentials
+ """
+ credential = self._multistore._get_credential(
+ self._client_id, self._user_agent, self._scope)
+ if credential:
+ credential.set_store(self)
+ return credential
+
+ def locked_put(self, credentials):
+ """Write a credential.
+
+ The Storage lock must be held when this is called.
+
+ Args:
+ credentials: Credentials, the credentials to store.
+ """
+ self._multistore._update_credential(credentials, self._scope)
+
+ def _create_file_if_needed(self):
+ """Create an empty file if necessary.
+
+ This method will not initialize the file. Instead it implements a
+ simple version of "touch" to ensure the file has been created.
+ """
+ if not os.path.exists(self._filename):
+ old_umask = os.umask(0177)
+ try:
+ open(self._filename, 'a+').close()
+ finally:
+ os.umask(old_umask)
+
+ def _lock(self):
+ """Lock the entire multistore."""
+ self._thread_lock.acquire()
+ # Check to see if the file is writeable.
+ if os.access(self._filename, os.W_OK):
+ self._file_handle = open(self._filename, 'r+')
+ fcntl.lockf(self._file_handle.fileno(), fcntl.LOCK_EX)
+ else:
+ # Cannot open in read/write mode. Open only in read mode.
+ self._file_handle = open(self._filename, 'r')
+ self._read_only = True
+ if self._warn_on_readonly:
+ logger.warn('The credentials file (%s) is not writable. Opening in '
+ 'read-only mode. Any refreshed credentials will only be '
+ 'valid for this run.' % self._filename)
+ if os.path.getsize(self._filename) == 0:
+ logger.debug('Initializing empty multistore file')
+ # The multistore is empty so write out an empty file.
+ self._data = {}
+ self._write()
+ elif not self._read_only or self._data is None:
+ # Only refresh the data if we are read/write or we haven't
+ # cached the data yet. If we are readonly, we assume is isn't
+ # changing out from under us and that we only have to read it
+ # once. This prevents us from whacking any new access keys that
+ # we have cached in memory but were unable to write out.
+ self._refresh_data_cache()
+
+ def _unlock(self):
+ """Release the lock on the multistore."""
+ if not self._read_only:
+ fcntl.lockf(self._file_handle.fileno(), fcntl.LOCK_UN)
+ self._file_handle.close()
+ self._thread_lock.release()
+
+ def _locked_json_read(self):
+ """Get the raw content of the multistore file.
+
+ The multistore must be locked when this is called.
+
+ Returns:
+ The contents of the multistore decoded as JSON.
+ """
+ assert self._thread_lock.locked()
+ self._file_handle.seek(0)
+ return simplejson.load(self._file_handle)
+
+ def _locked_json_write(self, data):
+ """Write a JSON serializable data structure to the multistore.
+
+ The multistore must be locked when this is called.
+
+ Args:
+ data: The data to be serialized and written.
+ """
+ assert self._thread_lock.locked()
+ if self._read_only:
+ return
+ self._file_handle.seek(0)
+ simplejson.dump(data, self._file_handle, sort_keys=True, indent=2)
+ self._file_handle.truncate()
+
+ def _refresh_data_cache(self):
+ """Refresh the contents of the multistore.
+
+ The multistore must be locked when this is called.
+
+ Raises:
+ NewerCredentialStoreError: Raised when a newer client has written the
+ store.
+ """
+ self._data = {}
+ try:
+ raw_data = self._locked_json_read()
+ except Exception:
+ logger.warn('Credential data store could not be loaded. '
+ 'Will ignore and overwrite.')
+ return
+
+ version = 0
+ try:
+ version = raw_data['file_version']
+ except Exception:
+ logger.warn('Missing version for credential data store. It may be '
+ 'corrupt or an old version. Overwriting.')
+ if version > 1:
+ raise NewerCredentialStoreError(
+ 'Credential file has file_version of %d. '
+ 'Only file_version of 1 is supported.' % version)
+
+ credentials = []
+ try:
+ credentials = raw_data['data']
+ except (TypeError, KeyError):
+ pass
+
+ for cred_entry in credentials:
+ try:
+ (key, credential) = self._decode_credential_from_json(cred_entry)
+ self._data[key] = credential
+ except:
+ # If something goes wrong loading a credential, just ignore it
+ logger.info('Error decoding credential, skipping', exc_info=True)
+
+ def _decode_credential_from_json(self, cred_entry):
+ """Load a credential from our JSON serialization.
+
+ Args:
+ cred_entry: A dict entry from the data member of our format
+
+ Returns:
+ (key, cred) where the key is the key tuple and the cred is the
+ OAuth2Credential object.
+ """
+ raw_key = cred_entry['key']
+ client_id = raw_key['clientId']
+ user_agent = raw_key['userAgent']
+ scope = raw_key['scope']
+ key = (client_id, user_agent, scope)
+ credential = pickle.loads(base64.b64decode(cred_entry['credential']))
+ return (key, credential)
+
+ def _write(self):
+ """Write the cached data back out.
+
+ The multistore must be locked.
+ """
+ raw_data = {'file_version': 1}
+ raw_creds = []
+ raw_data['data'] = raw_creds
+ for (cred_key, cred) in self._data.items():
+ raw_key = {
+ 'clientId': cred_key[0],
+ 'userAgent': cred_key[1],
+ 'scope': cred_key[2]
+ }
+ raw_cred = base64.b64encode(pickle.dumps(cred))
+ raw_creds.append({'key': raw_key, 'credential': raw_cred})
+ self._locked_json_write(raw_data)
+
+ def _get_credential(self, client_id, user_agent, scope):
+ """Get a credential from the multistore.
+
+ The multistore must be locked.
+
+ Args:
+ client_id: The client_id for the credential
+ user_agent: The user agent for the credential
+ scope: A string for the scope being requested
+
+ Returns:
+ The credential specified or None if not present
+ """
+ key = (client_id, user_agent, scope)
+ return self._data.get(key, None)
+
+ def _update_credential(self, cred, scope):
+ """Update a credential and write the multistore.
+
+ This must be called when the multistore is locked.
+
+ Args:
+ cred: The OAuth2Credential to update/set
+ scope: The scope that this credential covers
+ """
+ key = (cred.client_id, cred.user_agent, scope)
+ self._data[key] = cred
+ self._write()
+
+ def _get_storage(self, client_id, user_agent, scope):
+ """Get a Storage object to get/set a credential.
+
+ This Storage is a 'view' into the multistore.
+
+ Args:
+ client_id: The client_id for the credential
+ user_agent: The user agent for the credential
+ scope: A string for the scope being requested
+
+ Returns:
+ A Storage object that can be used to get/set this cred
+ """
+ return self._Storage(self, client_id, user_agent, scope)
diff --git a/oauth2client/tools.py b/oauth2client/tools.py
index f04d4c8..dc779b4 100644
--- a/oauth2client/tools.py
+++ b/oauth2client/tools.py
@@ -25,31 +25,30 @@
import BaseHTTPServer
import gflags
-import logging
import socket
import sys
from client import FlowExchangeError
try:
- from urlparse import parse_qsl
+ from urlparse import parse_qsl
except ImportError:
- from cgi import parse_qsl
+ from cgi import parse_qsl
FLAGS = gflags.FLAGS
gflags.DEFINE_boolean('auth_local_webserver', True,
- ('Run a local web server to handle redirects during '
+ ('Run a local web server to handle redirects during '
'OAuth authorization.'))
gflags.DEFINE_string('auth_host_name', 'localhost',
('Host name to use when running a local web server to '
- 'handle redirects during OAuth authorization.'))
+ 'handle redirects during OAuth authorization.'))
gflags.DEFINE_multi_int('auth_host_port', [8080, 8090],
- ('Port to use when running a local web server to '
- 'handle redirects during OAuth authorization.'))
+ ('Port to use when running a local web server to '
+ 'handle redirects during OAuth authorization.'))
class ClientRedirectServer(BaseHTTPServer.HTTPServer):
@@ -69,7 +68,7 @@
"""
def do_GET(s):
- """Handle a GET request
+ """Handle a GET request.
Parses the query parameters and prints a message
if the flow has completed. Note that we can't detect
@@ -106,8 +105,8 @@
for port in FLAGS.auth_host_port:
port_number = port
try:
- httpd = BaseHTTPServer.HTTPServer((FLAGS.auth_host_name, port),
- ClientRedirectHandler)
+ httpd = ClientRedirectServer((FLAGS.auth_host_name, port),
+ ClientRedirectHandler)
except socket.error, e:
pass
else:
@@ -126,10 +125,10 @@
print
if FLAGS.auth_local_webserver:
print 'If your browser is on a different machine then exit and re-run this'
- print 'application with the command-line parameter --noauth_local_webserver.'
+ print 'application with the command-line parameter '
+ print '--noauth_local_webserver.'
print
-
if FLAGS.auth_local_webserver:
httpd.handle_request()
if 'error' in httpd.query_params:
@@ -137,18 +136,15 @@
if 'code' in httpd.query_params:
code = httpd.query_params['code']
else:
- accepted = 'n'
- while accepted.lower() == 'n':
- accepted = raw_input('Have you authorized me? (y/n) ')
- code = raw_input('What is the verification code? ').strip()
+ code = raw_input('Enter verification code: ').strip()
try:
- credentials = flow.step2_exchange(code)
- except FlowExchangeError:
- sys.exit('The authentication has failed.')
+ credential = flow.step2_exchange(code)
+ except FlowExchangeError, e:
+ sys.exit('Authentication has failed: %s' % e)
- storage.put(credentials)
- credentials.set_store(storage.put)
- print "You have successfully authenticated."
+ storage.put(credential)
+ credential.set_store(storage)
+ print 'Authentication successful.'
- return credentials
+ return credential