feat: fetch id token from GCE metadata server (#462)
feat: fetch id token from GCE metadata server
diff --git a/google/auth/compute_engine/credentials.py b/google/auth/compute_engine/credentials.py
index e35907a..1927c26 100644
--- a/google/auth/compute_engine/credentials.py
+++ b/google/auth/compute_engine/credentials.py
@@ -125,18 +125,24 @@
These credentials relies on the default service account of a GCE instance.
- In order for this to work, the GCE instance must have been started with
+ ID token can be requested from `GCE metadata server identity endpoint`_, IAM
+ token endpoint or other token endpoints you specify. If metadata server
+ identity endpoint is not used, the GCE instance must have been started with
a service account that has access to the IAM Cloud API.
+
+ .. _GCE metadata server identity endpoint:
+ https://cloud.google.com/compute/docs/instances/verifying-instance-identity
"""
def __init__(
self,
request,
target_audience,
- token_uri=_DEFAULT_TOKEN_URI,
+ token_uri=None,
additional_claims=None,
service_account_email=None,
signer=None,
+ use_metadata_identity_endpoint=False,
):
"""
Args:
@@ -154,29 +160,54 @@
signer (google.auth.crypt.Signer): The signer used to sign JWTs.
In case the signer is specified, the request argument will be
ignored.
+ use_metadata_identity_endpoint (bool): Whether to use GCE metadata
+ identity endpoint. For backward compatibility the default value
+ is False. If set to True, ``token_uri``, ``additional_claims``,
+ ``service_account_email``, ``signer`` argument should not be set;
+ otherwise ValueError will be raised.
+
+ Raises:
+ ValueError:
+ If ``use_metadata_identity_endpoint`` is set to True, and one of
+ ``token_uri``, ``additional_claims``, ``service_account_email``,
+ ``signer`` arguments is set.
"""
super(IDTokenCredentials, self).__init__()
- if service_account_email is None:
- sa_info = _metadata.get_service_account_info(request)
- service_account_email = sa_info["email"]
- self._service_account_email = service_account_email
-
- if signer is None:
- signer = iam.Signer(
- request=request,
- credentials=Credentials(),
- service_account_email=service_account_email,
- )
- self._signer = signer
-
- self._token_uri = token_uri
+ self._use_metadata_identity_endpoint = use_metadata_identity_endpoint
self._target_audience = target_audience
- if additional_claims is not None:
- self._additional_claims = additional_claims
+ if use_metadata_identity_endpoint:
+ if token_uri or additional_claims or service_account_email or signer:
+ raise ValueError(
+ "If use_metadata_identity_endpoint is set, token_uri, "
+ "additional_claims, service_account_email, signer arguments"
+ " must not be set"
+ )
+ self._token_uri = None
+ self._additional_claims = None
+ self._signer = None
+
+ if service_account_email is None:
+ sa_info = _metadata.get_service_account_info(request)
+ self._service_account_email = sa_info["email"]
else:
- self._additional_claims = {}
+ self._service_account_email = service_account_email
+
+ if not use_metadata_identity_endpoint:
+ if signer is None:
+ signer = iam.Signer(
+ request=request,
+ credentials=Credentials(),
+ service_account_email=self._service_account_email,
+ )
+ self._signer = signer
+ self._token_uri = token_uri or _DEFAULT_TOKEN_URI
+
+ if additional_claims is not None:
+ self._additional_claims = additional_claims
+ else:
+ self._additional_claims = {}
def with_target_audience(self, target_audience):
"""Create a copy of these credentials with the specified target
@@ -190,14 +221,22 @@
"""
# since the signer is already instantiated,
# the request is not needed
- return self.__class__(
- None,
- service_account_email=self._service_account_email,
- token_uri=self._token_uri,
- target_audience=target_audience,
- additional_claims=self._additional_claims.copy(),
- signer=self.signer,
- )
+ if self._use_metadata_identity_endpoint:
+ return self.__class__(
+ None,
+ target_audience=target_audience,
+ use_metadata_identity_endpoint=True,
+ )
+ else:
+ return self.__class__(
+ None,
+ service_account_email=self._service_account_email,
+ token_uri=self._token_uri,
+ target_audience=target_audience,
+ additional_claims=self._additional_claims.copy(),
+ signer=self.signer,
+ use_metadata_identity_endpoint=False,
+ )
def _make_authorization_grant_assertion(self):
"""Create the OAuth 2.0 assertion.
@@ -228,22 +267,76 @@
return token
- @_helpers.copy_docstring(credentials.Credentials)
+ def _call_metadata_identity_endpoint(self, request):
+ """Request ID token from metadata identity endpoint.
+
+ Args:
+ request (google.auth.transport.Request): The object used to make
+ HTTP requests.
+
+ Raises:
+ google.auth.exceptions.RefreshError: If the Compute Engine metadata
+ service can't be reached or if the instance has no credentials.
+ ValueError: If extracting expiry from the obtained ID token fails.
+ """
+ try:
+ id_token = _metadata.get(
+ request,
+ "instance/service-accounts/default/identity?audience={}&format=full".format(
+ self._target_audience
+ ),
+ )
+ except exceptions.TransportError as caught_exc:
+ new_exc = exceptions.RefreshError(caught_exc)
+ six.raise_from(new_exc, caught_exc)
+
+ _, payload, _, _ = jwt._unverified_decode(id_token)
+ return id_token, payload["exp"]
+
def refresh(self, request):
- assertion = self._make_authorization_grant_assertion()
- access_token, expiry, _ = _client.id_token_jwt_grant(
- request, self._token_uri, assertion
- )
- self.token = access_token
- self.expiry = expiry
+ """Refreshes the ID token.
+
+ Args:
+ request (google.auth.transport.Request): The object used to make
+ HTTP requests.
+
+ Raises:
+ google.auth.exceptions.RefreshError: If the credentials could
+ not be refreshed.
+ ValueError: If extracting expiry from the obtained ID token fails.
+ """
+ if self._use_metadata_identity_endpoint:
+ self.token, self.expiry = self._call_metadata_identity_endpoint(request)
+ else:
+ assertion = self._make_authorization_grant_assertion()
+ access_token, expiry, _ = _client.id_token_jwt_grant(
+ request, self._token_uri, assertion
+ )
+ self.token = access_token
+ self.expiry = expiry
@property
@_helpers.copy_docstring(credentials.Signing)
def signer(self):
return self._signer
- @_helpers.copy_docstring(credentials.Signing)
def sign_bytes(self, message):
+ """Signs the given message.
+
+ Args:
+ message (bytes): The message to sign.
+
+ Returns:
+ bytes: The message's cryptographic signature.
+
+ Raises:
+ ValueError:
+ Signer is not available if metadata identity endpoint is used.
+ """
+ if self._use_metadata_identity_endpoint:
+ raise ValueError(
+ "Signer is not available if metadata identity endpoint is used"
+ )
return self._signer.sign(message)
@property
diff --git a/system_tests/test_compute_engine.py b/system_tests/test_compute_engine.py
index 3217c95..bcfdfd6 100644
--- a/system_tests/test_compute_engine.py
+++ b/system_tests/test_compute_engine.py
@@ -18,6 +18,7 @@
from google.auth import compute_engine
from google.auth import _helpers
from google.auth import exceptions
+from google.auth import jwt
from google.auth.compute_engine import _metadata
@@ -48,3 +49,14 @@
assert project_id is not None
assert isinstance(credentials, compute_engine.Credentials)
verify_refresh(credentials)
+
+
+def test_id_token_from_metadata(http_request):
+ credentials = compute_engine.IDTokenCredentials(
+ http_request, "target_audience", use_metadata_identity_endpoint=True
+ )
+ credentials.refresh(http_request)
+
+ _, payload, _, _ = jwt._unverified_decode(credentials.token)
+ assert payload["aud"] == "target_audience"
+ assert payload["exp"] == credentials.expiry
diff --git a/tests/compute_engine/test_credentials.py b/tests/compute_engine/test_credentials.py
index b861984..264235e 100644
--- a/tests/compute_engine/test_credentials.py
+++ b/tests/compute_engine/test_credentials.py
@@ -25,6 +25,24 @@
from google.auth.compute_engine import credentials
from google.auth.transport import requests
+SAMPLE_ID_TOKEN_EXP = 1584393400
+
+# header: {"alg": "RS256", "typ": "JWT", "kid": "1"}
+# payload: {"iss": "issuer", "iat": 1584393348, "sub": "subject",
+# "exp": 1584393400,"aud": "audience"}
+SAMPLE_ID_TOKEN = (
+ b"eyJhbGciOiAiUlMyNTYiLCAidHlwIjogIkpXVCIsICJraWQiOiAiMSJ9."
+ b"eyJpc3MiOiAiaXNzdWVyIiwgImlhdCI6IDE1ODQzOTMzNDgsICJzdWIiO"
+ b"iAic3ViamVjdCIsICJleHAiOiAxNTg0MzkzNDAwLCAiYXVkIjogImF1ZG"
+ b"llbmNlIn0."
+ b"OquNjHKhTmlgCk361omRo18F_uY-7y0f_AmLbzW062Q1Zr61HAwHYP5FM"
+ b"316CK4_0cH8MUNGASsvZc3VqXAqub6PUTfhemH8pFEwBdAdG0LhrNkU0H"
+ b"WN1YpT55IiQ31esLdL5q-qDsOPpNZJUti1y1lAreM5nIn2srdWzGXGs4i"
+ b"TRQsn0XkNUCL4RErpciXmjfhMrPkcAjKA-mXQm2fa4jmTlEZFqFmUlym1"
+ b"ozJ0yf5grjN6AslN4OGvAv1pS-_Ko_pGBS6IQtSBC6vVKCUuBfaqNjykg"
+ b"bsxbLa6Fp0SYeYwO8ifEnkRvasVpc1WTQqfRB2JCj5pTBDzJpIpFCMmnQ"
+)
+
class TestCredentials(object):
credentials = None
@@ -238,6 +256,26 @@
"foo": "bar",
}
+ def test_token_uri(self):
+ request = mock.create_autospec(transport.Request, instance=True)
+
+ self.credentials = credentials.IDTokenCredentials(
+ request=request,
+ signer=mock.Mock(),
+ service_account_email="foo@example.com",
+ target_audience="https://audience.com",
+ )
+ assert self.credentials._token_uri == credentials._DEFAULT_TOKEN_URI
+
+ self.credentials = credentials.IDTokenCredentials(
+ request=request,
+ signer=mock.Mock(),
+ service_account_email="foo@example.com",
+ target_audience="https://audience.com",
+ token_uri="https://example.com/token",
+ )
+ assert self.credentials._token_uri == "https://example.com/token"
+
@mock.patch(
"google.auth._helpers.utcnow",
return_value=datetime.datetime.utcfromtimestamp(0),
@@ -469,3 +507,104 @@
# The JWT token signature is 'signature' encoded in base 64:
assert signature == b"signature"
+
+ @mock.patch(
+ "google.auth.compute_engine._metadata.get_service_account_info", autospec=True
+ )
+ @mock.patch("google.auth.compute_engine._metadata.get", autospec=True)
+ def test_get_id_token_from_metadata(self, get, get_service_account_info):
+ get.return_value = SAMPLE_ID_TOKEN
+ get_service_account_info.return_value = {"email": "foo@example.com"}
+
+ cred = credentials.IDTokenCredentials(
+ mock.Mock(), "audience", use_metadata_identity_endpoint=True
+ )
+ cred.refresh(request=mock.Mock())
+
+ assert cred.token == SAMPLE_ID_TOKEN
+ assert cred.expiry == SAMPLE_ID_TOKEN_EXP
+ assert cred._use_metadata_identity_endpoint
+ assert cred._signer is None
+ assert cred._token_uri is None
+ assert cred._service_account_email == "foo@example.com"
+ assert cred._target_audience == "audience"
+ with pytest.raises(ValueError):
+ cred.sign_bytes(b"bytes")
+
+ @mock.patch(
+ "google.auth.compute_engine._metadata.get_service_account_info", autospec=True
+ )
+ def test_with_target_audience_for_metadata(self, get_service_account_info):
+ get_service_account_info.return_value = {"email": "foo@example.com"}
+
+ cred = credentials.IDTokenCredentials(
+ mock.Mock(), "audience", use_metadata_identity_endpoint=True
+ )
+ cred = cred.with_target_audience("new_audience")
+
+ assert cred._target_audience == "new_audience"
+ assert cred._use_metadata_identity_endpoint
+ assert cred._signer is None
+ assert cred._token_uri is None
+ assert cred._service_account_email == "foo@example.com"
+
+ @mock.patch(
+ "google.auth.compute_engine._metadata.get_service_account_info", autospec=True
+ )
+ @mock.patch("google.auth.compute_engine._metadata.get", autospec=True)
+ def test_invalid_id_token_from_metadata(self, get, get_service_account_info):
+ get.return_value = "invalid_id_token"
+ get_service_account_info.return_value = {"email": "foo@example.com"}
+
+ cred = credentials.IDTokenCredentials(
+ mock.Mock(), "audience", use_metadata_identity_endpoint=True
+ )
+
+ with pytest.raises(ValueError):
+ cred.refresh(request=mock.Mock())
+
+ @mock.patch(
+ "google.auth.compute_engine._metadata.get_service_account_info", autospec=True
+ )
+ @mock.patch("google.auth.compute_engine._metadata.get", autospec=True)
+ def test_transport_error_from_metadata(self, get, get_service_account_info):
+ get.side_effect = exceptions.TransportError("transport error")
+ get_service_account_info.return_value = {"email": "foo@example.com"}
+
+ cred = credentials.IDTokenCredentials(
+ mock.Mock(), "audience", use_metadata_identity_endpoint=True
+ )
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ cred.refresh(request=mock.Mock())
+ assert excinfo.match(r"transport error")
+
+ def test_get_id_token_from_metadata_constructor(self):
+ with pytest.raises(ValueError):
+ credentials.IDTokenCredentials(
+ mock.Mock(),
+ "audience",
+ use_metadata_identity_endpoint=True,
+ token_uri="token_uri",
+ )
+ with pytest.raises(ValueError):
+ credentials.IDTokenCredentials(
+ mock.Mock(),
+ "audience",
+ use_metadata_identity_endpoint=True,
+ signer=mock.Mock(),
+ )
+ with pytest.raises(ValueError):
+ credentials.IDTokenCredentials(
+ mock.Mock(),
+ "audience",
+ use_metadata_identity_endpoint=True,
+ additional_claims={"key", "value"},
+ )
+ with pytest.raises(ValueError):
+ credentials.IDTokenCredentials(
+ mock.Mock(),
+ "audience",
+ use_metadata_identity_endpoint=True,
+ service_account_email="foo@example.com",
+ )