Add jwt.OnDemandCredentials (#142)
diff --git a/google/auth/jwt.py b/google/auth/jwt.py
index 412f122..b1eb5fb 100644
--- a/google/auth/jwt.py
+++ b/google/auth/jwt.py
@@ -46,13 +46,17 @@
import datetime
import json
+import cachetools
+from six.moves import urllib
+
from google.auth import _helpers
from google.auth import _service_account_info
from google.auth import crypt
+from google.auth import exceptions
import google.auth.credentials
-
_DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds
+_DEFAULT_MAX_CACHE_SIZE = 10
def encode(signer, payload, header=None, key_id=None):
@@ -316,10 +320,10 @@
self._audience = audience
self._token_lifetime = token_lifetime
- if additional_claims is not None:
- self._additional_claims = additional_claims
- else:
- self._additional_claims = {}
+ if additional_claims is None:
+ additional_claims = {}
+
+ self._additional_claims = additional_claims
@classmethod
def _from_signer_and_info(cls, signer, info, **kwargs):
@@ -343,8 +347,7 @@
@classmethod
def from_service_account_info(cls, info, **kwargs):
- """Creates a Credentials instance from a dictionary containing service
- account info in Google format.
+ """Creates an Credentials instance from a dictionary.
Args:
info (Mapping[str, str]): The service account info in Google
@@ -487,3 +490,266 @@
@_helpers.copy_docstring(google.auth.credentials.Signing)
def signer(self):
return self._signer
+
+
+class OnDemandCredentials(
+ google.auth.credentials.Signing,
+ google.auth.credentials.Credentials):
+ """On-demand JWT credentials.
+
+ Like :class:`Credentials`, this class uses a JWT as the bearer token for
+ authentication. However, this class does not require the audience at
+ construction time. Instead, it will generate a new token on-demand for
+ each request using the request URI as the audience. It caches tokens
+ so that multiple requests to the same URI do not incur the overhead
+ of generating a new token every time.
+
+ This behavior is especially useful for `gRPC`_ clients. A gRPC service may
+ have multiple audience and gRPC clients may not know all of the audiences
+ required for accessing a particular service. With these credentials,
+ no knowledge of the audiences is required ahead of time.
+
+ .. _grpc: http://www.grpc.io/
+ """
+
+ def __init__(self, signer, issuer, subject,
+ additional_claims=None,
+ token_lifetime=_DEFAULT_TOKEN_LIFETIME_SECS,
+ max_cache_size=_DEFAULT_MAX_CACHE_SIZE):
+ """
+ Args:
+ signer (google.auth.crypt.Signer): The signer used to sign JWTs.
+ issuer (str): The `iss` claim.
+ subject (str): The `sub` claim.
+ additional_claims (Mapping[str, str]): Any additional claims for
+ the JWT payload.
+ token_lifetime (int): The amount of time in seconds for
+ which the token is valid. Defaults to 1 hour.
+ max_cache_size (int): The maximum number of JWT tokens to keep in
+ cache. Tokens are cached using :class:`cachetools.LRUCache`.
+ """
+ super(OnDemandCredentials, self).__init__()
+ self._signer = signer
+ self._issuer = issuer
+ self._subject = subject
+ self._token_lifetime = token_lifetime
+
+ if additional_claims is None:
+ additional_claims = {}
+
+ self._additional_claims = additional_claims
+ self._cache = cachetools.LRUCache(maxsize=max_cache_size)
+
+ @classmethod
+ def _from_signer_and_info(cls, signer, info, **kwargs):
+ """Creates an OnDemandCredentials instance from a signer and service
+ account info.
+
+ Args:
+ signer (google.auth.crypt.Signer): The signer used to sign JWTs.
+ info (Mapping[str, str]): The service account info.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ google.auth.jwt.OnDemandCredentials: The constructed credentials.
+
+ Raises:
+ ValueError: If the info is not in the expected format.
+ """
+ kwargs.setdefault('subject', info['client_email'])
+ kwargs.setdefault('issuer', info['client_email'])
+ return cls(signer, **kwargs)
+
+ @classmethod
+ def from_service_account_info(cls, info, **kwargs):
+ """Creates an OnDemandCredentials instance from a dictionary.
+
+ Args:
+ info (Mapping[str, str]): The service account info in Google
+ format.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ google.auth.jwt.OnDemandCredentials: The constructed credentials.
+
+ Raises:
+ ValueError: If the info is not in the expected format.
+ """
+ signer = _service_account_info.from_dict(
+ info, require=['client_email'])
+ return cls._from_signer_and_info(signer, info, **kwargs)
+
+ @classmethod
+ def from_service_account_file(cls, filename, **kwargs):
+ """Creates an OnDemandCredentials instance from a service account .json
+ file in Google format.
+
+ Args:
+ filename (str): The path to the service account .json file.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ google.auth.jwt.OnDemandCredentials: The constructed credentials.
+ """
+ info, signer = _service_account_info.from_filename(
+ filename, require=['client_email'])
+ return cls._from_signer_and_info(signer, info, **kwargs)
+
+ @classmethod
+ def from_signing_credentials(cls, credentials, **kwargs):
+ """Creates a new :class:`google.auth.jwt.OnDemandCredentials` instance
+ from an existing :class:`google.auth.credentials.Signing` instance.
+
+ The new instance will use the same signer as the existing instance and
+ will use the existing instance's signer email as the issuer and
+ subject by default.
+
+ Example::
+
+ svc_creds = service_account.Credentials.from_service_account_file(
+ 'service_account.json')
+ jwt_creds = jwt.OnDemandCredentials.from_signing_credentials(
+ svc_creds)
+
+ Args:
+ credentials (google.auth.credentials.Signing): The credentials to
+ use to construct the new credentials.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ google.auth.jwt.Credentials: A new Credentials instance.
+ """
+ kwargs.setdefault('issuer', credentials.signer_email)
+ kwargs.setdefault('subject', credentials.signer_email)
+ return cls(credentials.signer, **kwargs)
+
+ def with_claims(self, issuer=None, subject=None, additional_claims=None):
+ """Returns a copy of these credentials with modified claims.
+
+ Args:
+ issuer (str): The `iss` claim. If unspecified the current issuer
+ claim will be used.
+ subject (str): The `sub` claim. If unspecified the current subject
+ claim will be used.
+ additional_claims (Mapping[str, str]): Any additional claims for
+ the JWT payload. This will be merged with the current
+ additional claims.
+
+ Returns:
+ google.auth.jwt.OnDemandCredentials: A new credentials instance.
+ """
+ new_additional_claims = copy.deepcopy(self._additional_claims)
+ new_additional_claims.update(additional_claims or {})
+
+ return OnDemandCredentials(
+ self._signer,
+ issuer=issuer if issuer is not None else self._issuer,
+ subject=subject if subject is not None else self._subject,
+ additional_claims=new_additional_claims,
+ max_cache_size=self._cache.maxsize)
+
+ @property
+ def valid(self):
+ """Checks the validity of the credentials.
+
+ These credentials are always valid because it generates tokens on
+ demand.
+ """
+ return True
+
+ def _make_jwt_for_audience(self, audience):
+ """Make a new JWT for the given audience.
+
+ Args:
+ audience (str): The intended audience.
+
+ Returns:
+ Tuple[bytes, datetime]: The encoded JWT and the expiration.
+ """
+ now = _helpers.utcnow()
+ lifetime = datetime.timedelta(seconds=self._token_lifetime)
+ expiry = now + lifetime
+
+ payload = {
+ 'iss': self._issuer,
+ 'sub': self._subject,
+ 'iat': _helpers.datetime_to_secs(now),
+ 'exp': _helpers.datetime_to_secs(expiry),
+ 'aud': audience,
+ }
+
+ payload.update(self._additional_claims)
+
+ jwt = encode(self._signer, payload)
+
+ return jwt, expiry
+
+ def _get_jwt_for_audience(self, audience):
+ """Get a JWT For a given audience.
+
+ If there is already an existing, non-expired token in the cache for
+ the audience, that token is used. Otherwise, a new token will be
+ created.
+
+ Args:
+ audience (str): The intended audience.
+
+ Returns:
+ bytes: The encoded JWT.
+ """
+ token, expiry = self._cache.get(audience, (None, None))
+
+ if token is None or expiry < _helpers.utcnow():
+ token, expiry = self._make_jwt_for_audience(audience)
+ self._cache[audience] = token, expiry
+
+ return token
+
+ def refresh(self, request):
+ """Raises an exception, these credentials can not be directly
+ refreshed.
+
+ Args:
+ request (Any): Unused.
+
+ Raises:
+ google.auth.RefreshError
+ """
+ # pylint: disable=unused-argument
+ # (pylint doesn't correctly recognize overridden methods.)
+ raise exceptions.RefreshError(
+ 'OnDemandCredentials can not be directly refreshed.')
+
+ def before_request(self, request, method, url, headers):
+ """Performs credential-specific before request logic.
+
+ Args:
+ request (Any): Unused. JWT credentials do not need to make an
+ HTTP request to refresh.
+ method (str): The request's HTTP method.
+ url (str): The request's URI. This is used as the audience claim
+ when generating the JWT.
+ headers (Mapping): The request's headers.
+ """
+ # pylint: disable=unused-argument
+ # (pylint doesn't correctly recognize overridden methods.)
+ parts = urllib.parse.urlsplit(url)
+ # Strip query string and fragment
+ audience = urllib.parse.urlunsplit(
+ (parts.scheme, parts.netloc, parts.path, None, None))
+ token = self._get_jwt_for_audience(audience)
+ self.apply(headers, token=token)
+
+ @_helpers.copy_docstring(google.auth.credentials.Signing)
+ def sign_bytes(self, message):
+ return self._signer.sign(message)
+
+ @property
+ @_helpers.copy_docstring(google.auth.credentials.Signing)
+ def signer_email(self):
+ return self._issuer
+
+ @property
+ @_helpers.copy_docstring(google.auth.credentials.Signing)
+ def signer(self):
+ return self._signer
diff --git a/setup.py b/setup.py
index aaa13de..bad634a 100644
--- a/setup.py
+++ b/setup.py
@@ -23,6 +23,7 @@
'pyasn1-modules>=0.0.5',
'rsa>=3.1.4',
'six>=1.9.0',
+ 'cachetools>=2.0.0',
)
diff --git a/system_tests/test_grpc.py b/system_tests/test_grpc.py
index 4bf1c5b..365bc91 100644
--- a/system_tests/test_grpc.py
+++ b/system_tests/test_grpc.py
@@ -39,7 +39,7 @@
list(list_topics_iter)
-def test_grpc_request_with_jwt_credentials(http_request):
+def test_grpc_request_with_jwt_credentials():
credentials, project_id = google.auth.default()
audience = 'https://{}/google.pubsub.v1.Publisher'.format(
publisher_client.PublisherClient.SERVICE_ADDRESS)
@@ -49,7 +49,27 @@
channel = google.auth.transport.grpc.secure_authorized_channel(
credentials,
- http_request,
+ None,
+ publisher_client.PublisherClient.SERVICE_ADDRESS)
+
+ # Create a pub/sub client.
+ client = publisher_client.PublisherClient(channel=channel)
+
+ # list the topics and drain the iterator to test that an authorized API
+ # call works.
+ list_topics_iter = client.list_topics(
+ project='projects/{}'.format(project_id))
+ list(list_topics_iter)
+
+
+def test_grpc_request_with_on_demand_jwt_credentials():
+ credentials, project_id = google.auth.default()
+ credentials = google.auth.jwt.OnDemandCredentials.from_signing_credentials(
+ credentials)
+
+ channel = google.auth.transport.grpc.secure_authorized_channel(
+ credentials,
+ None,
publisher_client.PublisherClient.SERVICE_ADDRESS)
# Create a pub/sub client.
diff --git a/tests/test_jwt.py b/tests/test_jwt.py
index 59769de..22c5bc5 100644
--- a/tests/test_jwt.py
+++ b/tests/test_jwt.py
@@ -22,6 +22,7 @@
from google.auth import _helpers
from google.auth import crypt
+from google.auth import exceptions
from google.auth import jwt
@@ -196,7 +197,7 @@
assert payload['user'] == 'billy bob'
-class TestCredentials:
+class TestCredentials(object):
SERVICE_ACCOUNT_EMAIL = 'service-account@example.com'
SUBJECT = 'subject'
AUDIENCE = 'audience'
@@ -343,3 +344,135 @@
self.credentials.before_request(
None, 'GET', 'http://example.com?a=1#3', {})
assert self.credentials.valid
+
+
+class TestOnDemandCredentials(object):
+ SERVICE_ACCOUNT_EMAIL = 'service-account@example.com'
+ SUBJECT = 'subject'
+ ADDITIONAL_CLAIMS = {'meta': 'data'}
+ credentials = None
+
+ @pytest.fixture(autouse=True)
+ def credentials_fixture(self, signer):
+ self.credentials = jwt.OnDemandCredentials(
+ signer, self.SERVICE_ACCOUNT_EMAIL, self.SERVICE_ACCOUNT_EMAIL,
+ max_cache_size=2)
+
+ def test_from_service_account_info(self):
+ with open(SERVICE_ACCOUNT_JSON_FILE, 'r') as fh:
+ info = json.load(fh)
+
+ credentials = jwt.OnDemandCredentials.from_service_account_info(info)
+
+ assert credentials._signer.key_id == info['private_key_id']
+ assert credentials._issuer == info['client_email']
+ assert credentials._subject == info['client_email']
+
+ def test_from_service_account_info_args(self):
+ info = SERVICE_ACCOUNT_INFO.copy()
+
+ credentials = jwt.OnDemandCredentials.from_service_account_info(
+ info, subject=self.SUBJECT,
+ additional_claims=self.ADDITIONAL_CLAIMS)
+
+ assert credentials._signer.key_id == info['private_key_id']
+ assert credentials._issuer == info['client_email']
+ assert credentials._subject == self.SUBJECT
+ assert credentials._additional_claims == self.ADDITIONAL_CLAIMS
+
+ def test_from_service_account_file(self):
+ info = SERVICE_ACCOUNT_INFO.copy()
+
+ credentials = jwt.OnDemandCredentials.from_service_account_file(
+ SERVICE_ACCOUNT_JSON_FILE)
+
+ assert credentials._signer.key_id == info['private_key_id']
+ assert credentials._issuer == info['client_email']
+ assert credentials._subject == info['client_email']
+
+ def test_from_service_account_file_args(self):
+ info = SERVICE_ACCOUNT_INFO.copy()
+
+ credentials = jwt.OnDemandCredentials.from_service_account_file(
+ SERVICE_ACCOUNT_JSON_FILE, subject=self.SUBJECT,
+ additional_claims=self.ADDITIONAL_CLAIMS)
+
+ assert credentials._signer.key_id == info['private_key_id']
+ assert credentials._issuer == info['client_email']
+ assert credentials._subject == self.SUBJECT
+ assert credentials._additional_claims == self.ADDITIONAL_CLAIMS
+
+ def test_from_signing_credentials(self):
+ jwt_from_signing = self.credentials.from_signing_credentials(
+ self.credentials)
+ jwt_from_info = jwt.OnDemandCredentials.from_service_account_info(
+ SERVICE_ACCOUNT_INFO)
+
+ assert isinstance(jwt_from_signing, jwt.OnDemandCredentials)
+ assert jwt_from_signing._signer.key_id == jwt_from_info._signer.key_id
+ assert jwt_from_signing._issuer == jwt_from_info._issuer
+ assert jwt_from_signing._subject == jwt_from_info._subject
+
+ def test_default_state(self):
+ # Credentials are *always* valid.
+ assert self.credentials.valid
+ # Credentials *never* expire.
+ assert not self.credentials.expired
+
+ def test_with_claims(self):
+ new_claims = {'meep': 'moop'}
+ new_credentials = self.credentials.with_claims(
+ additional_claims=new_claims)
+
+ assert new_credentials._signer == self.credentials._signer
+ assert new_credentials._issuer == self.credentials._issuer
+ assert new_credentials._subject == self.credentials._subject
+ assert new_credentials._additional_claims == new_claims
+
+ def test_sign_bytes(self):
+ to_sign = b'123'
+ signature = self.credentials.sign_bytes(to_sign)
+ assert crypt.verify_signature(to_sign, signature, PUBLIC_CERT_BYTES)
+
+ def test_signer(self):
+ assert isinstance(self.credentials.signer, crypt.RSASigner)
+
+ def test_signer_email(self):
+ assert (self.credentials.signer_email ==
+ SERVICE_ACCOUNT_INFO['client_email'])
+
+ def _verify_token(self, token):
+ payload = jwt.decode(token, PUBLIC_CERT_BYTES)
+ assert payload['iss'] == self.SERVICE_ACCOUNT_EMAIL
+ return payload
+
+ def test_refresh(self):
+ with pytest.raises(exceptions.RefreshError):
+ self.credentials.refresh(None)
+
+ def test_before_request(self):
+ headers = {}
+
+ self.credentials.before_request(
+ None, 'GET', 'http://example.com?a=1#3', headers)
+
+ _, token = headers['authorization'].split(' ')
+ payload = self._verify_token(token)
+
+ assert payload['aud'] == 'http://example.com'
+
+ # Making another request should re-use the same token.
+ self.credentials.before_request(
+ None, 'GET', 'http://example.com?b=2', headers)
+
+ _, new_token = headers['authorization'].split(' ')
+
+ assert new_token == token
+
+ def test_expired_token(self):
+ self.credentials._cache['audience'] = (
+ mock.sentinel.token, datetime.datetime.min)
+
+ token = self.credentials._get_jwt_for_audience('audience')
+
+ assert token != mock.sentinel.token