Add jwt.OnDemandCredentials (#142)

diff --git a/google/auth/jwt.py b/google/auth/jwt.py
index 412f122..b1eb5fb 100644
--- a/google/auth/jwt.py
+++ b/google/auth/jwt.py
@@ -46,13 +46,17 @@
 import datetime
 import json
 
+import cachetools
+from six.moves import urllib
+
 from google.auth import _helpers
 from google.auth import _service_account_info
 from google.auth import crypt
+from google.auth import exceptions
 import google.auth.credentials
 
-
 _DEFAULT_TOKEN_LIFETIME_SECS = 3600  # 1 hour in seconds
+_DEFAULT_MAX_CACHE_SIZE = 10
 
 
 def encode(signer, payload, header=None, key_id=None):
@@ -316,10 +320,10 @@
         self._audience = audience
         self._token_lifetime = token_lifetime
 
-        if additional_claims is not None:
-            self._additional_claims = additional_claims
-        else:
-            self._additional_claims = {}
+        if additional_claims is None:
+            additional_claims = {}
+
+        self._additional_claims = additional_claims
 
     @classmethod
     def _from_signer_and_info(cls, signer, info, **kwargs):
@@ -343,8 +347,7 @@
 
     @classmethod
     def from_service_account_info(cls, info, **kwargs):
-        """Creates a Credentials instance from a dictionary containing service
-        account info in Google format.
+        """Creates an Credentials instance from a dictionary.
 
         Args:
             info (Mapping[str, str]): The service account info in Google
@@ -487,3 +490,266 @@
     @_helpers.copy_docstring(google.auth.credentials.Signing)
     def signer(self):
         return self._signer
+
+
+class OnDemandCredentials(
+        google.auth.credentials.Signing,
+        google.auth.credentials.Credentials):
+    """On-demand JWT credentials.
+
+    Like :class:`Credentials`, this class uses a JWT as the bearer token for
+    authentication. However, this class does not require the audience at
+    construction time. Instead, it will generate a new token on-demand for
+    each request using the request URI as the audience. It caches tokens
+    so that multiple requests to the same URI do not incur the overhead
+    of generating a new token every time.
+
+    This behavior is especially useful for `gRPC`_ clients. A gRPC service may
+    have multiple audience and gRPC clients may not know all of the audiences
+    required for accessing a particular service. With these credentials,
+    no knowledge of the audiences is required ahead of time.
+
+    .. _grpc: http://www.grpc.io/
+    """
+
+    def __init__(self, signer, issuer, subject,
+                 additional_claims=None,
+                 token_lifetime=_DEFAULT_TOKEN_LIFETIME_SECS,
+                 max_cache_size=_DEFAULT_MAX_CACHE_SIZE):
+        """
+        Args:
+            signer (google.auth.crypt.Signer): The signer used to sign JWTs.
+            issuer (str): The `iss` claim.
+            subject (str): The `sub` claim.
+            additional_claims (Mapping[str, str]): Any additional claims for
+                the JWT payload.
+            token_lifetime (int): The amount of time in seconds for
+                which the token is valid. Defaults to 1 hour.
+            max_cache_size (int): The maximum number of JWT tokens to keep in
+                cache. Tokens are cached using :class:`cachetools.LRUCache`.
+        """
+        super(OnDemandCredentials, self).__init__()
+        self._signer = signer
+        self._issuer = issuer
+        self._subject = subject
+        self._token_lifetime = token_lifetime
+
+        if additional_claims is None:
+            additional_claims = {}
+
+        self._additional_claims = additional_claims
+        self._cache = cachetools.LRUCache(maxsize=max_cache_size)
+
+    @classmethod
+    def _from_signer_and_info(cls, signer, info, **kwargs):
+        """Creates an OnDemandCredentials instance from a signer and service
+        account info.
+
+        Args:
+            signer (google.auth.crypt.Signer): The signer used to sign JWTs.
+            info (Mapping[str, str]): The service account info.
+            kwargs: Additional arguments to pass to the constructor.
+
+        Returns:
+            google.auth.jwt.OnDemandCredentials: The constructed credentials.
+
+        Raises:
+            ValueError: If the info is not in the expected format.
+        """
+        kwargs.setdefault('subject', info['client_email'])
+        kwargs.setdefault('issuer', info['client_email'])
+        return cls(signer, **kwargs)
+
+    @classmethod
+    def from_service_account_info(cls, info, **kwargs):
+        """Creates an OnDemandCredentials instance from a dictionary.
+
+        Args:
+            info (Mapping[str, str]): The service account info in Google
+                format.
+            kwargs: Additional arguments to pass to the constructor.
+
+        Returns:
+            google.auth.jwt.OnDemandCredentials: The constructed credentials.
+
+        Raises:
+            ValueError: If the info is not in the expected format.
+        """
+        signer = _service_account_info.from_dict(
+            info, require=['client_email'])
+        return cls._from_signer_and_info(signer, info, **kwargs)
+
+    @classmethod
+    def from_service_account_file(cls, filename, **kwargs):
+        """Creates an OnDemandCredentials instance from a service account .json
+        file in Google format.
+
+        Args:
+            filename (str): The path to the service account .json file.
+            kwargs: Additional arguments to pass to the constructor.
+
+        Returns:
+            google.auth.jwt.OnDemandCredentials: The constructed credentials.
+        """
+        info, signer = _service_account_info.from_filename(
+            filename, require=['client_email'])
+        return cls._from_signer_and_info(signer, info, **kwargs)
+
+    @classmethod
+    def from_signing_credentials(cls, credentials, **kwargs):
+        """Creates a new :class:`google.auth.jwt.OnDemandCredentials` instance
+        from an existing :class:`google.auth.credentials.Signing` instance.
+
+        The new instance will use the same signer as the existing instance and
+        will use the existing instance's signer email as the issuer and
+        subject by default.
+
+        Example::
+
+            svc_creds = service_account.Credentials.from_service_account_file(
+                'service_account.json')
+            jwt_creds = jwt.OnDemandCredentials.from_signing_credentials(
+                svc_creds)
+
+        Args:
+            credentials (google.auth.credentials.Signing): The credentials to
+                use to construct the new credentials.
+            kwargs: Additional arguments to pass to the constructor.
+
+        Returns:
+            google.auth.jwt.Credentials: A new Credentials instance.
+        """
+        kwargs.setdefault('issuer', credentials.signer_email)
+        kwargs.setdefault('subject', credentials.signer_email)
+        return cls(credentials.signer, **kwargs)
+
+    def with_claims(self, issuer=None, subject=None, additional_claims=None):
+        """Returns a copy of these credentials with modified claims.
+
+        Args:
+            issuer (str): The `iss` claim. If unspecified the current issuer
+                claim will be used.
+            subject (str): The `sub` claim. If unspecified the current subject
+                claim will be used.
+            additional_claims (Mapping[str, str]): Any additional claims for
+                the JWT payload. This will be merged with the current
+                additional claims.
+
+        Returns:
+            google.auth.jwt.OnDemandCredentials: A new credentials instance.
+        """
+        new_additional_claims = copy.deepcopy(self._additional_claims)
+        new_additional_claims.update(additional_claims or {})
+
+        return OnDemandCredentials(
+            self._signer,
+            issuer=issuer if issuer is not None else self._issuer,
+            subject=subject if subject is not None else self._subject,
+            additional_claims=new_additional_claims,
+            max_cache_size=self._cache.maxsize)
+
+    @property
+    def valid(self):
+        """Checks the validity of the credentials.
+
+        These credentials are always valid because it generates tokens on
+        demand.
+        """
+        return True
+
+    def _make_jwt_for_audience(self, audience):
+        """Make a new JWT for the given audience.
+
+        Args:
+            audience (str): The intended audience.
+
+        Returns:
+            Tuple[bytes, datetime]: The encoded JWT and the expiration.
+        """
+        now = _helpers.utcnow()
+        lifetime = datetime.timedelta(seconds=self._token_lifetime)
+        expiry = now + lifetime
+
+        payload = {
+            'iss': self._issuer,
+            'sub': self._subject,
+            'iat': _helpers.datetime_to_secs(now),
+            'exp': _helpers.datetime_to_secs(expiry),
+            'aud': audience,
+        }
+
+        payload.update(self._additional_claims)
+
+        jwt = encode(self._signer, payload)
+
+        return jwt, expiry
+
+    def _get_jwt_for_audience(self, audience):
+        """Get a JWT For a given audience.
+
+        If there is already an existing, non-expired token in the cache for
+        the audience, that token is used. Otherwise, a new token will be
+        created.
+
+        Args:
+            audience (str): The intended audience.
+
+        Returns:
+            bytes: The encoded JWT.
+        """
+        token, expiry = self._cache.get(audience, (None, None))
+
+        if token is None or expiry < _helpers.utcnow():
+            token, expiry = self._make_jwt_for_audience(audience)
+            self._cache[audience] = token, expiry
+
+        return token
+
+    def refresh(self, request):
+        """Raises an exception, these credentials can not be directly
+        refreshed.
+
+        Args:
+            request (Any): Unused.
+
+        Raises:
+            google.auth.RefreshError
+        """
+        # pylint: disable=unused-argument
+        # (pylint doesn't correctly recognize overridden methods.)
+        raise exceptions.RefreshError(
+            'OnDemandCredentials can not be directly refreshed.')
+
+    def before_request(self, request, method, url, headers):
+        """Performs credential-specific before request logic.
+
+        Args:
+            request (Any): Unused. JWT credentials do not need to make an
+                HTTP request to refresh.
+            method (str): The request's HTTP method.
+            url (str): The request's URI. This is used as the audience claim
+                when generating the JWT.
+            headers (Mapping): The request's headers.
+        """
+        # pylint: disable=unused-argument
+        # (pylint doesn't correctly recognize overridden methods.)
+        parts = urllib.parse.urlsplit(url)
+        # Strip query string and fragment
+        audience = urllib.parse.urlunsplit(
+            (parts.scheme, parts.netloc, parts.path, None, None))
+        token = self._get_jwt_for_audience(audience)
+        self.apply(headers, token=token)
+
+    @_helpers.copy_docstring(google.auth.credentials.Signing)
+    def sign_bytes(self, message):
+        return self._signer.sign(message)
+
+    @property
+    @_helpers.copy_docstring(google.auth.credentials.Signing)
+    def signer_email(self):
+        return self._issuer
+
+    @property
+    @_helpers.copy_docstring(google.auth.credentials.Signing)
+    def signer(self):
+        return self._signer
diff --git a/setup.py b/setup.py
index aaa13de..bad634a 100644
--- a/setup.py
+++ b/setup.py
@@ -23,6 +23,7 @@
     'pyasn1-modules>=0.0.5',
     'rsa>=3.1.4',
     'six>=1.9.0',
+    'cachetools>=2.0.0',
 )
 
 
