feat: fetch id token from GCE metadata server (#462)

feat: fetch id token from GCE metadata server
diff --git a/google/auth/compute_engine/credentials.py b/google/auth/compute_engine/credentials.py
index e35907a..1927c26 100644
--- a/google/auth/compute_engine/credentials.py
+++ b/google/auth/compute_engine/credentials.py
@@ -125,18 +125,24 @@
 
     These credentials relies on the default service account of a GCE instance.
 
-    In order for this to work, the GCE instance must have been started with
+    ID token can be requested from `GCE metadata server identity endpoint`_, IAM
+    token endpoint or other token endpoints you specify. If metadata server
+    identity endpoint is not used, the GCE instance must have been started with
     a service account that has access to the IAM Cloud API.
+
+    .. _GCE metadata server identity endpoint:
+        https://cloud.google.com/compute/docs/instances/verifying-instance-identity
     """
 
     def __init__(
         self,
         request,
         target_audience,
-        token_uri=_DEFAULT_TOKEN_URI,
+        token_uri=None,
         additional_claims=None,
         service_account_email=None,
         signer=None,
+        use_metadata_identity_endpoint=False,
     ):
         """
         Args:
@@ -154,29 +160,54 @@
             signer (google.auth.crypt.Signer): The signer used to sign JWTs.
                 In case the signer is specified, the request argument will be
                 ignored.
+            use_metadata_identity_endpoint (bool): Whether to use GCE metadata
+                identity endpoint. For backward compatibility the default value
+                is False. If set to True, ``token_uri``, ``additional_claims``,
+                ``service_account_email``, ``signer`` argument should not be set;
+                otherwise ValueError will be raised.
+
+        Raises:
+            ValueError:
+                If ``use_metadata_identity_endpoint`` is set to True, and one of
+                ``token_uri``, ``additional_claims``, ``service_account_email``,
+                 ``signer`` arguments is set.
         """
         super(IDTokenCredentials, self).__init__()
 
-        if service_account_email is None:
-            sa_info = _metadata.get_service_account_info(request)
-            service_account_email = sa_info["email"]
-        self._service_account_email = service_account_email
-
-        if signer is None:
-            signer = iam.Signer(
-                request=request,
-                credentials=Credentials(),
-                service_account_email=service_account_email,
-            )
-        self._signer = signer
-
-        self._token_uri = token_uri
+        self._use_metadata_identity_endpoint = use_metadata_identity_endpoint
         self._target_audience = target_audience
 
-        if additional_claims is not None:
-            self._additional_claims = additional_claims
+        if use_metadata_identity_endpoint:
+            if token_uri or additional_claims or service_account_email or signer:
+                raise ValueError(
+                    "If use_metadata_identity_endpoint is set, token_uri, "
+                    "additional_claims, service_account_email, signer arguments"
+                    " must not be set"
+                )
+            self._token_uri = None
+            self._additional_claims = None
+            self._signer = None
+
+        if service_account_email is None:
+            sa_info = _metadata.get_service_account_info(request)
+            self._service_account_email = sa_info["email"]
         else:
-            self._additional_claims = {}
+            self._service_account_email = service_account_email
+
+        if not use_metadata_identity_endpoint:
+            if signer is None:
+                signer = iam.Signer(
+                    request=request,
+                    credentials=Credentials(),
+                    service_account_email=self._service_account_email,
+                )
+            self._signer = signer
+            self._token_uri = token_uri or _DEFAULT_TOKEN_URI
+
+            if additional_claims is not None:
+                self._additional_claims = additional_claims
+            else:
+                self._additional_claims = {}
 
     def with_target_audience(self, target_audience):
         """Create a copy of these credentials with the specified target
@@ -190,14 +221,22 @@
         """
         # since the signer is already instantiated,
         # the request is not needed
-        return self.__class__(
-            None,
-            service_account_email=self._service_account_email,
-            token_uri=self._token_uri,
-            target_audience=target_audience,
-            additional_claims=self._additional_claims.copy(),
-            signer=self.signer,
-        )
+        if self._use_metadata_identity_endpoint:
+            return self.__class__(
+                None,
+                target_audience=target_audience,
+                use_metadata_identity_endpoint=True,
+            )
+        else:
+            return self.__class__(
+                None,
+                service_account_email=self._service_account_email,
+                token_uri=self._token_uri,
+                target_audience=target_audience,
+                additional_claims=self._additional_claims.copy(),
+                signer=self.signer,
+                use_metadata_identity_endpoint=False,
+            )
 
     def _make_authorization_grant_assertion(self):
         """Create the OAuth 2.0 assertion.
@@ -228,22 +267,76 @@
 
         return token
 
-    @_helpers.copy_docstring(credentials.Credentials)
+    def _call_metadata_identity_endpoint(self, request):
+        """Request ID token from metadata identity endpoint.
+
+        Args:
+            request (google.auth.transport.Request): The object used to make
+                HTTP requests.
+
+        Raises:
+            google.auth.exceptions.RefreshError: If the Compute Engine metadata
+                service can't be reached or if the instance has no credentials.
+            ValueError: If extracting expiry from the obtained ID token fails.
+        """
+        try:
+            id_token = _metadata.get(
+                request,
+                "instance/service-accounts/default/identity?audience={}&format=full".format(
+                    self._target_audience
+                ),
+            )
+        except exceptions.TransportError as caught_exc:
+            new_exc = exceptions.RefreshError(caught_exc)
+            six.raise_from(new_exc, caught_exc)
+
+        _, payload, _, _ = jwt._unverified_decode(id_token)
+        return id_token, payload["exp"]
+
     def refresh(self, request):
-        assertion = self._make_authorization_grant_assertion()
-        access_token, expiry, _ = _client.id_token_jwt_grant(
-            request, self._token_uri, assertion
-        )
-        self.token = access_token
-        self.expiry = expiry
+        """Refreshes the ID token.
+
+        Args:
+            request (google.auth.transport.Request): The object used to make
+                HTTP requests.
+
+        Raises:
+            google.auth.exceptions.RefreshError: If the credentials could
+                not be refreshed.
+            ValueError: If extracting expiry from the obtained ID token fails.
+        """
+        if self._use_metadata_identity_endpoint:
+            self.token, self.expiry = self._call_metadata_identity_endpoint(request)
+        else:
+            assertion = self._make_authorization_grant_assertion()
+            access_token, expiry, _ = _client.id_token_jwt_grant(
+                request, self._token_uri, assertion
+            )
+            self.token = access_token
+            self.expiry = expiry
 
     @property
     @_helpers.copy_docstring(credentials.Signing)
     def signer(self):
         return self._signer
 
-    @_helpers.copy_docstring(credentials.Signing)
     def sign_bytes(self, message):
+        """Signs the given message.
+
+        Args:
+            message (bytes): The message to sign.
+
+        Returns:
+            bytes: The message's cryptographic signature.
+
+        Raises:
+            ValueError:
+                Signer is not available if metadata identity endpoint is used.
+        """
+        if self._use_metadata_identity_endpoint:
+            raise ValueError(
+                "Signer is not available if metadata identity endpoint is used"
+            )
         return self._signer.sign(message)
 
     @property
diff --git a/system_tests/test_compute_engine.py b/system_tests/test_compute_engine.py
index 3217c95..bcfdfd6 100644
--- a/system_tests/test_compute_engine.py
+++ b/system_tests/test_compute_engine.py
@@ -18,6 +18,7 @@
 from google.auth import compute_engine
 from google.auth import _helpers
 from google.auth import exceptions
+from google.auth import jwt
 from google.auth.compute_engine import _metadata
 
 
@@ -48,3 +49,14 @@
     assert project_id is not None
     assert isinstance(credentials, compute_engine.Credentials)
     verify_refresh(credentials)
+
+
+def test_id_token_from_metadata(http_request):
+    credentials = compute_engine.IDTokenCredentials(
+        http_request, "target_audience", use_metadata_identity_endpoint=True
+    )
+    credentials.refresh(http_request)
+
+    _, payload, _, _ = jwt._unverified_decode(credentials.token)
+    assert payload["aud"] == "target_audience"
+    assert payload["exp"] == credentials.expiry
diff --git a/tests/compute_engine/test_credentials.py b/tests/compute_engine/test_credentials.py
index b861984..264235e 100644
--- a/tests/compute_engine/test_credentials.py
+++ b/tests/compute_engine/test_credentials.py
@@ -25,6 +25,24 @@
 from google.auth.compute_engine import credentials
 from google.auth.transport import requests
 
+SAMPLE_ID_TOKEN_EXP = 1584393400
+
+# header: {"alg": "RS256", "typ": "JWT", "kid": "1"}
+# payload: {"iss": "issuer", "iat": 1584393348, "sub": "subject",
+#   "exp": 1584393400,"aud": "audience"}
+SAMPLE_ID_TOKEN = (
+    b"eyJhbGciOiAiUlMyNTYiLCAidHlwIjogIkpXVCIsICJraWQiOiAiMSJ9."
+    b"eyJpc3MiOiAiaXNzdWVyIiwgImlhdCI6IDE1ODQzOTMzNDgsICJzdWIiO"
+    b"iAic3ViamVjdCIsICJleHAiOiAxNTg0MzkzNDAwLCAiYXVkIjogImF1ZG"
+    b"llbmNlIn0."
+    b"OquNjHKhTmlgCk361omRo18F_uY-7y0f_AmLbzW062Q1Zr61HAwHYP5FM"
+    b"316CK4_0cH8MUNGASsvZc3VqXAqub6PUTfhemH8pFEwBdAdG0LhrNkU0H"
+    b"WN1YpT55IiQ31esLdL5q-qDsOPpNZJUti1y1lAreM5nIn2srdWzGXGs4i"
+    b"TRQsn0XkNUCL4RErpciXmjfhMrPkcAjKA-mXQm2fa4jmTlEZFqFmUlym1"
+    b"ozJ0yf5grjN6AslN4OGvAv1pS-_Ko_pGBS6IQtSBC6vVKCUuBfaqNjykg"
+    b"bsxbLa6Fp0SYeYwO8ifEnkRvasVpc1WTQqfRB2JCj5pTBDzJpIpFCMmnQ"
+)
+
 
 class TestCredentials(object):
     credentials = None
@@ -238,6 +256,26 @@
             "foo": "bar",
         }
 
