Add compute engine-based IDTokenCredentials (#236)


diff --git a/google/auth/compute_engine/__init__.py b/google/auth/compute_engine/__init__.py
index 3794be2..ca31b46 100644
--- a/google/auth/compute_engine/__init__.py
+++ b/google/auth/compute_engine/__init__.py
@@ -15,8 +15,10 @@
 """Google Compute Engine authentication."""
 
 from google.auth.compute_engine.credentials import Credentials
+from google.auth.compute_engine.credentials import IDTokenCredentials
 
 
 __all__ = [
-    'Credentials'
+    'Credentials',
+    'IDTokenCredentials',
 ]
diff --git a/google/auth/compute_engine/credentials.py b/google/auth/compute_engine/credentials.py
index 3841df2..d9c6e26 100644
--- a/google/auth/compute_engine/credentials.py
+++ b/google/auth/compute_engine/credentials.py
@@ -19,11 +19,17 @@
 
 """
 
+import datetime
+
 import six
 
+from google.auth import _helpers
 from google.auth import credentials
 from google.auth import exceptions
+from google.auth import iam
+from google.auth import jwt
 from google.auth.compute_engine import _metadata
+from google.oauth2 import _client
 
 
 class Credentials(credentials.ReadOnlyScoped, credentials.Credentials):
@@ -108,3 +114,126 @@
     def requires_scopes(self):
         """False: Compute Engine credentials can not be scoped."""
         return False
+
+
+_DEFAULT_TOKEN_LIFETIME_SECS = 3600  # 1 hour in seconds
+_DEFAULT_TOKEN_URI = 'https://www.googleapis.com/oauth2/v4/token'
+
+
+class IDTokenCredentials(credentials.Credentials, credentials.Signing):
+    """Open ID Connect ID Token-based service account credentials.
+
+    These credentials relies on the default service account of a GCE instance.
+
+    In order for this to work, the GCE instance must have been started with
+    a service account that has access to the IAM Cloud API.
+    """
+    def __init__(self, request, target_audience,
+                 token_uri=_DEFAULT_TOKEN_URI,
+                 additional_claims=None,
+                 service_account_email=None):
+        """
+        Args:
+            request (google.auth.transport.Request): The object used to make
+                HTTP requests.
+            target_audience (str): The intended audience for these credentials,
+                used when requesting the ID Token. The ID Token's ``aud`` claim
+                will be set to this string.
+            token_uri (str): The OAuth 2.0 Token URI.
+            additional_claims (Mapping[str, str]): Any additional claims for
+                the JWT assertion used in the authorization grant.
+            service_account_email (str): Optional explicit service account to
+                use to sign JWT tokens.
+                By default, this is the default GCE service account.
+        """
+        super(IDTokenCredentials, self).__init__()
+
+        if service_account_email is None:
+            sa_info = _metadata.get_service_account_info(request)
+            service_account_email = sa_info['email']
+        self._service_account_email = service_account_email
+
+        self._signer = iam.Signer(
+            request=request,
+            credentials=Credentials(),
+            service_account_email=service_account_email)
+
+        self._token_uri = token_uri
+        self._target_audience = target_audience
+
+        if additional_claims is not None:
+            self._additional_claims = additional_claims
+        else:
+            self._additional_claims = {}
+
+    def with_target_audience(self, target_audience):
+        """Create a copy of these credentials with the specified target
+        audience.
+        Args:
+            target_audience (str): The intended audience for these credentials,
+            used when requesting the ID Token.
+        Returns:
+            google.auth.service_account.IDTokenCredentials: A new credentials
+                instance.
+        """
+        return self.__class__(
+            self._signer,
+            service_account_email=self._service_account_email,
+            token_uri=self._token_uri,
+            target_audience=target_audience,
+            additional_claims=self._additional_claims.copy())
+
+    def _make_authorization_grant_assertion(self):
+        """Create the OAuth 2.0 assertion.
+        This assertion is used during the OAuth 2.0 grant to acquire an
+        ID token.
+        Returns:
+            bytes: The authorization grant assertion.
+        """
+        now = _helpers.utcnow()
+        lifetime = datetime.timedelta(seconds=_DEFAULT_TOKEN_LIFETIME_SECS)
+        expiry = now + lifetime
+
+        payload = {
+            'iat': _helpers.datetime_to_secs(now),
+            'exp': _helpers.datetime_to_secs(expiry),
+            # The issuer must be the service account email.
+            'iss': self.service_account_email,
+            # The audience must be the auth token endpoint's URI
+            'aud': self._token_uri,
+            # The target audience specifies which service the ID token is
+            # intended for.
+            'target_audience': self._target_audience
+        }
+
+        payload.update(self._additional_claims)
+
+        token = jwt.encode(self._signer, payload)
+
+        return token
+
+    @_helpers.copy_docstring(credentials.Credentials)
+    def refresh(self, request):
+        assertion = self._make_authorization_grant_assertion()
+        access_token, expiry, _ = _client.id_token_jwt_grant(
+            request, self._token_uri, assertion)
+        self.token = access_token
+        self.expiry = expiry
+
+    @property
+    @_helpers.copy_docstring(credentials.Signing)
+    def signer(self):
+        return self._signer
+
+    @_helpers.copy_docstring(credentials.Signing)
+    def sign_bytes(self, message):
+        return self._signer.sign(message)
+
+    @property
+    def service_account_email(self):
+        """The service account email."""
+        return self._service_account_email
+
+    @property
+    def signer_email(self):
+        return self._service_account_email
diff --git a/tests/compute_engine/test_credentials.py b/tests/compute_engine/test_credentials.py
index ae2597d..ee415db 100644
--- a/tests/compute_engine/test_credentials.py
+++ b/tests/compute_engine/test_credentials.py
@@ -19,6 +19,7 @@
 
 from google.auth import _helpers
 from google.auth import exceptions
+from google.auth import jwt
 from google.auth import transport
 from google.auth.compute_engine import credentials
 
@@ -105,3 +106,278 @@
 
         # Credentials should now be valid.
         assert self.credentials.valid
+
+
+class TestIDTokenCredentials(object):
+    credentials = None
+
+    @mock.patch('google.auth.compute_engine._metadata.get', autospec=True)
+    def test_default_state(self, get):
+        get.side_effect = [{
+            'email': 'service-account@example.com',
+            'scope': ['one', 'two'],
+        }]
+
+        request = mock.create_autospec(transport.Request, instance=True)
+        self.credentials = credentials.IDTokenCredentials(
+            request=request, target_audience="https://example.com")
+
+        assert not self.credentials.valid
+        # Expiration hasn't been set yet
+        assert not self.credentials.expired
+        # Service account email hasn't been populated
+        assert (self.credentials.service_account_email
+                == 'service-account@example.com')
+        # Signer is initialized
+        assert self.credentials.signer
+        assert self.credentials.signer_email == 'service-account@example.com'
+
+    @mock.patch(
+        'google.auth._helpers.utcnow',
+        return_value=datetime.datetime.utcfromtimestamp(0))
+    @mock.patch('google.auth.compute_engine._metadata.get', autospec=True)
+    @mock.patch('google.auth.iam.Signer.sign', autospec=True)
+    def test_make_authorization_grant_assertion(self, sign, get, utcnow):
+        get.side_effect = [{
+            'email': 'service-account@example.com',
+            'scopes': ['one', 'two']
+        }]
+        sign.side_effect = [b'signature']
+
+        request = mock.create_autospec(transport.Request, instance=True)
+        self.credentials = credentials.IDTokenCredentials(
+            request=request, target_audience="https://audience.com")
+
+        # Generate authorization grant:
+        token = self.credentials._make_authorization_grant_assertion()
+        payload = jwt.decode(token, verify=False)
+
+        # The JWT token signature is 'signature' encoded in base 64:
+        assert token.endswith(b'.c2lnbmF0dXJl')
+
+        # Check that the credentials have the token and proper expiration
+        assert payload == {
+            'aud': 'https://www.googleapis.com/oauth2/v4/token',
+            'exp': 3600,
+            'iat': 0,
+            'iss': 'service-account@example.com',
+            'target_audience': 'https://audience.com'}
+
+    @mock.patch(
+        'google.auth._helpers.utcnow',
+        return_value=datetime.datetime.utcfromtimestamp(0))
+    @mock.patch('google.auth.compute_engine._metadata.get', autospec=True)
+    @mock.patch('google.auth.iam.Signer.sign', autospec=True)
+    def test_with_service_account(self, sign, get, utcnow):
+        sign.side_effect = [b'signature']
+
+        request = mock.create_autospec(transport.Request, instance=True)
+        self.credentials = credentials.IDTokenCredentials(
+            request=request, target_audience="https://audience.com",
+            service_account_email="service-account@other.com")
+
+        # Generate authorization grant:
+        token = self.credentials._make_authorization_grant_assertion()
+        payload = jwt.decode(token, verify=False)
+
+        # The JWT token signature is 'signature' encoded in base 64:
+        assert token.endswith(b'.c2lnbmF0dXJl')
+
+        # Check that the credentials have the token and proper expiration
+        assert payload == {
+            'aud': 'https://www.googleapis.com/oauth2/v4/token',
+            'exp': 3600,
+            'iat': 0,
+            'iss': 'service-account@other.com',
+            'target_audience': 'https://audience.com'}
+
+    @mock.patch(
+        'google.auth._helpers.utcnow',
+        return_value=datetime.datetime.utcfromtimestamp(0))
+    @mock.patch('google.auth.compute_engine._metadata.get', autospec=True)
+    @mock.patch('google.auth.iam.Signer.sign', autospec=True)
+    def test_additional_claims(self, sign, get, utcnow):
+        get.side_effect = [{
+            'email': 'service-account@example.com',
+            'scopes': ['one', 'two']
+        }]
+        sign.side_effect = [b'signature']
+
+        request = mock.create_autospec(transport.Request, instance=True)
+        self.credentials = credentials.IDTokenCredentials(
+            request=request, target_audience="https://audience.com",
+            additional_claims={'foo': 'bar'})
+
+        # Generate authorization grant:
+        token = self.credentials._make_authorization_grant_assertion()
+        payload = jwt.decode(token, verify=False)
+
+        # The JWT token signature is 'signature' encoded in base 64:
+        assert token.endswith(b'.c2lnbmF0dXJl')
+
+        # Check that the credentials have the token and proper expiration
+        assert payload == {
+            'aud': 'https://www.googleapis.com/oauth2/v4/token',
+            'exp': 3600,
+            'iat': 0,
+            'iss': 'service-account@example.com',
+            'target_audience': 'https://audience.com',
+            'foo': 'bar'}
+
+    @mock.patch(
+        'google.auth._helpers.utcnow',
+        return_value=datetime.datetime.utcfromtimestamp(0))
+    @mock.patch('google.auth.compute_engine._metadata.get', autospec=True)
+    @mock.patch('google.auth.iam.Signer.sign', autospec=True)
+    def test_with_target_audience(self, sign, get, utcnow):
+        get.side_effect = [{
+            'email': 'service-account@example.com',
+            'scopes': ['one', 'two']
+        }]
+        sign.side_effect = [b'signature']
+
+        request = mock.create_autospec(transport.Request, instance=True)
+        self.credentials = credentials.IDTokenCredentials(
+            request=request, target_audience="https://audience.com")
+        self.credentials = (
+            self.credentials.with_target_audience("https://actually.not"))
+
+        # Generate authorization grant:
+        token = self.credentials._make_authorization_grant_assertion()
+        payload = jwt.decode(token, verify=False)
+
+        # The JWT token signature is 'signature' encoded in base 64:
+        assert token.endswith(b'.c2lnbmF0dXJl')
+
+        # Check that the credentials have the token and proper expiration
+        assert payload == {
+            'aud': 'https://www.googleapis.com/oauth2/v4/token',
+            'exp': 3600,
+            'iat': 0,
+            'iss': 'service-account@example.com',
+            'target_audience': 'https://actually.not'}
+
+    @mock.patch(
+        'google.auth._helpers.utcnow',
+        return_value=datetime.datetime.utcfromtimestamp(0))
+    @mock.patch('google.auth.compute_engine._metadata.get', autospec=True)
+    @mock.patch('google.auth.iam.Signer.sign', autospec=True)
+    @mock.patch('google.oauth2._client.id_token_jwt_grant', autospec=True)
+    def test_refresh_success(self, id_token_jwt_grant, sign, get, utcnow):
+        get.side_effect = [{
+            'email': 'service-account@example.com',
+            'scopes': ['one', 'two']
+        }]
+        sign.side_effect = [b'signature']
+        id_token_jwt_grant.side_effect = [(
+            'idtoken',
+            datetime.datetime.utcfromtimestamp(3600),
+            {},
+        )]
+
+        request = mock.create_autospec(transport.Request, instance=True)
+        self.credentials = credentials.IDTokenCredentials(
+            request=request, target_audience="https://audience.com")
+
+        # Refresh credentials
+        self.credentials.refresh(None)
+
+        # Check that the credentials have the token and proper expiration
+        assert self.credentials.token == 'idtoken'
+        assert self.credentials.expiry == (
+            datetime.datetime.utcfromtimestamp(3600))
+
+        # Check the credential info
+        assert (self.credentials.service_account_email ==
+                'service-account@example.com')
+
+        # Check that the credentials are valid (have a token and are not
+        # expired)
+        assert self.credentials.valid
+
+    @mock.patch(
+        'google.auth._helpers.utcnow',
+        return_value=datetime.datetime.utcfromtimestamp(0))
+    @mock.patch('google.auth.compute_engine._metadata.get', autospec=True)
+    @mock.patch('google.auth.iam.Signer.sign', autospec=True)
+    def test_refresh_error(self, sign, get, utcnow):
+        get.side_effect = [{
+            'email': 'service-account@example.com',
+            'scopes': ['one', 'two'],
+        }]
+        sign.side_effect = [b'signature']
+
+        request = mock.create_autospec(transport.Request, instance=True)
+        response = mock.Mock()
+        response.data = b'{"error": "http error"}'
+        response.status = 500
+        request.side_effect = [response]
+
+        self.credentials = credentials.IDTokenCredentials(
+            request=request, target_audience="https://audience.com")
+
+        with pytest.raises(exceptions.RefreshError) as excinfo:
+            self.credentials.refresh(request)
+
+        assert excinfo.match(r'http error')
+
+    @mock.patch(
+        'google.auth._helpers.utcnow',
+        return_value=datetime.datetime.utcfromtimestamp(0))
+    @mock.patch('google.auth.compute_engine._metadata.get', autospec=True)
+    @mock.patch('google.auth.iam.Signer.sign', autospec=True)
+    @mock.patch('google.oauth2._client.id_token_jwt_grant', autospec=True)
+    def test_before_request_refreshes(
+            self, id_token_jwt_grant, sign, get, utcnow):
+        get.side_effect = [{
+            'email': 'service-account@example.com',
+            'scopes': 'one two'
+        }]
+        sign.side_effect = [b'signature']
+        id_token_jwt_grant.side_effect = [(
+            'idtoken',
+            datetime.datetime.utcfromtimestamp(3600),
+            {},
+        )]
+
+        request = mock.create_autospec(transport.Request, instance=True)
+        self.credentials = credentials.IDTokenCredentials(
+            request=request, target_audience="https://audience.com")
+
+        # Credentials should start as invalid
+        assert not self.credentials.valid
+
+        # before_request should cause a refresh
+        request = mock.create_autospec(transport.Request, instance=True)
+        self.credentials.before_request(
+            request, 'GET', 'http://example.com?a=1#3', {})
+
+        # The refresh endpoint should've been called.
+        assert get.called
+
+        # Credentials should now be valid.
+        assert self.credentials.valid
+
+    @mock.patch('google.auth.compute_engine._metadata.get', autospec=True)
+    @mock.patch('google.auth.iam.Signer.sign', autospec=True)
+    def test_sign_bytes(self, sign, get):
+        get.side_effect = [{
+            'email': 'service-account@example.com',
+            'scopes': ['one', 'two']
+        }]
+        sign.side_effect = [b'signature']
+
+        request = mock.create_autospec(transport.Request, instance=True)
+        response = mock.Mock()
+        response.data = b'{"signature": "c2lnbmF0dXJl"}'
+        response.status = 200
+        request.side_effect = [response]
+
+        self.credentials = credentials.IDTokenCredentials(
+            request=request, target_audience="https://audience.com")
+
+        # Generate authorization grant:
+        signature = self.credentials.sign_bytes(b"some bytes")
+
+        # The JWT token signature is 'signature' encoded in base 64:
+        assert signature == b'signature'