diff --git a/system_tests/test_grpc.py b/system_tests/test_grpc.py
index 4bf1c5b..365bc91 100644
--- a/system_tests/test_grpc.py
+++ b/system_tests/test_grpc.py
@@ -39,7 +39,7 @@
     list(list_topics_iter)
 
 
-def test_grpc_request_with_jwt_credentials(http_request):
+def test_grpc_request_with_jwt_credentials():
     credentials, project_id = google.auth.default()
     audience = 'https://{}/google.pubsub.v1.Publisher'.format(
         publisher_client.PublisherClient.SERVICE_ADDRESS)
@@ -49,7 +49,27 @@
 
     channel = google.auth.transport.grpc.secure_authorized_channel(
         credentials,
-        http_request,
+        None,
+        publisher_client.PublisherClient.SERVICE_ADDRESS)
+
+    # Create a pub/sub client.
+    client = publisher_client.PublisherClient(channel=channel)
+
+    # list the topics and drain the iterator to test that an authorized API
+    # call works.
+    list_topics_iter = client.list_topics(
+        project='projects/{}'.format(project_id))
+    list(list_topics_iter)
+
+
+def test_grpc_request_with_on_demand_jwt_credentials():
+    credentials, project_id = google.auth.default()
+    credentials = google.auth.jwt.OnDemandCredentials.from_signing_credentials(
+        credentials)
+
+    channel = google.auth.transport.grpc.secure_authorized_channel(
+        credentials,
+        None,
         publisher_client.PublisherClient.SERVICE_ADDRESS)
 
     # Create a pub/sub client.
