Make decorators thread-safe.
Reviewed in https://codereview.appspot.com/9363044/.
diff --git a/oauth2client/appengine.py b/oauth2client/appengine.py
index 5bf5626..2234cf8 100644
--- a/oauth2client/appengine.py
+++ b/oauth2client/appengine.py
@@ -25,6 +25,7 @@
import logging
import os
import pickle
+import threading
import time
from google.appengine.api import app_identity
@@ -570,6 +571,22 @@
"""
+ def set_credentials(self, credentials):
+ self._tls.credentials = credentials
+
+ def get_credentials(self):
+ return self._tls.credentials
+
+ def set_flow(self, flow):
+ self._tls.flow = flow
+
+ def get_flow(self):
+ return self._tls.flow
+
+ flow = property(get_flow, set_flow)
+ credentials = property(get_credentials, set_credentials)
+
+
@util.positional(4)
def __init__(self, client_id, client_secret, scope,
auth_uri=GOOGLE_AUTH_URI,
@@ -621,6 +638,7 @@
**kwargs: dict, Keyword arguments are be passed along as kwargs to the
OAuth2WebServerFlow constructor.
"""
+ self._tls = threading.local()
self.flow = None
self.credentials = None
self._client_id = client_id
@@ -678,9 +696,12 @@
if not self.has_credentials():
return request_handler.redirect(self.authorize_url())
try:
- return method(request_handler, *args, **kwargs)
+ resp = method(request_handler, *args, **kwargs)
except AccessTokenRefreshError:
return request_handler.redirect(self.authorize_url())
+ finally:
+ self.credentials = None
+ return resp
return check_oauth
@@ -737,9 +758,14 @@
self.credentials = self._storage_class(
self._credentials_class, None,
self._credentials_property_name, user=user).get()
- return method(request_handler, *args, **kwargs)
+ try:
+ resp = method(request_handler, *args, **kwargs)
+ finally:
+ self.credentials = None
+ return resp
return setup_oauth
+
def has_credentials(self):
"""True if for the logged in user there are valid access Credentials.
diff --git a/tests/test_oauth2client_appengine.py b/tests/test_oauth2client_appengine.py
index b6541ba..b99bd8c 100644
--- a/tests/test_oauth2client_appengine.py
+++ b/tests/test_oauth2client_appengine.py
@@ -441,20 +441,31 @@
def _finish_setup(self, decorator, user_mock):
self.decorator = decorator
+ self.had_credentials = False
+ self.found_credentials = None
+ self.should_raise = False
+ parent = self
class TestRequiredHandler(webapp2.RequestHandler):
-
@decorator.oauth_required
def get(self):
- pass
+ if decorator.has_credentials():
+ parent.had_credentials = True
+ parent.found_credentials = decorator.credentials
+ if parent.should_raise:
+ raise Exception('')
class TestAwareHandler(webapp2.RequestHandler):
-
@decorator.oauth_aware
def get(self, *args, **kwargs):
self.response.out.write('Hello World!')
assert(kwargs['year'] == '2012')
assert(kwargs['month'] == '01')
+ if decorator.has_credentials():
+ parent.had_credentials = True
+ parent.found_credentials = decorator.credentials
+ if parent.should_raise:
+ raise Exception('')
application = webapp2.WSGIApplication([
@@ -507,6 +518,9 @@
response = parse_qs(parts[1])[self.decorator._token_response_param][0]
self.assertEqual(Http2Mock.content,
simplejson.loads(urllib.unquote(response)))
+ self.assertEqual(self.decorator.flow, self.decorator._tls.flow)
+ self.assertEqual(self.decorator.credentials,
+ self.decorator._tls.credentials)
m.UnsetStubs()
m.VerifyAll()
@@ -514,15 +528,26 @@
# Now requesting the decorated path should work.
response = self.app.get('/foo_path')
self.assertEqual('200 OK', response.status)
- self.assertEqual(True, self.decorator.has_credentials())
+ self.assertEqual(True, self.had_credentials)
self.assertEqual('foo_refresh_token',
- self.decorator.credentials.refresh_token)
+ self.found_credentials.refresh_token)
self.assertEqual('foo_access_token',
- self.decorator.credentials.access_token)
+ self.found_credentials.access_token)
+ self.assertEqual(None, self.decorator.credentials)
+
+ # Raising an exception still clears the Credentials.
+ self.should_raise = True
+ try:
+ response = self.app.get('/foo_path')
+ self.fail('Should have raised an exception.')
+ except Exception:
+ pass
+ self.assertEqual(None, self.decorator.credentials)
+ self.should_raise = False
# Invalidate the stored Credentials.
- self.decorator.credentials.invalid = True
- self.decorator.credentials.store.put(self.decorator.credentials)
+ self.found_credentials.invalid = True
+ self.found_credentials.store.put(self.found_credentials)
# Invalid Credentials should start the OAuth dance again.
response = self.app.get('/foo_path')
@@ -553,8 +578,13 @@
# Now requesting the decorated path should work.
response = self.app.get('/foo_path')
+ self.assertTrue(self.had_credentials)
+
+ # Credentials should be cleared after each call.
+ self.assertEqual(None, self.decorator.credentials)
+
# Invalidate the stored Credentials.
- self.decorator.credentials.store.delete()
+ self.found_credentials.store.delete()
# Invalid Credentials should start the OAuth dance again.
response = self.app.get('/foo_path')
@@ -600,11 +630,25 @@
response = self.app.get('/bar_path/2012/01')
self.assertEqual('200 OK', response.status)
self.assertEqual('Hello World!', response.body)
- self.assertEqual(True, self.decorator.has_credentials())
+ self.assertEqual(True, self.had_credentials)
self.assertEqual('foo_refresh_token',
- self.decorator.credentials.refresh_token)
+ self.found_credentials.refresh_token)
self.assertEqual('foo_access_token',
- self.decorator.credentials.access_token)
+ self.found_credentials.access_token)
+
+ # Credentials should be cleared after each call.
+ self.assertEqual(None, self.decorator.credentials)
+
+ # Raising an exception still clears the Credentials.
+ self.should_raise = True
+ try:
+ response = self.app.get('/bar_path/2012/01')
+ self.fail('Should have raised an exception.')
+ except Exception:
+ pass
+ self.assertEqual(None, self.decorator.credentials)
+ self.should_raise = False
+
def test_error_in_step2(self):
# An initial request to an oauth_aware decorated path should not redirect.
@@ -634,6 +678,7 @@
self.assertEqual('foo_user_agent', decorator.flow.user_agent)
self.assertEqual('dummy_revoke_uri', decorator.flow.revoke_uri)
self.assertEqual(None, decorator.flow.params.get('user_agent', None))
+ self.assertEqual(decorator.flow, decorator._tls.flow)
def test_token_response_param(self):
self.decorator._token_response_param = 'foobar'