+    def test_token_uri(self):
+        request = mock.create_autospec(transport.Request, instance=True)
+
+        self.credentials = credentials.IDTokenCredentials(
+            request=request,
+            signer=mock.Mock(),
+            service_account_email="foo@example.com",
+            target_audience="https://audience.com",
+        )
+        assert self.credentials._token_uri == credentials._DEFAULT_TOKEN_URI
+
+        self.credentials = credentials.IDTokenCredentials(
+            request=request,
+            signer=mock.Mock(),
+            service_account_email="foo@example.com",
+            target_audience="https://audience.com",
+            token_uri="https://example.com/token",
+        )
+        assert self.credentials._token_uri == "https://example.com/token"
+
     @mock.patch(
         "google.auth._helpers.utcnow",
         return_value=datetime.datetime.utcfromtimestamp(0),
@@ -469,3 +507,104 @@
 
         # The JWT token signature is 'signature' encoded in base 64:
         assert signature == b"signature"
+
+    @mock.patch(
+        "google.auth.compute_engine._metadata.get_service_account_info", autospec=True
+    )
+    @mock.patch("google.auth.compute_engine._metadata.get", autospec=True)
+    def test_get_id_token_from_metadata(self, get, get_service_account_info):
+        get.return_value = SAMPLE_ID_TOKEN
+        get_service_account_info.return_value = {"email": "foo@example.com"}
+
+        cred = credentials.IDTokenCredentials(
+            mock.Mock(), "audience", use_metadata_identity_endpoint=True
+        )
+        cred.refresh(request=mock.Mock())
+
+        assert cred.token == SAMPLE_ID_TOKEN
+        assert cred.expiry == SAMPLE_ID_TOKEN_EXP
+        assert cred._use_metadata_identity_endpoint
+        assert cred._signer is None
+        assert cred._token_uri is None
+        assert cred._service_account_email == "foo@example.com"
+        assert cred._target_audience == "audience"
+        with pytest.raises(ValueError):
+            cred.sign_bytes(b"bytes")
+
+    @mock.patch(
+        "google.auth.compute_engine._metadata.get_service_account_info", autospec=True
+    )
+    def test_with_target_audience_for_metadata(self, get_service_account_info):
+        get_service_account_info.return_value = {"email": "foo@example.com"}
+
+        cred = credentials.IDTokenCredentials(
+            mock.Mock(), "audience", use_metadata_identity_endpoint=True
+        )
+        cred = cred.with_target_audience("new_audience")
+
+        assert cred._target_audience == "new_audience"
+        assert cred._use_metadata_identity_endpoint
+        assert cred._signer is None
+        assert cred._token_uri is None
+        assert cred._service_account_email == "foo@example.com"
+
+    @mock.patch(
+        "google.auth.compute_engine._metadata.get_service_account_info", autospec=True
+    )
+    @mock.patch("google.auth.compute_engine._metadata.get", autospec=True)
+    def test_invalid_id_token_from_metadata(self, get, get_service_account_info):
+        get.return_value = "invalid_id_token"
+        get_service_account_info.return_value = {"email": "foo@example.com"}
+
+        cred = credentials.IDTokenCredentials(
+            mock.Mock(), "audience", use_metadata_identity_endpoint=True
+        )
+
+        with pytest.raises(ValueError):
+            cred.refresh(request=mock.Mock())
+
+    @mock.patch(
+        "google.auth.compute_engine._metadata.get_service_account_info", autospec=True
+    )
+    @mock.patch("google.auth.compute_engine._metadata.get", autospec=True)
+    def test_transport_error_from_metadata(self, get, get_service_account_info):
+        get.side_effect = exceptions.TransportError("transport error")
+        get_service_account_info.return_value = {"email": "foo@example.com"}
+
+        cred = credentials.IDTokenCredentials(
+            mock.Mock(), "audience", use_metadata_identity_endpoint=True
+        )
+
+        with pytest.raises(exceptions.RefreshError) as excinfo:
+            cred.refresh(request=mock.Mock())
+        assert excinfo.match(r"transport error")
+
+    def test_get_id_token_from_metadata_constructor(self):
+        with pytest.raises(ValueError):
+            credentials.IDTokenCredentials(
+                mock.Mock(),
+                "audience",
+                use_metadata_identity_endpoint=True,
+                token_uri="token_uri",
+            )
+        with pytest.raises(ValueError):
+            credentials.IDTokenCredentials(
+                mock.Mock(),
+                "audience",
+                use_metadata_identity_endpoint=True,
+                signer=mock.Mock(),
+            )
+        with pytest.raises(ValueError):
+            credentials.IDTokenCredentials(
+                mock.Mock(),
+                "audience",
+                use_metadata_identity_endpoint=True,
+                additional_claims={"key", "value"},
+            )
+        with pytest.raises(ValueError):
+            credentials.IDTokenCredentials(
+                mock.Mock(),
+                "audience",
+                use_metadata_identity_endpoint=True,
+                service_account_email="foo@example.com",
+            )