diff --git a/tests/test_jwt.py b/tests/test_jwt.py
index 59769de..22c5bc5 100644
--- a/tests/test_jwt.py
+++ b/tests/test_jwt.py
@@ -22,6 +22,7 @@
 
 from google.auth import _helpers
 from google.auth import crypt
+from google.auth import exceptions
 from google.auth import jwt
 
 
@@ -196,7 +197,7 @@
     assert payload['user'] == 'billy bob'
 
 
-class TestCredentials:
+class TestCredentials(object):
     SERVICE_ACCOUNT_EMAIL = 'service-account@example.com'
     SUBJECT = 'subject'
     AUDIENCE = 'audience'
@@ -343,3 +344,135 @@
         self.credentials.before_request(
             None, 'GET', 'http://example.com?a=1#3', {})
         assert self.credentials.valid
+
+
+class TestOnDemandCredentials(object):
+    SERVICE_ACCOUNT_EMAIL = 'service-account@example.com'
+    SUBJECT = 'subject'
+    ADDITIONAL_CLAIMS = {'meta': 'data'}
+    credentials = None
+
+    @pytest.fixture(autouse=True)
+    def credentials_fixture(self, signer):
+        self.credentials = jwt.OnDemandCredentials(
+            signer, self.SERVICE_ACCOUNT_EMAIL, self.SERVICE_ACCOUNT_EMAIL,
+            max_cache_size=2)
+
+    def test_from_service_account_info(self):
+        with open(SERVICE_ACCOUNT_JSON_FILE, 'r') as fh:
+            info = json.load(fh)
+
+        credentials = jwt.OnDemandCredentials.from_service_account_info(info)
+
+        assert credentials._signer.key_id == info['private_key_id']
+        assert credentials._issuer == info['client_email']
+        assert credentials._subject == info['client_email']
+
+    def test_from_service_account_info_args(self):
+        info = SERVICE_ACCOUNT_INFO.copy()
+
+        credentials = jwt.OnDemandCredentials.from_service_account_info(
+            info, subject=self.SUBJECT,
+            additional_claims=self.ADDITIONAL_CLAIMS)
+
+        assert credentials._signer.key_id == info['private_key_id']
+        assert credentials._issuer == info['client_email']
+        assert credentials._subject == self.SUBJECT
+        assert credentials._additional_claims == self.ADDITIONAL_CLAIMS
+
+    def test_from_service_account_file(self):
+        info = SERVICE_ACCOUNT_INFO.copy()
+
+        credentials = jwt.OnDemandCredentials.from_service_account_file(
+            SERVICE_ACCOUNT_JSON_FILE)
+
+        assert credentials._signer.key_id == info['private_key_id']
+        assert credentials._issuer == info['client_email']
+        assert credentials._subject == info['client_email']
+
+    def test_from_service_account_file_args(self):
+        info = SERVICE_ACCOUNT_INFO.copy()
+
+        credentials = jwt.OnDemandCredentials.from_service_account_file(
+            SERVICE_ACCOUNT_JSON_FILE, subject=self.SUBJECT,
+            additional_claims=self.ADDITIONAL_CLAIMS)
+
+        assert credentials._signer.key_id == info['private_key_id']
+        assert credentials._issuer == info['client_email']
+        assert credentials._subject == self.SUBJECT
+        assert credentials._additional_claims == self.ADDITIONAL_CLAIMS
+
+    def test_from_signing_credentials(self):
+        jwt_from_signing = self.credentials.from_signing_credentials(
+            self.credentials)
+        jwt_from_info = jwt.OnDemandCredentials.from_service_account_info(
+            SERVICE_ACCOUNT_INFO)
+
+        assert isinstance(jwt_from_signing, jwt.OnDemandCredentials)
+        assert jwt_from_signing._signer.key_id == jwt_from_info._signer.key_id
+        assert jwt_from_signing._issuer == jwt_from_info._issuer
+        assert jwt_from_signing._subject == jwt_from_info._subject
+
+    def test_default_state(self):
+        # Credentials are *always* valid.
+        assert self.credentials.valid
+        # Credentials *never* expire.
+        assert not self.credentials.expired
+
+    def test_with_claims(self):
+        new_claims = {'meep': 'moop'}
+        new_credentials = self.credentials.with_claims(
+            additional_claims=new_claims)
+
+        assert new_credentials._signer == self.credentials._signer
+        assert new_credentials._issuer == self.credentials._issuer
+        assert new_credentials._subject == self.credentials._subject
+        assert new_credentials._additional_claims == new_claims
+
+    def test_sign_bytes(self):
+        to_sign = b'123'
+        signature = self.credentials.sign_bytes(to_sign)
+        assert crypt.verify_signature(to_sign, signature, PUBLIC_CERT_BYTES)
+
+    def test_signer(self):
+        assert isinstance(self.credentials.signer, crypt.RSASigner)
+
+    def test_signer_email(self):
+        assert (self.credentials.signer_email ==
+                SERVICE_ACCOUNT_INFO['client_email'])
+
+    def _verify_token(self, token):
+        payload = jwt.decode(token, PUBLIC_CERT_BYTES)
+        assert payload['iss'] == self.SERVICE_ACCOUNT_EMAIL
+        return payload
+
+    def test_refresh(self):
+        with pytest.raises(exceptions.RefreshError):
+            self.credentials.refresh(None)
+
+    def test_before_request(self):
+        headers = {}
+
+        self.credentials.before_request(
+            None, 'GET', 'http://example.com?a=1#3', headers)
+
+        _, token = headers['authorization'].split(' ')
+        payload = self._verify_token(token)
+
+        assert payload['aud'] == 'http://example.com'
+
+        # Making another request should re-use the same token.
+        self.credentials.before_request(
+            None, 'GET', 'http://example.com?b=2', headers)
+
+        _, new_token = headers['authorization'].split(' ')
+
+        assert new_token == token
+
+    def test_expired_token(self):
+        self.credentials._cache['audience'] = (
+            mock.sentinel.token, datetime.datetime.min)
+
+        token = self.credentials._get_jwt_for_audience('audience')
+
+        assert token != mock.sentinel.token