feat: workload identity federation support (#686)
Using workload identity federation, applications can access Google Cloud resources from Amazon Web Services (AWS), Microsoft Azure or any identity provider that supports OpenID Connect (OIDC). Workload identity federation is recommended for non-Google Cloud environments as it avoids the need to download, manage and store service account private keys locally.
diff --git a/tests/data/external_subject_token.json b/tests/data/external_subject_token.json
new file mode 100644
index 0000000..a47ec34
--- /dev/null
+++ b/tests/data/external_subject_token.json
@@ -0,0 +1,3 @@
+{
+ "access_token": "HEADER.SIMULATED_JWT_PAYLOAD.SIGNATURE"
+}
\ No newline at end of file
diff --git a/tests/data/external_subject_token.txt b/tests/data/external_subject_token.txt
new file mode 100644
index 0000000..c668d8f
--- /dev/null
+++ b/tests/data/external_subject_token.txt
@@ -0,0 +1 @@
+HEADER.SIMULATED_JWT_PAYLOAD.SIGNATURE
\ No newline at end of file
diff --git a/tests/oauth2/test_sts.py b/tests/oauth2/test_sts.py
new file mode 100644
index 0000000..8792bd6
--- /dev/null
+++ b/tests/oauth2/test_sts.py
@@ -0,0 +1,395 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+import mock
+import pytest
+from six.moves import http_client
+from six.moves import urllib
+
+from google.auth import exceptions
+from google.auth import transport
+from google.oauth2 import sts
+from google.oauth2 import utils
+
+CLIENT_ID = "username"
+CLIENT_SECRET = "password"
+# Base64 encoding of "username:password"
+BASIC_AUTH_ENCODING = "dXNlcm5hbWU6cGFzc3dvcmQ="
+
+
+class TestStsClient(object):
+ GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange"
+ RESOURCE = "https://api.example.com/"
+ AUDIENCE = "urn:example:cooperation-context"
+ SCOPES = ["scope1", "scope2"]
+ REQUESTED_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token"
+ SUBJECT_TOKEN = "HEADER.SUBJECT_TOKEN_PAYLOAD.SIGNATURE"
+ SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt"
+ ACTOR_TOKEN = "HEADER.ACTOR_TOKEN_PAYLOAD.SIGNATURE"
+ ACTOR_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt"
+ TOKEN_EXCHANGE_ENDPOINT = "https://example.com/token.oauth2"
+ ADDON_HEADERS = {"x-client-version": "0.1.2"}
+ ADDON_OPTIONS = {"additional": {"non-standard": ["options"], "other": "some-value"}}
+ SUCCESS_RESPONSE = {
+ "access_token": "ACCESS_TOKEN",
+ "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "token_type": "Bearer",
+ "expires_in": 3600,
+ "scope": "scope1 scope2",
+ }
+ ERROR_RESPONSE = {
+ "error": "invalid_request",
+ "error_description": "Invalid subject token",
+ "error_uri": "https://tools.ietf.org/html/rfc6749",
+ }
+ CLIENT_AUTH_BASIC = utils.ClientAuthentication(
+ utils.ClientAuthType.basic, CLIENT_ID, CLIENT_SECRET
+ )
+ CLIENT_AUTH_REQUEST_BODY = utils.ClientAuthentication(
+ utils.ClientAuthType.request_body, CLIENT_ID, CLIENT_SECRET
+ )
+
+ @classmethod
+ def make_client(cls, client_auth=None):
+ return sts.Client(cls.TOKEN_EXCHANGE_ENDPOINT, client_auth)
+
+ @classmethod
+ def make_mock_request(cls, data, status=http_client.OK):
+ response = mock.create_autospec(transport.Response, instance=True)
+ response.status = status
+ response.data = json.dumps(data).encode("utf-8")
+
+ request = mock.create_autospec(transport.Request)
+ request.return_value = response
+
+ return request
+
+ @classmethod
+ def assert_request_kwargs(cls, request_kwargs, headers, request_data):
+ """Asserts the request was called with the expected parameters.
+ """
+ assert request_kwargs["url"] == cls.TOKEN_EXCHANGE_ENDPOINT
+ assert request_kwargs["method"] == "POST"
+ assert request_kwargs["headers"] == headers
+ assert request_kwargs["body"] is not None
+ body_tuples = urllib.parse.parse_qsl(request_kwargs["body"])
+ for (k, v) in body_tuples:
+ assert v.decode("utf-8") == request_data[k.decode("utf-8")]
+ assert len(body_tuples) == len(request_data.keys())
+
+ def test_exchange_token_full_success_without_auth(self):
+ """Test token exchange success without client authentication using full
+ parameters.
+ """
+ client = self.make_client()
+ headers = self.ADDON_HEADERS.copy()
+ headers["Content-Type"] = "application/x-www-form-urlencoded"
+ request_data = {
+ "grant_type": self.GRANT_TYPE,
+ "resource": self.RESOURCE,
+ "audience": self.AUDIENCE,
+ "scope": " ".join(self.SCOPES),
+ "requested_token_type": self.REQUESTED_TOKEN_TYPE,
+ "subject_token": self.SUBJECT_TOKEN,
+ "subject_token_type": self.SUBJECT_TOKEN_TYPE,
+ "actor_token": self.ACTOR_TOKEN,
+ "actor_token_type": self.ACTOR_TOKEN_TYPE,
+ "options": urllib.parse.quote(json.dumps(self.ADDON_OPTIONS)),
+ }
+ request = self.make_mock_request(
+ status=http_client.OK, data=self.SUCCESS_RESPONSE
+ )
+
+ response = client.exchange_token(
+ request,
+ self.GRANT_TYPE,
+ self.SUBJECT_TOKEN,
+ self.SUBJECT_TOKEN_TYPE,
+ self.RESOURCE,
+ self.AUDIENCE,
+ self.SCOPES,
+ self.REQUESTED_TOKEN_TYPE,
+ self.ACTOR_TOKEN,
+ self.ACTOR_TOKEN_TYPE,
+ self.ADDON_OPTIONS,
+ self.ADDON_HEADERS,
+ )
+
+ self.assert_request_kwargs(request.call_args.kwargs, headers, request_data)
+ assert response == self.SUCCESS_RESPONSE
+
+ def test_exchange_token_partial_success_without_auth(self):
+ """Test token exchange success without client authentication using
+ partial (required only) parameters.
+ """
+ client = self.make_client()
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
+ request_data = {
+ "grant_type": self.GRANT_TYPE,
+ "audience": self.AUDIENCE,
+ "requested_token_type": self.REQUESTED_TOKEN_TYPE,
+ "subject_token": self.SUBJECT_TOKEN,
+ "subject_token_type": self.SUBJECT_TOKEN_TYPE,
+ }
+ request = self.make_mock_request(
+ status=http_client.OK, data=self.SUCCESS_RESPONSE
+ )
+
+ response = client.exchange_token(
+ request,
+ grant_type=self.GRANT_TYPE,
+ subject_token=self.SUBJECT_TOKEN,
+ subject_token_type=self.SUBJECT_TOKEN_TYPE,
+ audience=self.AUDIENCE,
+ requested_token_type=self.REQUESTED_TOKEN_TYPE,
+ )
+
+ self.assert_request_kwargs(request.call_args.kwargs, headers, request_data)
+ assert response == self.SUCCESS_RESPONSE
+
+ def test_exchange_token_non200_without_auth(self):
+ """Test token exchange without client auth responding with non-200 status.
+ """
+ client = self.make_client()
+ request = self.make_mock_request(
+ status=http_client.BAD_REQUEST, data=self.ERROR_RESPONSE
+ )
+
+ with pytest.raises(exceptions.OAuthError) as excinfo:
+ client.exchange_token(
+ request,
+ self.GRANT_TYPE,
+ self.SUBJECT_TOKEN,
+ self.SUBJECT_TOKEN_TYPE,
+ self.RESOURCE,
+ self.AUDIENCE,
+ self.SCOPES,
+ self.REQUESTED_TOKEN_TYPE,
+ self.ACTOR_TOKEN,
+ self.ACTOR_TOKEN_TYPE,
+ self.ADDON_OPTIONS,
+ self.ADDON_HEADERS,
+ )
+
+ assert excinfo.match(
+ r"Error code invalid_request: Invalid subject token - https://tools.ietf.org/html/rfc6749"
+ )
+
+ def test_exchange_token_full_success_with_basic_auth(self):
+ """Test token exchange success with basic client authentication using full
+ parameters.
+ """
+ client = self.make_client(self.CLIENT_AUTH_BASIC)
+ headers = self.ADDON_HEADERS.copy()
+ headers["Content-Type"] = "application/x-www-form-urlencoded"
+ headers["Authorization"] = "Basic {}".format(BASIC_AUTH_ENCODING)
+ request_data = {
+ "grant_type": self.GRANT_TYPE,
+ "resource": self.RESOURCE,
+ "audience": self.AUDIENCE,
+ "scope": " ".join(self.SCOPES),
+ "requested_token_type": self.REQUESTED_TOKEN_TYPE,
+ "subject_token": self.SUBJECT_TOKEN,
+ "subject_token_type": self.SUBJECT_TOKEN_TYPE,
+ "actor_token": self.ACTOR_TOKEN,
+ "actor_token_type": self.ACTOR_TOKEN_TYPE,
+ "options": urllib.parse.quote(json.dumps(self.ADDON_OPTIONS)),
+ }
+ request = self.make_mock_request(
+ status=http_client.OK, data=self.SUCCESS_RESPONSE
+ )
+
+ response = client.exchange_token(
+ request,
+ self.GRANT_TYPE,
+ self.SUBJECT_TOKEN,
+ self.SUBJECT_TOKEN_TYPE,
+ self.RESOURCE,
+ self.AUDIENCE,
+ self.SCOPES,
+ self.REQUESTED_TOKEN_TYPE,
+ self.ACTOR_TOKEN,
+ self.ACTOR_TOKEN_TYPE,
+ self.ADDON_OPTIONS,
+ self.ADDON_HEADERS,
+ )
+
+ self.assert_request_kwargs(request.call_args.kwargs, headers, request_data)
+ assert response == self.SUCCESS_RESPONSE
+
+ def test_exchange_token_partial_success_with_basic_auth(self):
+ """Test token exchange success with basic client authentication using
+ partial (required only) parameters.
+ """
+ client = self.make_client(self.CLIENT_AUTH_BASIC)
+ headers = {
+ "Content-Type": "application/x-www-form-urlencoded",
+ "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING),
+ }
+ request_data = {
+ "grant_type": self.GRANT_TYPE,
+ "audience": self.AUDIENCE,
+ "requested_token_type": self.REQUESTED_TOKEN_TYPE,
+ "subject_token": self.SUBJECT_TOKEN,
+ "subject_token_type": self.SUBJECT_TOKEN_TYPE,
+ }
+ request = self.make_mock_request(
+ status=http_client.OK, data=self.SUCCESS_RESPONSE
+ )
+
+ response = client.exchange_token(
+ request,
+ grant_type=self.GRANT_TYPE,
+ subject_token=self.SUBJECT_TOKEN,
+ subject_token_type=self.SUBJECT_TOKEN_TYPE,
+ audience=self.AUDIENCE,
+ requested_token_type=self.REQUESTED_TOKEN_TYPE,
+ )
+
+ self.assert_request_kwargs(request.call_args.kwargs, headers, request_data)
+ assert response == self.SUCCESS_RESPONSE
+
+ def test_exchange_token_non200_with_basic_auth(self):
+ """Test token exchange with basic client auth responding with non-200
+ status.
+ """
+ client = self.make_client(self.CLIENT_AUTH_BASIC)
+ request = self.make_mock_request(
+ status=http_client.BAD_REQUEST, data=self.ERROR_RESPONSE
+ )
+
+ with pytest.raises(exceptions.OAuthError) as excinfo:
+ client.exchange_token(
+ request,
+ self.GRANT_TYPE,
+ self.SUBJECT_TOKEN,
+ self.SUBJECT_TOKEN_TYPE,
+ self.RESOURCE,
+ self.AUDIENCE,
+ self.SCOPES,
+ self.REQUESTED_TOKEN_TYPE,
+ self.ACTOR_TOKEN,
+ self.ACTOR_TOKEN_TYPE,
+ self.ADDON_OPTIONS,
+ self.ADDON_HEADERS,
+ )
+
+ assert excinfo.match(
+ r"Error code invalid_request: Invalid subject token - https://tools.ietf.org/html/rfc6749"
+ )
+
+ def test_exchange_token_full_success_with_reqbody_auth(self):
+ """Test token exchange success with request body client authenticaiton
+ using full parameters.
+ """
+ client = self.make_client(self.CLIENT_AUTH_REQUEST_BODY)
+ headers = self.ADDON_HEADERS.copy()
+ headers["Content-Type"] = "application/x-www-form-urlencoded"
+ request_data = {
+ "grant_type": self.GRANT_TYPE,
+ "resource": self.RESOURCE,
+ "audience": self.AUDIENCE,
+ "scope": " ".join(self.SCOPES),
+ "requested_token_type": self.REQUESTED_TOKEN_TYPE,
+ "subject_token": self.SUBJECT_TOKEN,
+ "subject_token_type": self.SUBJECT_TOKEN_TYPE,
+ "actor_token": self.ACTOR_TOKEN,
+ "actor_token_type": self.ACTOR_TOKEN_TYPE,
+ "options": urllib.parse.quote(json.dumps(self.ADDON_OPTIONS)),
+ "client_id": CLIENT_ID,
+ "client_secret": CLIENT_SECRET,
+ }
+ request = self.make_mock_request(
+ status=http_client.OK, data=self.SUCCESS_RESPONSE
+ )
+
+ response = client.exchange_token(
+ request,
+ self.GRANT_TYPE,
+ self.SUBJECT_TOKEN,
+ self.SUBJECT_TOKEN_TYPE,
+ self.RESOURCE,
+ self.AUDIENCE,
+ self.SCOPES,
+ self.REQUESTED_TOKEN_TYPE,
+ self.ACTOR_TOKEN,
+ self.ACTOR_TOKEN_TYPE,
+ self.ADDON_OPTIONS,
+ self.ADDON_HEADERS,
+ )
+
+ self.assert_request_kwargs(request.call_args.kwargs, headers, request_data)
+ assert response == self.SUCCESS_RESPONSE
+
+ def test_exchange_token_partial_success_with_reqbody_auth(self):
+ """Test token exchange success with request body client authentication
+ using partial (required only) parameters.
+ """
+ client = self.make_client(self.CLIENT_AUTH_REQUEST_BODY)
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
+ request_data = {
+ "grant_type": self.GRANT_TYPE,
+ "audience": self.AUDIENCE,
+ "requested_token_type": self.REQUESTED_TOKEN_TYPE,
+ "subject_token": self.SUBJECT_TOKEN,
+ "subject_token_type": self.SUBJECT_TOKEN_TYPE,
+ "client_id": CLIENT_ID,
+ "client_secret": CLIENT_SECRET,
+ }
+ request = self.make_mock_request(
+ status=http_client.OK, data=self.SUCCESS_RESPONSE
+ )
+
+ response = client.exchange_token(
+ request,
+ grant_type=self.GRANT_TYPE,
+ subject_token=self.SUBJECT_TOKEN,
+ subject_token_type=self.SUBJECT_TOKEN_TYPE,
+ audience=self.AUDIENCE,
+ requested_token_type=self.REQUESTED_TOKEN_TYPE,
+ )
+
+ self.assert_request_kwargs(request.call_args.kwargs, headers, request_data)
+ assert response == self.SUCCESS_RESPONSE
+
+ def test_exchange_token_non200_with_reqbody_auth(self):
+ """Test token exchange with POST request body client auth responding
+ with non-200 status.
+ """
+ client = self.make_client(self.CLIENT_AUTH_REQUEST_BODY)
+ request = self.make_mock_request(
+ status=http_client.BAD_REQUEST, data=self.ERROR_RESPONSE
+ )
+
+ with pytest.raises(exceptions.OAuthError) as excinfo:
+ client.exchange_token(
+ request,
+ self.GRANT_TYPE,
+ self.SUBJECT_TOKEN,
+ self.SUBJECT_TOKEN_TYPE,
+ self.RESOURCE,
+ self.AUDIENCE,
+ self.SCOPES,
+ self.REQUESTED_TOKEN_TYPE,
+ self.ACTOR_TOKEN,
+ self.ACTOR_TOKEN_TYPE,
+ self.ADDON_OPTIONS,
+ self.ADDON_HEADERS,
+ )
+
+ assert excinfo.match(
+ r"Error code invalid_request: Invalid subject token - https://tools.ietf.org/html/rfc6749"
+ )
diff --git a/tests/oauth2/test_utils.py b/tests/oauth2/test_utils.py
new file mode 100644
index 0000000..6de9ff5
--- /dev/null
+++ b/tests/oauth2/test_utils.py
@@ -0,0 +1,264 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+import pytest
+
+from google.auth import exceptions
+from google.oauth2 import utils
+
+
+CLIENT_ID = "username"
+CLIENT_SECRET = "password"
+# Base64 encoding of "username:password"
+BASIC_AUTH_ENCODING = "dXNlcm5hbWU6cGFzc3dvcmQ="
+# Base64 encoding of "username:"
+BASIC_AUTH_ENCODING_SECRETLESS = "dXNlcm5hbWU6"
+
+
+class AuthHandler(utils.OAuthClientAuthHandler):
+ def __init__(self, client_auth=None):
+ super(AuthHandler, self).__init__(client_auth)
+
+ def apply_client_authentication_options(
+ self, headers, request_body=None, bearer_token=None
+ ):
+ return super(AuthHandler, self).apply_client_authentication_options(
+ headers, request_body, bearer_token
+ )
+
+
+class TestClientAuthentication(object):
+ @classmethod
+ def make_client_auth(cls, client_secret=None):
+ return utils.ClientAuthentication(
+ utils.ClientAuthType.basic, CLIENT_ID, client_secret
+ )
+
+ def test_initialization_with_client_secret(self):
+ client_auth = self.make_client_auth(CLIENT_SECRET)
+
+ assert client_auth.client_auth_type == utils.ClientAuthType.basic
+ assert client_auth.client_id == CLIENT_ID
+ assert client_auth.client_secret == CLIENT_SECRET
+
+ def test_initialization_no_client_secret(self):
+ client_auth = self.make_client_auth()
+
+ assert client_auth.client_auth_type == utils.ClientAuthType.basic
+ assert client_auth.client_id == CLIENT_ID
+ assert client_auth.client_secret is None
+
+
+class TestOAuthClientAuthHandler(object):
+ CLIENT_AUTH_BASIC = utils.ClientAuthentication(
+ utils.ClientAuthType.basic, CLIENT_ID, CLIENT_SECRET
+ )
+ CLIENT_AUTH_BASIC_SECRETLESS = utils.ClientAuthentication(
+ utils.ClientAuthType.basic, CLIENT_ID
+ )
+ CLIENT_AUTH_REQUEST_BODY = utils.ClientAuthentication(
+ utils.ClientAuthType.request_body, CLIENT_ID, CLIENT_SECRET
+ )
+ CLIENT_AUTH_REQUEST_BODY_SECRETLESS = utils.ClientAuthentication(
+ utils.ClientAuthType.request_body, CLIENT_ID
+ )
+
+ @classmethod
+ def make_oauth_client_auth_handler(cls, client_auth=None):
+ return AuthHandler(client_auth)
+
+ def test_apply_client_authentication_options_none(self):
+ headers = {"Content-Type": "application/json"}
+ request_body = {"foo": "bar"}
+ auth_handler = self.make_oauth_client_auth_handler()
+
+ auth_handler.apply_client_authentication_options(headers, request_body)
+
+ assert headers == {"Content-Type": "application/json"}
+ assert request_body == {"foo": "bar"}
+
+ def test_apply_client_authentication_options_basic(self):
+ headers = {"Content-Type": "application/json"}
+ request_body = {"foo": "bar"}
+ auth_handler = self.make_oauth_client_auth_handler(self.CLIENT_AUTH_BASIC)
+
+ auth_handler.apply_client_authentication_options(headers, request_body)
+
+ assert headers == {
+ "Content-Type": "application/json",
+ "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING),
+ }
+ assert request_body == {"foo": "bar"}
+
+ def test_apply_client_authentication_options_basic_nosecret(self):
+ headers = {"Content-Type": "application/json"}
+ request_body = {"foo": "bar"}
+ auth_handler = self.make_oauth_client_auth_handler(
+ self.CLIENT_AUTH_BASIC_SECRETLESS
+ )
+
+ auth_handler.apply_client_authentication_options(headers, request_body)
+
+ assert headers == {
+ "Content-Type": "application/json",
+ "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING_SECRETLESS),
+ }
+ assert request_body == {"foo": "bar"}
+
+ def test_apply_client_authentication_options_request_body(self):
+ headers = {"Content-Type": "application/json"}
+ request_body = {"foo": "bar"}
+ auth_handler = self.make_oauth_client_auth_handler(
+ self.CLIENT_AUTH_REQUEST_BODY
+ )
+
+ auth_handler.apply_client_authentication_options(headers, request_body)
+
+ assert headers == {"Content-Type": "application/json"}
+ assert request_body == {
+ "foo": "bar",
+ "client_id": CLIENT_ID,
+ "client_secret": CLIENT_SECRET,
+ }
+
+ def test_apply_client_authentication_options_request_body_nosecret(self):
+ headers = {"Content-Type": "application/json"}
+ request_body = {"foo": "bar"}
+ auth_handler = self.make_oauth_client_auth_handler(
+ self.CLIENT_AUTH_REQUEST_BODY_SECRETLESS
+ )
+
+ auth_handler.apply_client_authentication_options(headers, request_body)
+
+ assert headers == {"Content-Type": "application/json"}
+ assert request_body == {
+ "foo": "bar",
+ "client_id": CLIENT_ID,
+ "client_secret": "",
+ }
+
+ def test_apply_client_authentication_options_request_body_no_body(self):
+ headers = {"Content-Type": "application/json"}
+ auth_handler = self.make_oauth_client_auth_handler(
+ self.CLIENT_AUTH_REQUEST_BODY
+ )
+
+ with pytest.raises(exceptions.OAuthError) as excinfo:
+ auth_handler.apply_client_authentication_options(headers)
+
+ assert excinfo.match(r"HTTP request does not support request-body")
+
+ def test_apply_client_authentication_options_bearer_token(self):
+ bearer_token = "ACCESS_TOKEN"
+ headers = {"Content-Type": "application/json"}
+ request_body = {"foo": "bar"}
+ auth_handler = self.make_oauth_client_auth_handler()
+
+ auth_handler.apply_client_authentication_options(
+ headers, request_body, bearer_token
+ )
+
+ assert headers == {
+ "Content-Type": "application/json",
+ "Authorization": "Bearer {}".format(bearer_token),
+ }
+ assert request_body == {"foo": "bar"}
+
+ def test_apply_client_authentication_options_bearer_and_basic(self):
+ bearer_token = "ACCESS_TOKEN"
+ headers = {"Content-Type": "application/json"}
+ request_body = {"foo": "bar"}
+ auth_handler = self.make_oauth_client_auth_handler(self.CLIENT_AUTH_BASIC)
+
+ auth_handler.apply_client_authentication_options(
+ headers, request_body, bearer_token
+ )
+
+ # Bearer token should have higher priority.
+ assert headers == {
+ "Content-Type": "application/json",
+ "Authorization": "Bearer {}".format(bearer_token),
+ }
+ assert request_body == {"foo": "bar"}
+
+ def test_apply_client_authentication_options_bearer_and_request_body(self):
+ bearer_token = "ACCESS_TOKEN"
+ headers = {"Content-Type": "application/json"}
+ request_body = {"foo": "bar"}
+ auth_handler = self.make_oauth_client_auth_handler(
+ self.CLIENT_AUTH_REQUEST_BODY
+ )
+
+ auth_handler.apply_client_authentication_options(
+ headers, request_body, bearer_token
+ )
+
+ # Bearer token should have higher priority.
+ assert headers == {
+ "Content-Type": "application/json",
+ "Authorization": "Bearer {}".format(bearer_token),
+ }
+ assert request_body == {"foo": "bar"}
+
+
+def test__handle_error_response_code_only():
+ error_resp = {"error": "unsupported_grant_type"}
+ response_data = json.dumps(error_resp)
+
+ with pytest.raises(exceptions.OAuthError) as excinfo:
+ utils.handle_error_response(response_data)
+
+ assert excinfo.match(r"Error code unsupported_grant_type")
+
+
+def test__handle_error_response_code_description():
+ error_resp = {
+ "error": "unsupported_grant_type",
+ "error_description": "The provided grant_type is unsupported",
+ }
+ response_data = json.dumps(error_resp)
+
+ with pytest.raises(exceptions.OAuthError) as excinfo:
+ utils.handle_error_response(response_data)
+
+ assert excinfo.match(
+ r"Error code unsupported_grant_type: The provided grant_type is unsupported"
+ )
+
+
+def test__handle_error_response_code_description_uri():
+ error_resp = {
+ "error": "unsupported_grant_type",
+ "error_description": "The provided grant_type is unsupported",
+ "error_uri": "https://tools.ietf.org/html/rfc6749",
+ }
+ response_data = json.dumps(error_resp)
+
+ with pytest.raises(exceptions.OAuthError) as excinfo:
+ utils.handle_error_response(response_data)
+
+ assert excinfo.match(
+ r"Error code unsupported_grant_type: The provided grant_type is unsupported - https://tools.ietf.org/html/rfc6749"
+ )
+
+
+def test__handle_error_response_non_json():
+ response_data = "Oops, something wrong happened"
+
+ with pytest.raises(exceptions.OAuthError) as excinfo:
+ utils.handle_error_response(response_data)
+
+ assert excinfo.match(r"Oops, something wrong happened")
diff --git a/tests/test__default.py b/tests/test__default.py
index 74511f9..ef6cb78 100644
--- a/tests/test__default.py
+++ b/tests/test__default.py
@@ -20,10 +20,13 @@
from google.auth import _default
from google.auth import app_engine
+from google.auth import aws
from google.auth import compute_engine
from google.auth import credentials
from google.auth import environment_vars
from google.auth import exceptions
+from google.auth import external_account
+from google.auth import identity_pool
from google.oauth2 import service_account
import google.oauth2.credentials
@@ -49,6 +52,34 @@
with open(SERVICE_ACCOUNT_FILE) as fh:
SERVICE_ACCOUNT_FILE_DATA = json.load(fh)
+SUBJECT_TOKEN_TEXT_FILE = os.path.join(DATA_DIR, "external_subject_token.txt")
+TOKEN_URL = "https://sts.googleapis.com/v1/token"
+AUDIENCE = "//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/POOL_ID/providers/PROVIDER_ID"
+REGION_URL = "http://169.254.169.254/latest/meta-data/placement/availability-zone"
+SECURITY_CREDS_URL = "http://169.254.169.254/latest/meta-data/iam/security-credentials"
+CRED_VERIFICATION_URL = (
+ "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15"
+)
+IDENTITY_POOL_DATA = {
+ "type": "external_account",
+ "audience": AUDIENCE,
+ "subject_token_type": "urn:ietf:params:oauth:token-type:jwt",
+ "token_url": TOKEN_URL,
+ "credential_source": {"file": SUBJECT_TOKEN_TEXT_FILE},
+}
+AWS_DATA = {
+ "type": "external_account",
+ "audience": AUDIENCE,
+ "subject_token_type": "urn:ietf:params:aws:token-type:aws4_request",
+ "token_url": TOKEN_URL,
+ "credential_source": {
+ "environment_id": "aws1",
+ "region_url": REGION_URL,
+ "url": SECURITY_CREDS_URL,
+ "regional_cred_verification_url": CRED_VERIFICATION_URL,
+ },
+}
+
MOCK_CREDENTIALS = mock.Mock(spec=credentials.CredentialsWithQuotaProject)
MOCK_CREDENTIALS.with_quota_project.return_value = MOCK_CREDENTIALS
@@ -57,6 +88,12 @@
return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id),
autospec=True,
)
+EXTERNAL_ACCOUNT_GET_PROJECT_ID_PATCH = mock.patch.object(
+ external_account.Credentials,
+ "get_project_id",
+ return_value=mock.sentinel.project_id,
+ autospec=True,
+)
def test_load_credentials_from_missing_file():
@@ -185,6 +222,92 @@
assert excinfo.match(r"missing fields")
+@EXTERNAL_ACCOUNT_GET_PROJECT_ID_PATCH
+def test_load_credentials_from_file_external_account_identity_pool(
+ get_project_id, tmpdir
+):
+ config_file = tmpdir.join("config.json")
+ config_file.write(json.dumps(IDENTITY_POOL_DATA))
+ credentials, project_id = _default.load_credentials_from_file(str(config_file))
+
+ assert isinstance(credentials, identity_pool.Credentials)
+ assert project_id is mock.sentinel.project_id
+ assert get_project_id.called
+
+
+@EXTERNAL_ACCOUNT_GET_PROJECT_ID_PATCH
+def test_load_credentials_from_file_external_account_aws(get_project_id, tmpdir):
+ config_file = tmpdir.join("config.json")
+ config_file.write(json.dumps(AWS_DATA))
+ credentials, project_id = _default.load_credentials_from_file(str(config_file))
+
+ assert isinstance(credentials, aws.Credentials)
+ assert project_id is mock.sentinel.project_id
+ assert get_project_id.called
+
+
+@EXTERNAL_ACCOUNT_GET_PROJECT_ID_PATCH
+def test_load_credentials_from_file_external_account_with_user_and_default_scopes(
+ get_project_id, tmpdir
+):
+ config_file = tmpdir.join("config.json")
+ config_file.write(json.dumps(IDENTITY_POOL_DATA))
+ credentials, project_id = _default.load_credentials_from_file(
+ str(config_file),
+ scopes=["https://www.google.com/calendar/feeds"],
+ default_scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+
+ assert isinstance(credentials, identity_pool.Credentials)
+ assert project_id is mock.sentinel.project_id
+ assert credentials.scopes == ["https://www.google.com/calendar/feeds"]
+ assert credentials.default_scopes == [
+ "https://www.googleapis.com/auth/cloud-platform"
+ ]
+
+
+@EXTERNAL_ACCOUNT_GET_PROJECT_ID_PATCH
+def test_load_credentials_from_file_external_account_with_quota_project(
+ get_project_id, tmpdir
+):
+ config_file = tmpdir.join("config.json")
+ config_file.write(json.dumps(IDENTITY_POOL_DATA))
+ credentials, project_id = _default.load_credentials_from_file(
+ str(config_file), quota_project_id="project-foo"
+ )
+
+ assert isinstance(credentials, identity_pool.Credentials)
+ assert project_id is mock.sentinel.project_id
+ assert credentials.quota_project_id == "project-foo"
+
+
+def test_load_credentials_from_file_external_account_bad_format(tmpdir):
+ filename = tmpdir.join("external_account_bad.json")
+ filename.write(json.dumps({"type": "external_account"}))
+
+ with pytest.raises(exceptions.DefaultCredentialsError) as excinfo:
+ _default.load_credentials_from_file(str(filename))
+
+ assert excinfo.match(
+ "Failed to load external account credentials from {}".format(str(filename))
+ )
+
+
+@EXTERNAL_ACCOUNT_GET_PROJECT_ID_PATCH
+def test_load_credentials_from_file_external_account_explicit_request(
+ get_project_id, tmpdir
+):
+ config_file = tmpdir.join("config.json")
+ config_file.write(json.dumps(IDENTITY_POOL_DATA))
+ credentials, project_id = _default.load_credentials_from_file(
+ str(config_file), request=mock.sentinel.request
+ )
+
+ assert isinstance(credentials, identity_pool.Credentials)
+ assert project_id is mock.sentinel.project_id
+ get_project_id.assert_called_with(credentials, request=mock.sentinel.request)
+
+
@mock.patch.dict(os.environ, {}, clear=True)
def test__get_explicit_environ_credentials_no_env():
assert _default._get_explicit_environ_credentials() == (None, None)
@@ -198,7 +321,34 @@
assert credentials is MOCK_CREDENTIALS
assert project_id is mock.sentinel.project_id
- load.assert_called_with("filename")
+ load.assert_called_with(
+ "filename",
+ scopes=None,
+ default_scopes=None,
+ quota_project_id=None,
+ request=None,
+ )
+
+
+@LOAD_FILE_PATCH
+def test__get_explicit_environ_credentials_with_scopes_and_request(load, monkeypatch):
+ scopes = ["one", "two"]
+ monkeypatch.setenv(environment_vars.CREDENTIALS, "filename")
+
+ credentials, project_id = _default._get_explicit_environ_credentials(
+ request=mock.sentinel.request, scopes=scopes
+ )
+
+ assert credentials is MOCK_CREDENTIALS
+ assert project_id is mock.sentinel.project_id
+ # Request and scopes should be propagated.
+ load.assert_called_with(
+ "filename",
+ scopes=scopes,
+ default_scopes=None,
+ quota_project_id=None,
+ request=mock.sentinel.request,
+ )
@LOAD_FILE_PATCH
@@ -503,3 +653,70 @@
sys.modules["google.auth.compute_engine"] = None
sys.modules["google.auth.app_engine"] = None
assert _default.default() == (MOCK_CREDENTIALS, mock.sentinel.project_id)
+
+
+@EXTERNAL_ACCOUNT_GET_PROJECT_ID_PATCH
+def test_default_environ_external_credentials(get_project_id, monkeypatch, tmpdir):
+ config_file = tmpdir.join("config.json")
+ config_file.write(json.dumps(IDENTITY_POOL_DATA))
+ monkeypatch.setenv(environment_vars.CREDENTIALS, str(config_file))
+
+ credentials, project_id = _default.default()
+
+ assert isinstance(credentials, identity_pool.Credentials)
+ assert project_id is mock.sentinel.project_id
+
+
+@EXTERNAL_ACCOUNT_GET_PROJECT_ID_PATCH
+def test_default_environ_external_credentials_with_user_and_default_scopes_and_quota_project_id(
+ get_project_id, monkeypatch, tmpdir
+):
+ config_file = tmpdir.join("config.json")
+ config_file.write(json.dumps(IDENTITY_POOL_DATA))
+ monkeypatch.setenv(environment_vars.CREDENTIALS, str(config_file))
+
+ credentials, project_id = _default.default(
+ scopes=["https://www.google.com/calendar/feeds"],
+ default_scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ quota_project_id="project-foo",
+ )
+
+ assert isinstance(credentials, identity_pool.Credentials)
+ assert project_id is mock.sentinel.project_id
+ assert credentials.quota_project_id == "project-foo"
+ assert credentials.scopes == ["https://www.google.com/calendar/feeds"]
+ assert credentials.default_scopes == [
+ "https://www.googleapis.com/auth/cloud-platform"
+ ]
+
+
+@EXTERNAL_ACCOUNT_GET_PROJECT_ID_PATCH
+def test_default_environ_external_credentials_explicit_request(
+ get_project_id, monkeypatch, tmpdir
+):
+ config_file = tmpdir.join("config.json")
+ config_file.write(json.dumps(IDENTITY_POOL_DATA))
+ monkeypatch.setenv(environment_vars.CREDENTIALS, str(config_file))
+
+ credentials, project_id = _default.default(request=mock.sentinel.request)
+
+ assert isinstance(credentials, identity_pool.Credentials)
+ assert project_id is mock.sentinel.project_id
+ # default() will initialize new credentials via with_scopes_if_required
+ # and potentially with_quota_project.
+ # As a result the caller of get_project_id() will not match the returned
+ # credentials.
+ get_project_id.assert_called_with(mock.ANY, request=mock.sentinel.request)
+
+
+def test_default_environ_external_credentials_bad_format(monkeypatch, tmpdir):
+ filename = tmpdir.join("external_account_bad.json")
+ filename.write(json.dumps({"type": "external_account"}))
+ monkeypatch.setenv(environment_vars.CREDENTIALS, str(filename))
+
+ with pytest.raises(exceptions.DefaultCredentialsError) as excinfo:
+ _default.default()
+
+ assert excinfo.match(
+ "Failed to load external account credentials from {}".format(str(filename))
+ )
diff --git a/tests/test_aws.py b/tests/test_aws.py
new file mode 100644
index 0000000..9a8f98e
--- /dev/null
+++ b/tests/test_aws.py
@@ -0,0 +1,1434 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import json
+
+import mock
+import pytest
+from six.moves import http_client
+from six.moves import urllib
+
+from google.auth import _helpers
+from google.auth import aws
+from google.auth import environment_vars
+from google.auth import exceptions
+from google.auth import transport
+
+
+CLIENT_ID = "username"
+CLIENT_SECRET = "password"
+# Base64 encoding of "username:password".
+BASIC_AUTH_ENCODING = "dXNlcm5hbWU6cGFzc3dvcmQ="
+SERVICE_ACCOUNT_EMAIL = "service-1234@service-name.iam.gserviceaccount.com"
+SERVICE_ACCOUNT_IMPERSONATION_URL = (
+ "https://us-east1-iamcredentials.googleapis.com/v1/projects/-"
+ + "/serviceAccounts/{}:generateAccessToken".format(SERVICE_ACCOUNT_EMAIL)
+)
+QUOTA_PROJECT_ID = "QUOTA_PROJECT_ID"
+SCOPES = ["scope1", "scope2"]
+TOKEN_URL = "https://sts.googleapis.com/v1/token"
+SUBJECT_TOKEN_TYPE = "urn:ietf:params:aws:token-type:aws4_request"
+AUDIENCE = "//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/POOL_ID/providers/PROVIDER_ID"
+REGION_URL = "http://169.254.169.254/latest/meta-data/placement/availability-zone"
+SECURITY_CREDS_URL = "http://169.254.169.254/latest/meta-data/iam/security-credentials"
+CRED_VERIFICATION_URL = (
+ "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15"
+)
+# Sample AWS security credentials to be used with tests that require a session token.
+ACCESS_KEY_ID = "ASIARD4OQDT6A77FR3CL"
+SECRET_ACCESS_KEY = "Y8AfSaucF37G4PpvfguKZ3/l7Id4uocLXxX0+VTx"
+TOKEN = "IQoJb3JpZ2luX2VjEIz//////////wEaCXVzLWVhc3QtMiJGMEQCIH7MHX/Oy/OB8OlLQa9GrqU1B914+iMikqWQW7vPCKlgAiA/Lsv8Jcafn14owfxXn95FURZNKaaphj0ykpmS+Ki+CSq0AwhlEAAaDDA3NzA3MTM5MTk5NiIMx9sAeP1ovlMTMKLjKpEDwuJQg41/QUKx0laTZYjPlQvjwSqS3OB9P1KAXPWSLkliVMMqaHqelvMF/WO/glv3KwuTfQsavRNs3v5pcSEm4SPO3l7mCs7KrQUHwGP0neZhIKxEXy+Ls//1C/Bqt53NL+LSbaGv6RPHaX82laz2qElphg95aVLdYgIFY6JWV5fzyjgnhz0DQmy62/Vi8pNcM2/VnxeCQ8CC8dRDSt52ry2v+nc77vstuI9xV5k8mPtnaPoJDRANh0bjwY5Sdwkbp+mGRUJBAQRlNgHUJusefXQgVKBCiyJY4w3Csd8Bgj9IyDV+Azuy1jQqfFZWgP68LSz5bURyIjlWDQunO82stZ0BgplKKAa/KJHBPCp8Qi6i99uy7qh76FQAqgVTsnDuU6fGpHDcsDSGoCls2HgZjZFPeOj8mmRhFk1Xqvkbjuz8V1cJk54d3gIJvQt8gD2D6yJQZecnuGWd5K2e2HohvCc8Fc9kBl1300nUJPV+k4tr/A5R/0QfEKOZL1/k5lf1g9CREnrM8LVkGxCgdYMxLQow1uTL+QU67AHRRSp5PhhGX4Rek+01vdYSnJCMaPhSEgcLqDlQkhk6MPsyT91QMXcWmyO+cAZwUPwnRamFepuP4K8k2KVXs/LIJHLELwAZ0ekyaS7CptgOqS7uaSTFG3U+vzFZLEnGvWQ7y9IPNQZ+Dffgh4p3vF4J68y9049sI6Sr5d5wbKkcbm8hdCDHZcv4lnqohquPirLiFQ3q7B17V9krMPu3mz1cg4Ekgcrn/E09NTsxAqD8NcZ7C7ECom9r+X3zkDOxaajW6hu3Az8hGlyylDaMiFfRbBJpTIlxp7jfa7CxikNgNtEKLH9iCzvuSg2vhA=="
+# To avoid json.dumps() differing behavior from one version to other,
+# the JSON payload is hardcoded.
+REQUEST_PARAMS = '{"KeySchema":[{"KeyType":"HASH","AttributeName":"Id"}],"TableName":"TestTable","AttributeDefinitions":[{"AttributeName":"Id","AttributeType":"S"}],"ProvisionedThroughput":{"WriteCapacityUnits":5,"ReadCapacityUnits":5}}'
+# Each tuple contains the following entries:
+# region, time, credentials, original_request, signed_request
+TEST_FIXTURES = [
+ # GET request (AWS botocore tests).
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-vanilla.req
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-vanilla.sreq
+ (
+ "us-east-1",
+ "2011-09-09T23:36:00Z",
+ {
+ "access_key_id": "AKIDEXAMPLE",
+ "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
+ },
+ {
+ "method": "GET",
+ "url": "https://host.foo.com",
+ "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT"},
+ },
+ {
+ "url": "https://host.foo.com",
+ "method": "GET",
+ "headers": {
+ "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=b27ccfbfa7df52a200ff74193ca6e32d4b48b8856fab7ebf1c595d0670a7e470",
+ "host": "host.foo.com",
+ "date": "Mon, 09 Sep 2011 23:36:00 GMT",
+ },
+ },
+ ),
+ # GET request with relative path (AWS botocore tests).
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-relative-relative.req
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-relative-relative.sreq
+ (
+ "us-east-1",
+ "2011-09-09T23:36:00Z",
+ {
+ "access_key_id": "AKIDEXAMPLE",
+ "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
+ },
+ {
+ "method": "GET",
+ "url": "https://host.foo.com/foo/bar/../..",
+ "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT"},
+ },
+ {
+ "url": "https://host.foo.com/foo/bar/../..",
+ "method": "GET",
+ "headers": {
+ "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=b27ccfbfa7df52a200ff74193ca6e32d4b48b8856fab7ebf1c595d0670a7e470",
+ "host": "host.foo.com",
+ "date": "Mon, 09 Sep 2011 23:36:00 GMT",
+ },
+ },
+ ),
+ # GET request with /./ path (AWS botocore tests).
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-slash-dot-slash.req
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-slash-dot-slash.sreq
+ (
+ "us-east-1",
+ "2011-09-09T23:36:00Z",
+ {
+ "access_key_id": "AKIDEXAMPLE",
+ "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
+ },
+ {
+ "method": "GET",
+ "url": "https://host.foo.com/./",
+ "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT"},
+ },
+ {
+ "url": "https://host.foo.com/./",
+ "method": "GET",
+ "headers": {
+ "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=b27ccfbfa7df52a200ff74193ca6e32d4b48b8856fab7ebf1c595d0670a7e470",
+ "host": "host.foo.com",
+ "date": "Mon, 09 Sep 2011 23:36:00 GMT",
+ },
+ },
+ ),
+ # GET request with pointless dot path (AWS botocore tests).
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-slash-pointless-dot.req
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-slash-pointless-dot.sreq
+ (
+ "us-east-1",
+ "2011-09-09T23:36:00Z",
+ {
+ "access_key_id": "AKIDEXAMPLE",
+ "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
+ },
+ {
+ "method": "GET",
+ "url": "https://host.foo.com/./foo",
+ "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT"},
+ },
+ {
+ "url": "https://host.foo.com/./foo",
+ "method": "GET",
+ "headers": {
+ "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=910e4d6c9abafaf87898e1eb4c929135782ea25bb0279703146455745391e63a",
+ "host": "host.foo.com",
+ "date": "Mon, 09 Sep 2011 23:36:00 GMT",
+ },
+ },
+ ),
+ # GET request with utf8 path (AWS botocore tests).
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-utf8.req
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-utf8.sreq
+ (
+ "us-east-1",
+ "2011-09-09T23:36:00Z",
+ {
+ "access_key_id": "AKIDEXAMPLE",
+ "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
+ },
+ {
+ "method": "GET",
+ "url": "https://host.foo.com/%E1%88%B4",
+ "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT"},
+ },
+ {
+ "url": "https://host.foo.com/%E1%88%B4",
+ "method": "GET",
+ "headers": {
+ "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=8d6634c189aa8c75c2e51e106b6b5121bed103fdb351f7d7d4381c738823af74",
+ "host": "host.foo.com",
+ "date": "Mon, 09 Sep 2011 23:36:00 GMT",
+ },
+ },
+ ),
+ # GET request with duplicate query key (AWS botocore tests).
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-vanilla-query-order-key-case.req
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-vanilla-query-order-key-case.sreq
+ (
+ "us-east-1",
+ "2011-09-09T23:36:00Z",
+ {
+ "access_key_id": "AKIDEXAMPLE",
+ "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
+ },
+ {
+ "method": "GET",
+ "url": "https://host.foo.com/?foo=Zoo&foo=aha",
+ "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT"},
+ },
+ {
+ "url": "https://host.foo.com/?foo=Zoo&foo=aha",
+ "method": "GET",
+ "headers": {
+ "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=be7148d34ebccdc6423b19085378aa0bee970bdc61d144bd1a8c48c33079ab09",
+ "host": "host.foo.com",
+ "date": "Mon, 09 Sep 2011 23:36:00 GMT",
+ },
+ },
+ ),
+ # GET request with duplicate out of order query key (AWS botocore tests).
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-vanilla-query-order-value.req
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-vanilla-query-order-value.sreq
+ (
+ "us-east-1",
+ "2011-09-09T23:36:00Z",
+ {
+ "access_key_id": "AKIDEXAMPLE",
+ "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
+ },
+ {
+ "method": "GET",
+ "url": "https://host.foo.com/?foo=b&foo=a",
+ "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT"},
+ },
+ {
+ "url": "https://host.foo.com/?foo=b&foo=a",
+ "method": "GET",
+ "headers": {
+ "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=feb926e49e382bec75c9d7dcb2a1b6dc8aa50ca43c25d2bc51143768c0875acc",
+ "host": "host.foo.com",
+ "date": "Mon, 09 Sep 2011 23:36:00 GMT",
+ },
+ },
+ ),
+ # GET request with utf8 query (AWS botocore tests).
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-vanilla-ut8-query.req
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-vanilla-ut8-query.sreq
+ (
+ "us-east-1",
+ "2011-09-09T23:36:00Z",
+ {
+ "access_key_id": "AKIDEXAMPLE",
+ "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
+ },
+ {
+ "method": "GET",
+ "url": "https://host.foo.com/?{}=bar".format(
+ urllib.parse.unquote("%E1%88%B4")
+ ),
+ "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT"},
+ },
+ {
+ "url": "https://host.foo.com/?{}=bar".format(
+ urllib.parse.unquote("%E1%88%B4")
+ ),
+ "method": "GET",
+ "headers": {
+ "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=6fb359e9a05394cc7074e0feb42573a2601abc0c869a953e8c5c12e4e01f1a8c",
+ "host": "host.foo.com",
+ "date": "Mon, 09 Sep 2011 23:36:00 GMT",
+ },
+ },
+ ),
+ # POST request with sorted headers (AWS botocore tests).
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/post-header-key-sort.req
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/post-header-key-sort.sreq
+ (
+ "us-east-1",
+ "2011-09-09T23:36:00Z",
+ {
+ "access_key_id": "AKIDEXAMPLE",
+ "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
+ },
+ {
+ "method": "POST",
+ "url": "https://host.foo.com/",
+ "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT", "ZOO": "zoobar"},
+ },
+ {
+ "url": "https://host.foo.com/",
+ "method": "POST",
+ "headers": {
+ "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host;zoo, Signature=b7a95a52518abbca0964a999a880429ab734f35ebbf1235bd79a5de87756dc4a",
+ "host": "host.foo.com",
+ "date": "Mon, 09 Sep 2011 23:36:00 GMT",
+ "ZOO": "zoobar",
+ },
+ },
+ ),
+ # POST request with upper case header value from AWS Python test harness.
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/post-header-value-case.req
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/post-header-value-case.sreq
+ (
+ "us-east-1",
+ "2011-09-09T23:36:00Z",
+ {
+ "access_key_id": "AKIDEXAMPLE",
+ "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
+ },
+ {
+ "method": "POST",
+ "url": "https://host.foo.com/",
+ "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT", "zoo": "ZOOBAR"},
+ },
+ {
+ "url": "https://host.foo.com/",
+ "method": "POST",
+ "headers": {
+ "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host;zoo, Signature=273313af9d0c265c531e11db70bbd653f3ba074c1009239e8559d3987039cad7",
+ "host": "host.foo.com",
+ "date": "Mon, 09 Sep 2011 23:36:00 GMT",
+ "zoo": "ZOOBAR",
+ },
+ },
+ ),
+ # POST request with header and no body (AWS botocore tests).
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-header-value-trim.req
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-header-value-trim.sreq
+ (
+ "us-east-1",
+ "2011-09-09T23:36:00Z",
+ {
+ "access_key_id": "AKIDEXAMPLE",
+ "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
+ },
+ {
+ "method": "POST",
+ "url": "https://host.foo.com/",
+ "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT", "p": "phfft"},
+ },
+ {
+ "url": "https://host.foo.com/",
+ "method": "POST",
+ "headers": {
+ "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host;p, Signature=debf546796015d6f6ded8626f5ce98597c33b47b9164cf6b17b4642036fcb592",
+ "host": "host.foo.com",
+ "date": "Mon, 09 Sep 2011 23:36:00 GMT",
+ "p": "phfft",
+ },
+ },
+ ),
+ # POST request with body and no header (AWS botocore tests).
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/post-x-www-form-urlencoded.req
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/post-x-www-form-urlencoded.sreq
+ (
+ "us-east-1",
+ "2011-09-09T23:36:00Z",
+ {
+ "access_key_id": "AKIDEXAMPLE",
+ "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
+ },
+ {
+ "method": "POST",
+ "url": "https://host.foo.com/",
+ "headers": {
+ "Content-Type": "application/x-www-form-urlencoded",
+ "date": "Mon, 09 Sep 2011 23:36:00 GMT",
+ },
+ "data": "foo=bar",
+ },
+ {
+ "url": "https://host.foo.com/",
+ "method": "POST",
+ "headers": {
+ "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=content-type;date;host, Signature=5a15b22cf462f047318703b92e6f4f38884e4a7ab7b1d6426ca46a8bd1c26cbc",
+ "host": "host.foo.com",
+ "Content-Type": "application/x-www-form-urlencoded",
+ "date": "Mon, 09 Sep 2011 23:36:00 GMT",
+ },
+ "data": "foo=bar",
+ },
+ ),
+ # POST request with querystring (AWS botocore tests).
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/post-vanilla-query.req
+ # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/post-vanilla-query.sreq
+ (
+ "us-east-1",
+ "2011-09-09T23:36:00Z",
+ {
+ "access_key_id": "AKIDEXAMPLE",
+ "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
+ },
+ {
+ "method": "POST",
+ "url": "https://host.foo.com/?foo=bar",
+ "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT"},
+ },
+ {
+ "url": "https://host.foo.com/?foo=bar",
+ "method": "POST",
+ "headers": {
+ "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=b6e3b79003ce0743a491606ba1035a804593b0efb1e20a11cba83f8c25a57a92",
+ "host": "host.foo.com",
+ "date": "Mon, 09 Sep 2011 23:36:00 GMT",
+ },
+ },
+ ),
+ # GET request with session token credentials.
+ (
+ "us-east-2",
+ "2020-08-11T06:55:22Z",
+ {
+ "access_key_id": ACCESS_KEY_ID,
+ "secret_access_key": SECRET_ACCESS_KEY,
+ "security_token": TOKEN,
+ },
+ {
+ "method": "GET",
+ "url": "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15",
+ },
+ {
+ "url": "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15",
+ "method": "GET",
+ "headers": {
+ "Authorization": "AWS4-HMAC-SHA256 Credential="
+ + ACCESS_KEY_ID
+ + "/20200811/us-east-2/ec2/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=631ea80cddfaa545fdadb120dc92c9f18166e38a5c47b50fab9fce476e022855",
+ "host": "ec2.us-east-2.amazonaws.com",
+ "x-amz-date": "20200811T065522Z",
+ "x-amz-security-token": TOKEN,
+ },
+ },
+ ),
+ # POST request with session token credentials.
+ (
+ "us-east-2",
+ "2020-08-11T06:55:22Z",
+ {
+ "access_key_id": ACCESS_KEY_ID,
+ "secret_access_key": SECRET_ACCESS_KEY,
+ "security_token": TOKEN,
+ },
+ {
+ "method": "POST",
+ "url": "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
+ },
+ {
+ "url": "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
+ "method": "POST",
+ "headers": {
+ "Authorization": "AWS4-HMAC-SHA256 Credential="
+ + ACCESS_KEY_ID
+ + "/20200811/us-east-2/sts/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=73452984e4a880ffdc5c392355733ec3f5ba310d5e0609a89244440cadfe7a7a",
+ "host": "sts.us-east-2.amazonaws.com",
+ "x-amz-date": "20200811T065522Z",
+ "x-amz-security-token": TOKEN,
+ },
+ },
+ ),
+ # POST request with computed x-amz-date and no data.
+ (
+ "us-east-2",
+ "2020-08-11T06:55:22Z",
+ {"access_key_id": ACCESS_KEY_ID, "secret_access_key": SECRET_ACCESS_KEY},
+ {
+ "method": "POST",
+ "url": "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
+ },
+ {
+ "url": "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
+ "method": "POST",
+ "headers": {
+ "Authorization": "AWS4-HMAC-SHA256 Credential="
+ + ACCESS_KEY_ID
+ + "/20200811/us-east-2/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=d095ba304919cd0d5570ba8a3787884ee78b860f268ed040ba23831d55536d56",
+ "host": "sts.us-east-2.amazonaws.com",
+ "x-amz-date": "20200811T065522Z",
+ },
+ },
+ ),
+ # POST request with session token and additional headers/data.
+ (
+ "us-east-2",
+ "2020-08-11T06:55:22Z",
+ {
+ "access_key_id": ACCESS_KEY_ID,
+ "secret_access_key": SECRET_ACCESS_KEY,
+ "security_token": TOKEN,
+ },
+ {
+ "method": "POST",
+ "url": "https://dynamodb.us-east-2.amazonaws.com/",
+ "headers": {
+ "Content-Type": "application/x-amz-json-1.0",
+ "x-amz-target": "DynamoDB_20120810.CreateTable",
+ },
+ "data": REQUEST_PARAMS,
+ },
+ {
+ "url": "https://dynamodb.us-east-2.amazonaws.com/",
+ "method": "POST",
+ "headers": {
+ "Authorization": "AWS4-HMAC-SHA256 Credential="
+ + ACCESS_KEY_ID
+ + "/20200811/us-east-2/dynamodb/aws4_request, SignedHeaders=content-type;host;x-amz-date;x-amz-security-token;x-amz-target, Signature=fdaa5b9cc9c86b80fe61eaf504141c0b3523780349120f2bd8145448456e0385",
+ "host": "dynamodb.us-east-2.amazonaws.com",
+ "x-amz-date": "20200811T065522Z",
+ "Content-Type": "application/x-amz-json-1.0",
+ "x-amz-target": "DynamoDB_20120810.CreateTable",
+ "x-amz-security-token": TOKEN,
+ },
+ "data": REQUEST_PARAMS,
+ },
+ ),
+]
+
+
+class TestRequestSigner(object):
+ @pytest.mark.parametrize(
+ "region, time, credentials, original_request, signed_request", TEST_FIXTURES
+ )
+ @mock.patch("google.auth._helpers.utcnow")
+ def test_get_request_options(
+ self, utcnow, region, time, credentials, original_request, signed_request
+ ):
+ utcnow.return_value = datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%SZ")
+ request_signer = aws.RequestSigner(region)
+ actual_signed_request = request_signer.get_request_options(
+ credentials,
+ original_request.get("url"),
+ original_request.get("method"),
+ original_request.get("data"),
+ original_request.get("headers"),
+ )
+
+ assert actual_signed_request == signed_request
+
+ def test_get_request_options_with_missing_scheme_url(self):
+ request_signer = aws.RequestSigner("us-east-2")
+
+ with pytest.raises(ValueError) as excinfo:
+ request_signer.get_request_options(
+ {
+ "access_key_id": ACCESS_KEY_ID,
+ "secret_access_key": SECRET_ACCESS_KEY,
+ },
+ "invalid",
+ "POST",
+ )
+
+ assert excinfo.match(r"Invalid AWS service URL")
+
+ def test_get_request_options_with_invalid_scheme_url(self):
+ request_signer = aws.RequestSigner("us-east-2")
+
+ with pytest.raises(ValueError) as excinfo:
+ request_signer.get_request_options(
+ {
+ "access_key_id": ACCESS_KEY_ID,
+ "secret_access_key": SECRET_ACCESS_KEY,
+ },
+ "http://invalid",
+ "POST",
+ )
+
+ assert excinfo.match(r"Invalid AWS service URL")
+
+ def test_get_request_options_with_missing_hostname_url(self):
+ request_signer = aws.RequestSigner("us-east-2")
+
+ with pytest.raises(ValueError) as excinfo:
+ request_signer.get_request_options(
+ {
+ "access_key_id": ACCESS_KEY_ID,
+ "secret_access_key": SECRET_ACCESS_KEY,
+ },
+ "https://",
+ "POST",
+ )
+
+ assert excinfo.match(r"Invalid AWS service URL")
+
+
+class TestCredentials(object):
+ AWS_REGION = "us-east-2"
+ AWS_ROLE = "gcp-aws-role"
+ AWS_SECURITY_CREDENTIALS_RESPONSE = {
+ "AccessKeyId": ACCESS_KEY_ID,
+ "SecretAccessKey": SECRET_ACCESS_KEY,
+ "Token": TOKEN,
+ }
+ AWS_SIGNATURE_TIME = "2020-08-11T06:55:22Z"
+ CREDENTIAL_SOURCE = {
+ "environment_id": "aws1",
+ "region_url": REGION_URL,
+ "url": SECURITY_CREDS_URL,
+ "regional_cred_verification_url": CRED_VERIFICATION_URL,
+ }
+ SUCCESS_RESPONSE = {
+ "access_token": "ACCESS_TOKEN",
+ "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "token_type": "Bearer",
+ "expires_in": 3600,
+ "scope": " ".join(SCOPES),
+ }
+
+ @classmethod
+ def make_serialized_aws_signed_request(
+ cls,
+ aws_security_credentials,
+ region_name="us-east-2",
+ url="https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
+ ):
+ """Utility to generate serialize AWS signed requests.
+ This makes it easy to assert generated subject tokens based on the
+ provided AWS security credentials, regions and AWS STS endpoint.
+ """
+ request_signer = aws.RequestSigner(region_name)
+ signed_request = request_signer.get_request_options(
+ aws_security_credentials, url, "POST"
+ )
+ reformatted_signed_request = {
+ "url": signed_request.get("url"),
+ "method": signed_request.get("method"),
+ "headers": [
+ {
+ "key": "Authorization",
+ "value": signed_request.get("headers").get("Authorization"),
+ },
+ {"key": "host", "value": signed_request.get("headers").get("host")},
+ {
+ "key": "x-amz-date",
+ "value": signed_request.get("headers").get("x-amz-date"),
+ },
+ ],
+ }
+ # Include security token if available.
+ if "security_token" in aws_security_credentials:
+ reformatted_signed_request.get("headers").append(
+ {
+ "key": "x-amz-security-token",
+ "value": signed_request.get("headers").get("x-amz-security-token"),
+ }
+ )
+ # Append x-goog-cloud-target-resource header.
+ reformatted_signed_request.get("headers").append(
+ {"key": "x-goog-cloud-target-resource", "value": AUDIENCE}
+ ),
+ return urllib.parse.quote(
+ json.dumps(
+ reformatted_signed_request, separators=(",", ":"), sort_keys=True
+ )
+ )
+
+ @classmethod
+ def make_mock_request(
+ cls,
+ region_status=None,
+ region_name=None,
+ role_status=None,
+ role_name=None,
+ security_credentials_status=None,
+ security_credentials_data=None,
+ token_status=None,
+ token_data=None,
+ impersonation_status=None,
+ impersonation_data=None,
+ ):
+ """Utility function to generate a mock HTTP request object.
+ This will facilitate testing various edge cases by specify how the
+ various endpoints will respond while generating a Google Access token
+ in an AWS environment.
+ """
+ responses = []
+ if region_status:
+ # AWS region request.
+ region_response = mock.create_autospec(transport.Response, instance=True)
+ region_response.status = region_status
+ if region_name:
+ region_response.data = "{}b".format(region_name).encode("utf-8")
+ responses.append(region_response)
+
+ if role_status:
+ # AWS role name request.
+ role_response = mock.create_autospec(transport.Response, instance=True)
+ role_response.status = role_status
+ if role_name:
+ role_response.data = role_name.encode("utf-8")
+ responses.append(role_response)
+
+ if security_credentials_status:
+ # AWS security credentials request.
+ security_credentials_response = mock.create_autospec(
+ transport.Response, instance=True
+ )
+ security_credentials_response.status = security_credentials_status
+ if security_credentials_data:
+ security_credentials_response.data = json.dumps(
+ security_credentials_data
+ ).encode("utf-8")
+ responses.append(security_credentials_response)
+
+ if token_status:
+ # GCP token exchange request.
+ token_response = mock.create_autospec(transport.Response, instance=True)
+ token_response.status = token_status
+ token_response.data = json.dumps(token_data).encode("utf-8")
+ responses.append(token_response)
+
+ if impersonation_status:
+ # Service account impersonation request.
+ impersonation_response = mock.create_autospec(
+ transport.Response, instance=True
+ )
+ impersonation_response.status = impersonation_status
+ impersonation_response.data = json.dumps(impersonation_data).encode("utf-8")
+ responses.append(impersonation_response)
+
+ request = mock.create_autospec(transport.Request)
+ request.side_effect = responses
+
+ return request
+
+ @classmethod
+ def make_credentials(
+ cls,
+ credential_source,
+ client_id=None,
+ client_secret=None,
+ quota_project_id=None,
+ scopes=None,
+ default_scopes=None,
+ service_account_impersonation_url=None,
+ ):
+ return aws.Credentials(
+ audience=AUDIENCE,
+ subject_token_type=SUBJECT_TOKEN_TYPE,
+ token_url=TOKEN_URL,
+ service_account_impersonation_url=service_account_impersonation_url,
+ credential_source=credential_source,
+ client_id=client_id,
+ client_secret=client_secret,
+ quota_project_id=quota_project_id,
+ scopes=scopes,
+ default_scopes=default_scopes,
+ )
+
+ @classmethod
+ def assert_aws_metadata_request_kwargs(cls, request_kwargs, url, headers=None):
+ assert request_kwargs["url"] == url
+ # All used AWS metadata server endpoints use GET HTTP method.
+ assert request_kwargs["method"] == "GET"
+ if headers:
+ assert request_kwargs["headers"] == headers
+ else:
+ assert "headers" not in request_kwargs
+ # None of the endpoints used require any data in request.
+ assert "body" not in request_kwargs
+
+ @classmethod
+ def assert_token_request_kwargs(
+ cls, request_kwargs, headers, request_data, token_url=TOKEN_URL
+ ):
+ assert request_kwargs["url"] == token_url
+ assert request_kwargs["method"] == "POST"
+ assert request_kwargs["headers"] == headers
+ assert request_kwargs["body"] is not None
+ body_tuples = urllib.parse.parse_qsl(request_kwargs["body"])
+ assert len(body_tuples) == len(request_data.keys())
+ for (k, v) in body_tuples:
+ assert v.decode("utf-8") == request_data[k.decode("utf-8")]
+
+ @classmethod
+ def assert_impersonation_request_kwargs(
+ cls,
+ request_kwargs,
+ headers,
+ request_data,
+ service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL,
+ ):
+ assert request_kwargs["url"] == service_account_impersonation_url
+ assert request_kwargs["method"] == "POST"
+ assert request_kwargs["headers"] == headers
+ assert request_kwargs["body"] is not None
+ body_json = json.loads(request_kwargs["body"].decode("utf-8"))
+ assert body_json == request_data
+
+ @mock.patch.object(aws.Credentials, "__init__", return_value=None)
+ def test_from_info_full_options(self, mock_init):
+ credentials = aws.Credentials.from_info(
+ {
+ "audience": AUDIENCE,
+ "subject_token_type": SUBJECT_TOKEN_TYPE,
+ "token_url": TOKEN_URL,
+ "service_account_impersonation_url": SERVICE_ACCOUNT_IMPERSONATION_URL,
+ "client_id": CLIENT_ID,
+ "client_secret": CLIENT_SECRET,
+ "quota_project_id": QUOTA_PROJECT_ID,
+ "credential_source": self.CREDENTIAL_SOURCE,
+ }
+ )
+
+ # Confirm aws.Credentials instance initialized with the expected parameters.
+ assert isinstance(credentials, aws.Credentials)
+ mock_init.assert_called_once_with(
+ audience=AUDIENCE,
+ subject_token_type=SUBJECT_TOKEN_TYPE,
+ token_url=TOKEN_URL,
+ service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL,
+ client_id=CLIENT_ID,
+ client_secret=CLIENT_SECRET,
+ credential_source=self.CREDENTIAL_SOURCE,
+ quota_project_id=QUOTA_PROJECT_ID,
+ )
+
+ @mock.patch.object(aws.Credentials, "__init__", return_value=None)
+ def test_from_info_required_options_only(self, mock_init):
+ credentials = aws.Credentials.from_info(
+ {
+ "audience": AUDIENCE,
+ "subject_token_type": SUBJECT_TOKEN_TYPE,
+ "token_url": TOKEN_URL,
+ "credential_source": self.CREDENTIAL_SOURCE,
+ }
+ )
+
+ # Confirm aws.Credentials instance initialized with the expected parameters.
+ assert isinstance(credentials, aws.Credentials)
+ mock_init.assert_called_once_with(
+ audience=AUDIENCE,
+ subject_token_type=SUBJECT_TOKEN_TYPE,
+ token_url=TOKEN_URL,
+ service_account_impersonation_url=None,
+ client_id=None,
+ client_secret=None,
+ credential_source=self.CREDENTIAL_SOURCE,
+ quota_project_id=None,
+ )
+
+ @mock.patch.object(aws.Credentials, "__init__", return_value=None)
+ def test_from_file_full_options(self, mock_init, tmpdir):
+ info = {
+ "audience": AUDIENCE,
+ "subject_token_type": SUBJECT_TOKEN_TYPE,
+ "token_url": TOKEN_URL,
+ "service_account_impersonation_url": SERVICE_ACCOUNT_IMPERSONATION_URL,
+ "client_id": CLIENT_ID,
+ "client_secret": CLIENT_SECRET,
+ "quota_project_id": QUOTA_PROJECT_ID,
+ "credential_source": self.CREDENTIAL_SOURCE,
+ }
+ config_file = tmpdir.join("config.json")
+ config_file.write(json.dumps(info))
+ credentials = aws.Credentials.from_file(str(config_file))
+
+ # Confirm aws.Credentials instance initialized with the expected parameters.
+ assert isinstance(credentials, aws.Credentials)
+ mock_init.assert_called_once_with(
+ audience=AUDIENCE,
+ subject_token_type=SUBJECT_TOKEN_TYPE,
+ token_url=TOKEN_URL,
+ service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL,
+ client_id=CLIENT_ID,
+ client_secret=CLIENT_SECRET,
+ credential_source=self.CREDENTIAL_SOURCE,
+ quota_project_id=QUOTA_PROJECT_ID,
+ )
+
+ @mock.patch.object(aws.Credentials, "__init__", return_value=None)
+ def test_from_file_required_options_only(self, mock_init, tmpdir):
+ info = {
+ "audience": AUDIENCE,
+ "subject_token_type": SUBJECT_TOKEN_TYPE,
+ "token_url": TOKEN_URL,
+ "credential_source": self.CREDENTIAL_SOURCE,
+ }
+ config_file = tmpdir.join("config.json")
+ config_file.write(json.dumps(info))
+ credentials = aws.Credentials.from_file(str(config_file))
+
+ # Confirm aws.Credentials instance initialized with the expected parameters.
+ assert isinstance(credentials, aws.Credentials)
+ mock_init.assert_called_once_with(
+ audience=AUDIENCE,
+ subject_token_type=SUBJECT_TOKEN_TYPE,
+ token_url=TOKEN_URL,
+ service_account_impersonation_url=None,
+ client_id=None,
+ client_secret=None,
+ credential_source=self.CREDENTIAL_SOURCE,
+ quota_project_id=None,
+ )
+
+ def test_constructor_invalid_credential_source(self):
+ # Provide invalid credential source.
+ credential_source = {"unsupported": "value"}
+
+ with pytest.raises(ValueError) as excinfo:
+ self.make_credentials(credential_source=credential_source)
+
+ assert excinfo.match(r"No valid AWS 'credential_source' provided")
+
+ def test_constructor_invalid_environment_id(self):
+ # Provide invalid environment_id.
+ credential_source = self.CREDENTIAL_SOURCE.copy()
+ credential_source["environment_id"] = "azure1"
+
+ with pytest.raises(ValueError) as excinfo:
+ self.make_credentials(credential_source=credential_source)
+
+ assert excinfo.match(r"No valid AWS 'credential_source' provided")
+
+ def test_constructor_missing_cred_verification_url(self):
+ # regional_cred_verification_url is a required field.
+ credential_source = self.CREDENTIAL_SOURCE.copy()
+ credential_source.pop("regional_cred_verification_url")
+
+ with pytest.raises(ValueError) as excinfo:
+ self.make_credentials(credential_source=credential_source)
+
+ assert excinfo.match(r"No valid AWS 'credential_source' provided")
+
+ def test_constructor_invalid_environment_id_version(self):
+ # Provide an unsupported version.
+ credential_source = self.CREDENTIAL_SOURCE.copy()
+ credential_source["environment_id"] = "aws3"
+
+ with pytest.raises(ValueError) as excinfo:
+ self.make_credentials(credential_source=credential_source)
+
+ assert excinfo.match(r"aws version '3' is not supported in the current build.")
+
+ def test_retrieve_subject_token_missing_region_url(self):
+ # When AWS_REGION envvar is not available, region_url is required for
+ # determining the current AWS region.
+ credential_source = self.CREDENTIAL_SOURCE.copy()
+ credential_source.pop("region_url")
+ credentials = self.make_credentials(credential_source=credential_source)
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ credentials.retrieve_subject_token(None)
+
+ assert excinfo.match(r"Unable to determine AWS region")
+
+ @mock.patch("google.auth._helpers.utcnow")
+ def test_retrieve_subject_token_success_temp_creds_no_environment_vars(
+ self, utcnow
+ ):
+ utcnow.return_value = datetime.datetime.strptime(
+ self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ"
+ )
+ request = self.make_mock_request(
+ region_status=http_client.OK,
+ region_name=self.AWS_REGION,
+ role_status=http_client.OK,
+ role_name=self.AWS_ROLE,
+ security_credentials_status=http_client.OK,
+ security_credentials_data=self.AWS_SECURITY_CREDENTIALS_RESPONSE,
+ )
+ credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE)
+
+ subject_token = credentials.retrieve_subject_token(request)
+
+ assert subject_token == self.make_serialized_aws_signed_request(
+ {
+ "access_key_id": ACCESS_KEY_ID,
+ "secret_access_key": SECRET_ACCESS_KEY,
+ "security_token": TOKEN,
+ }
+ )
+ # Assert region request.
+ self.assert_aws_metadata_request_kwargs(
+ request.call_args_list[0].kwargs, REGION_URL
+ )
+ # Assert role request.
+ self.assert_aws_metadata_request_kwargs(
+ request.call_args_list[1].kwargs, SECURITY_CREDS_URL
+ )
+ # Assert security credentials request.
+ self.assert_aws_metadata_request_kwargs(
+ request.call_args_list[2].kwargs,
+ "{}/{}".format(SECURITY_CREDS_URL, self.AWS_ROLE),
+ {"Content-Type": "application/json"},
+ )
+
+ # Retrieve subject_token again. Region should not be queried again.
+ new_request = self.make_mock_request(
+ role_status=http_client.OK,
+ role_name=self.AWS_ROLE,
+ security_credentials_status=http_client.OK,
+ security_credentials_data=self.AWS_SECURITY_CREDENTIALS_RESPONSE,
+ )
+
+ credentials.retrieve_subject_token(new_request)
+
+ # Only 2 requests should be sent as the region is cached.
+ assert len(new_request.call_args_list) == 2
+ # Assert role request.
+ self.assert_aws_metadata_request_kwargs(
+ new_request.call_args_list[0].kwargs, SECURITY_CREDS_URL
+ )
+ # Assert security credentials request.
+ self.assert_aws_metadata_request_kwargs(
+ new_request.call_args_list[1].kwargs,
+ "{}/{}".format(SECURITY_CREDS_URL, self.AWS_ROLE),
+ {"Content-Type": "application/json"},
+ )
+
+ @mock.patch("google.auth._helpers.utcnow")
+ def test_retrieve_subject_token_success_permanent_creds_no_environment_vars(
+ self, utcnow
+ ):
+ # Simualte a permanent credential without a session token is
+ # returned by the security-credentials endpoint.
+ security_creds_response = self.AWS_SECURITY_CREDENTIALS_RESPONSE.copy()
+ security_creds_response.pop("Token")
+ utcnow.return_value = datetime.datetime.strptime(
+ self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ"
+ )
+ request = self.make_mock_request(
+ region_status=http_client.OK,
+ region_name=self.AWS_REGION,
+ role_status=http_client.OK,
+ role_name=self.AWS_ROLE,
+ security_credentials_status=http_client.OK,
+ security_credentials_data=security_creds_response,
+ )
+ credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE)
+
+ subject_token = credentials.retrieve_subject_token(request)
+
+ assert subject_token == self.make_serialized_aws_signed_request(
+ {"access_key_id": ACCESS_KEY_ID, "secret_access_key": SECRET_ACCESS_KEY}
+ )
+
+ @mock.patch("google.auth._helpers.utcnow")
+ def test_retrieve_subject_token_success_environment_vars(self, utcnow, monkeypatch):
+ monkeypatch.setenv(environment_vars.AWS_ACCESS_KEY_ID, ACCESS_KEY_ID)
+ monkeypatch.setenv(environment_vars.AWS_SECRET_ACCESS_KEY, SECRET_ACCESS_KEY)
+ monkeypatch.setenv(environment_vars.AWS_SESSION_TOKEN, TOKEN)
+ monkeypatch.setenv(environment_vars.AWS_REGION, self.AWS_REGION)
+ utcnow.return_value = datetime.datetime.strptime(
+ self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ"
+ )
+ credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE)
+
+ subject_token = credentials.retrieve_subject_token(None)
+
+ assert subject_token == self.make_serialized_aws_signed_request(
+ {
+ "access_key_id": ACCESS_KEY_ID,
+ "secret_access_key": SECRET_ACCESS_KEY,
+ "security_token": TOKEN,
+ }
+ )
+
+ @mock.patch("google.auth._helpers.utcnow")
+ def test_retrieve_subject_token_success_environment_vars_no_session_token(
+ self, utcnow, monkeypatch
+ ):
+ monkeypatch.setenv(environment_vars.AWS_ACCESS_KEY_ID, ACCESS_KEY_ID)
+ monkeypatch.setenv(environment_vars.AWS_SECRET_ACCESS_KEY, SECRET_ACCESS_KEY)
+ monkeypatch.setenv(environment_vars.AWS_REGION, self.AWS_REGION)
+ utcnow.return_value = datetime.datetime.strptime(
+ self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ"
+ )
+ credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE)
+
+ subject_token = credentials.retrieve_subject_token(None)
+
+ assert subject_token == self.make_serialized_aws_signed_request(
+ {"access_key_id": ACCESS_KEY_ID, "secret_access_key": SECRET_ACCESS_KEY}
+ )
+
+ @mock.patch("google.auth._helpers.utcnow")
+ def test_retrieve_subject_token_success_environment_vars_except_region(
+ self, utcnow, monkeypatch
+ ):
+ monkeypatch.setenv(environment_vars.AWS_ACCESS_KEY_ID, ACCESS_KEY_ID)
+ monkeypatch.setenv(environment_vars.AWS_SECRET_ACCESS_KEY, SECRET_ACCESS_KEY)
+ monkeypatch.setenv(environment_vars.AWS_SESSION_TOKEN, TOKEN)
+ utcnow.return_value = datetime.datetime.strptime(
+ self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ"
+ )
+ # Region will be queried since it is not found in envvars.
+ request = self.make_mock_request(
+ region_status=http_client.OK, region_name=self.AWS_REGION
+ )
+ credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE)
+
+ subject_token = credentials.retrieve_subject_token(request)
+
+ assert subject_token == self.make_serialized_aws_signed_request(
+ {
+ "access_key_id": ACCESS_KEY_ID,
+ "secret_access_key": SECRET_ACCESS_KEY,
+ "security_token": TOKEN,
+ }
+ )
+
+ def test_retrieve_subject_token_error_determining_aws_region(self):
+ # Simulate error in retrieving the AWS region.
+ request = self.make_mock_request(region_status=http_client.BAD_REQUEST)
+ credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE)
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ credentials.retrieve_subject_token(request)
+
+ assert excinfo.match(r"Unable to retrieve AWS region")
+
+ def test_retrieve_subject_token_error_determining_aws_role(self):
+ # Simulate error in retrieving the AWS role name.
+ request = self.make_mock_request(
+ region_status=http_client.OK,
+ region_name=self.AWS_REGION,
+ role_status=http_client.BAD_REQUEST,
+ )
+ credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE)
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ credentials.retrieve_subject_token(request)
+
+ assert excinfo.match(r"Unable to retrieve AWS role name")
+
+ def test_retrieve_subject_token_error_determining_security_creds_url(self):
+ # Simulate the security-credentials url is missing. This is needed for
+ # determining the AWS security credentials when not found in envvars.
+ credential_source = self.CREDENTIAL_SOURCE.copy()
+ credential_source.pop("url")
+ request = self.make_mock_request(
+ region_status=http_client.OK, region_name=self.AWS_REGION
+ )
+ credentials = self.make_credentials(credential_source=credential_source)
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ credentials.retrieve_subject_token(request)
+
+ assert excinfo.match(
+ r"Unable to determine the AWS metadata server security credentials endpoint"
+ )
+
+ def test_retrieve_subject_token_error_determining_aws_security_creds(self):
+ # Simulate error in retrieving the AWS security credentials.
+ request = self.make_mock_request(
+ region_status=http_client.OK,
+ region_name=self.AWS_REGION,
+ role_status=http_client.OK,
+ role_name=self.AWS_ROLE,
+ security_credentials_status=http_client.BAD_REQUEST,
+ )
+ credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE)
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ credentials.retrieve_subject_token(request)
+
+ assert excinfo.match(r"Unable to retrieve AWS security credentials")
+
+ @mock.patch("google.auth._helpers.utcnow")
+ def test_refresh_success_without_impersonation_ignore_default_scopes(self, utcnow):
+ utcnow.return_value = datetime.datetime.strptime(
+ self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ"
+ )
+ expected_subject_token = self.make_serialized_aws_signed_request(
+ {
+ "access_key_id": ACCESS_KEY_ID,
+ "secret_access_key": SECRET_ACCESS_KEY,
+ "security_token": TOKEN,
+ }
+ )
+ token_headers = {
+ "Content-Type": "application/x-www-form-urlencoded",
+ "Authorization": "Basic " + BASIC_AUTH_ENCODING,
+ }
+ token_request_data = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
+ "audience": AUDIENCE,
+ "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "scope": " ".join(SCOPES),
+ "subject_token": expected_subject_token,
+ "subject_token_type": SUBJECT_TOKEN_TYPE,
+ }
+ request = self.make_mock_request(
+ region_status=http_client.OK,
+ region_name=self.AWS_REGION,
+ role_status=http_client.OK,
+ role_name=self.AWS_ROLE,
+ security_credentials_status=http_client.OK,
+ security_credentials_data=self.AWS_SECURITY_CREDENTIALS_RESPONSE,
+ token_status=http_client.OK,
+ token_data=self.SUCCESS_RESPONSE,
+ )
+ credentials = self.make_credentials(
+ client_id=CLIENT_ID,
+ client_secret=CLIENT_SECRET,
+ credential_source=self.CREDENTIAL_SOURCE,
+ quota_project_id=QUOTA_PROJECT_ID,
+ scopes=SCOPES,
+ # Default scopes should be ignored.
+ default_scopes=["ignored"],
+ )
+
+ credentials.refresh(request)
+
+ assert len(request.call_args_list) == 4
+ # Fourth request should be sent to GCP STS endpoint.
+ self.assert_token_request_kwargs(
+ request.call_args_list[3].kwargs, token_headers, token_request_data
+ )
+ assert credentials.token == self.SUCCESS_RESPONSE["access_token"]
+ assert credentials.quota_project_id == QUOTA_PROJECT_ID
+ assert credentials.scopes == SCOPES
+ assert credentials.default_scopes == ["ignored"]
+
+ @mock.patch("google.auth._helpers.utcnow")
+ def test_refresh_success_without_impersonation_use_default_scopes(self, utcnow):
+ utcnow.return_value = datetime.datetime.strptime(
+ self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ"
+ )
+ expected_subject_token = self.make_serialized_aws_signed_request(
+ {
+ "access_key_id": ACCESS_KEY_ID,
+ "secret_access_key": SECRET_ACCESS_KEY,
+ "security_token": TOKEN,
+ }
+ )
+ token_headers = {
+ "Content-Type": "application/x-www-form-urlencoded",
+ "Authorization": "Basic " + BASIC_AUTH_ENCODING,
+ }
+ token_request_data = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
+ "audience": AUDIENCE,
+ "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "scope": " ".join(SCOPES),
+ "subject_token": expected_subject_token,
+ "subject_token_type": SUBJECT_TOKEN_TYPE,
+ }
+ request = self.make_mock_request(
+ region_status=http_client.OK,
+ region_name=self.AWS_REGION,
+ role_status=http_client.OK,
+ role_name=self.AWS_ROLE,
+ security_credentials_status=http_client.OK,
+ security_credentials_data=self.AWS_SECURITY_CREDENTIALS_RESPONSE,
+ token_status=http_client.OK,
+ token_data=self.SUCCESS_RESPONSE,
+ )
+ credentials = self.make_credentials(
+ client_id=CLIENT_ID,
+ client_secret=CLIENT_SECRET,
+ credential_source=self.CREDENTIAL_SOURCE,
+ quota_project_id=QUOTA_PROJECT_ID,
+ scopes=None,
+ # Default scopes should be used since user specified scopes are none.
+ default_scopes=SCOPES,
+ )
+
+ credentials.refresh(request)
+
+ assert len(request.call_args_list) == 4
+ # Fourth request should be sent to GCP STS endpoint.
+ self.assert_token_request_kwargs(
+ request.call_args_list[3].kwargs, token_headers, token_request_data
+ )
+ assert credentials.token == self.SUCCESS_RESPONSE["access_token"]
+ assert credentials.quota_project_id == QUOTA_PROJECT_ID
+ assert credentials.scopes is None
+ assert credentials.default_scopes == SCOPES
+
+ @mock.patch("google.auth._helpers.utcnow")
+ def test_refresh_success_with_impersonation_ignore_default_scopes(self, utcnow):
+ utcnow.return_value = datetime.datetime.strptime(
+ self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ"
+ )
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600)
+ ).isoformat("T") + "Z"
+ expected_subject_token = self.make_serialized_aws_signed_request(
+ {
+ "access_key_id": ACCESS_KEY_ID,
+ "secret_access_key": SECRET_ACCESS_KEY,
+ "security_token": TOKEN,
+ }
+ )
+ token_headers = {
+ "Content-Type": "application/x-www-form-urlencoded",
+ "Authorization": "Basic " + BASIC_AUTH_ENCODING,
+ }
+ token_request_data = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
+ "audience": AUDIENCE,
+ "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "scope": "https://www.googleapis.com/auth/iam",
+ "subject_token": expected_subject_token,
+ "subject_token_type": SUBJECT_TOKEN_TYPE,
+ }
+ # Service account impersonation request/response.
+ impersonation_response = {
+ "accessToken": "SA_ACCESS_TOKEN",
+ "expireTime": expire_time,
+ }
+ impersonation_headers = {
+ "Content-Type": "application/json",
+ "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]),
+ "x-goog-user-project": QUOTA_PROJECT_ID,
+ }
+ impersonation_request_data = {
+ "delegates": None,
+ "scope": SCOPES,
+ "lifetime": "3600s",
+ }
+ request = self.make_mock_request(
+ region_status=http_client.OK,
+ region_name=self.AWS_REGION,
+ role_status=http_client.OK,
+ role_name=self.AWS_ROLE,
+ security_credentials_status=http_client.OK,
+ security_credentials_data=self.AWS_SECURITY_CREDENTIALS_RESPONSE,
+ token_status=http_client.OK,
+ token_data=self.SUCCESS_RESPONSE,
+ impersonation_status=http_client.OK,
+ impersonation_data=impersonation_response,
+ )
+ credentials = self.make_credentials(
+ client_id=CLIENT_ID,
+ client_secret=CLIENT_SECRET,
+ credential_source=self.CREDENTIAL_SOURCE,
+ service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL,
+ quota_project_id=QUOTA_PROJECT_ID,
+ scopes=SCOPES,
+ # Default scopes should be ignored.
+ default_scopes=["ignored"],
+ )
+
+ credentials.refresh(request)
+
+ assert len(request.call_args_list) == 5
+ # Fourth request should be sent to GCP STS endpoint.
+ self.assert_token_request_kwargs(
+ request.call_args_list[3].kwargs, token_headers, token_request_data
+ )
+ # Fifth request should be sent to iamcredentials endpoint for service
+ # account impersonation.
+ self.assert_impersonation_request_kwargs(
+ request.call_args_list[4].kwargs,
+ impersonation_headers,
+ impersonation_request_data,
+ )
+ assert credentials.token == impersonation_response["accessToken"]
+ assert credentials.quota_project_id == QUOTA_PROJECT_ID
+ assert credentials.scopes == SCOPES
+ assert credentials.default_scopes == ["ignored"]
+
+ @mock.patch("google.auth._helpers.utcnow")
+ def test_refresh_success_with_impersonation_use_default_scopes(self, utcnow):
+ utcnow.return_value = datetime.datetime.strptime(
+ self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ"
+ )
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600)
+ ).isoformat("T") + "Z"
+ expected_subject_token = self.make_serialized_aws_signed_request(
+ {
+ "access_key_id": ACCESS_KEY_ID,
+ "secret_access_key": SECRET_ACCESS_KEY,
+ "security_token": TOKEN,
+ }
+ )
+ token_headers = {
+ "Content-Type": "application/x-www-form-urlencoded",
+ "Authorization": "Basic " + BASIC_AUTH_ENCODING,
+ }
+ token_request_data = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
+ "audience": AUDIENCE,
+ "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "scope": "https://www.googleapis.com/auth/iam",
+ "subject_token": expected_subject_token,
+ "subject_token_type": SUBJECT_TOKEN_TYPE,
+ }
+ # Service account impersonation request/response.
+ impersonation_response = {
+ "accessToken": "SA_ACCESS_TOKEN",
+ "expireTime": expire_time,
+ }
+ impersonation_headers = {
+ "Content-Type": "application/json",
+ "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]),
+ "x-goog-user-project": QUOTA_PROJECT_ID,
+ }
+ impersonation_request_data = {
+ "delegates": None,
+ "scope": SCOPES,
+ "lifetime": "3600s",
+ }
+ request = self.make_mock_request(
+ region_status=http_client.OK,
+ region_name=self.AWS_REGION,
+ role_status=http_client.OK,
+ role_name=self.AWS_ROLE,
+ security_credentials_status=http_client.OK,
+ security_credentials_data=self.AWS_SECURITY_CREDENTIALS_RESPONSE,
+ token_status=http_client.OK,
+ token_data=self.SUCCESS_RESPONSE,
+ impersonation_status=http_client.OK,
+ impersonation_data=impersonation_response,
+ )
+ credentials = self.make_credentials(
+ client_id=CLIENT_ID,
+ client_secret=CLIENT_SECRET,
+ credential_source=self.CREDENTIAL_SOURCE,
+ service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL,
+ quota_project_id=QUOTA_PROJECT_ID,
+ scopes=None,
+ # Default scopes should be used since user specified scopes are none.
+ default_scopes=SCOPES,
+ )
+
+ credentials.refresh(request)
+
+ assert len(request.call_args_list) == 5
+ # Fourth request should be sent to GCP STS endpoint.
+ self.assert_token_request_kwargs(
+ request.call_args_list[3].kwargs, token_headers, token_request_data
+ )
+ # Fifth request should be sent to iamcredentials endpoint for service
+ # account impersonation.
+ self.assert_impersonation_request_kwargs(
+ request.call_args_list[4].kwargs,
+ impersonation_headers,
+ impersonation_request_data,
+ )
+ assert credentials.token == impersonation_response["accessToken"]
+ assert credentials.quota_project_id == QUOTA_PROJECT_ID
+ assert credentials.scopes is None
+ assert credentials.default_scopes == SCOPES
+
+ def test_refresh_with_retrieve_subject_token_error(self):
+ request = self.make_mock_request(region_status=http_client.BAD_REQUEST)
+ credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE)
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ credentials.refresh(request)
+
+ assert excinfo.match(r"Unable to retrieve AWS region")
diff --git a/tests/test_external_account.py b/tests/test_external_account.py
new file mode 100644
index 0000000..42e53ec
--- /dev/null
+++ b/tests/test_external_account.py
@@ -0,0 +1,1095 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import json
+
+import mock
+import pytest
+from six.moves import http_client
+from six.moves import urllib
+
+from google.auth import _helpers
+from google.auth import exceptions
+from google.auth import external_account
+from google.auth import transport
+
+
+CLIENT_ID = "username"
+CLIENT_SECRET = "password"
+# Base64 encoding of "username:password"
+BASIC_AUTH_ENCODING = "dXNlcm5hbWU6cGFzc3dvcmQ="
+SERVICE_ACCOUNT_EMAIL = "service-1234@service-name.iam.gserviceaccount.com"
+
+
+class CredentialsImpl(external_account.Credentials):
+ def __init__(
+ self,
+ audience,
+ subject_token_type,
+ token_url,
+ credential_source,
+ service_account_impersonation_url=None,
+ client_id=None,
+ client_secret=None,
+ quota_project_id=None,
+ scopes=None,
+ default_scopes=None,
+ ):
+ super(CredentialsImpl, self).__init__(
+ audience=audience,
+ subject_token_type=subject_token_type,
+ token_url=token_url,
+ credential_source=credential_source,
+ service_account_impersonation_url=service_account_impersonation_url,
+ client_id=client_id,
+ client_secret=client_secret,
+ quota_project_id=quota_project_id,
+ scopes=scopes,
+ default_scopes=default_scopes,
+ )
+ self._counter = 0
+
+ def retrieve_subject_token(self, request):
+ counter = self._counter
+ self._counter += 1
+ return "subject_token_{}".format(counter)
+
+
+class TestCredentials(object):
+ TOKEN_URL = "https://sts.googleapis.com/v1/token"
+ PROJECT_NUMBER = "123456"
+ POOL_ID = "POOL_ID"
+ PROVIDER_ID = "PROVIDER_ID"
+ AUDIENCE = (
+ "//iam.googleapis.com/projects/{}"
+ "/locations/global/workloadIdentityPools/{}"
+ "/providers/{}"
+ ).format(PROJECT_NUMBER, POOL_ID, PROVIDER_ID)
+ SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt"
+ CREDENTIAL_SOURCE = {"file": "/var/run/secrets/goog.id/token"}
+ SUCCESS_RESPONSE = {
+ "access_token": "ACCESS_TOKEN",
+ "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "token_type": "Bearer",
+ "expires_in": 3600,
+ "scope": "scope1 scope2",
+ }
+ ERROR_RESPONSE = {
+ "error": "invalid_request",
+ "error_description": "Invalid subject token",
+ "error_uri": "https://tools.ietf.org/html/rfc6749",
+ }
+ QUOTA_PROJECT_ID = "QUOTA_PROJECT_ID"
+ SERVICE_ACCOUNT_IMPERSONATION_URL = (
+ "https://us-east1-iamcredentials.googleapis.com/v1/projects/-"
+ + "/serviceAccounts/{}:generateAccessToken".format(SERVICE_ACCOUNT_EMAIL)
+ )
+ SCOPES = ["scope1", "scope2"]
+ IMPERSONATION_ERROR_RESPONSE = {
+ "error": {
+ "code": 400,
+ "message": "Request contains an invalid argument",
+ "status": "INVALID_ARGUMENT",
+ }
+ }
+ PROJECT_ID = "my-proj-id"
+ CLOUD_RESOURCE_MANAGER_URL = (
+ "https://cloudresourcemanager.googleapis.com/v1/projects/"
+ )
+ CLOUD_RESOURCE_MANAGER_SUCCESS_RESPONSE = {
+ "projectNumber": PROJECT_NUMBER,
+ "projectId": PROJECT_ID,
+ "lifecycleState": "ACTIVE",
+ "name": "project-name",
+ "createTime": "2018-11-06T04:42:54.109Z",
+ "parent": {"type": "folder", "id": "12345678901"},
+ }
+
+ @classmethod
+ def make_credentials(
+ cls,
+ client_id=None,
+ client_secret=None,
+ quota_project_id=None,
+ scopes=None,
+ default_scopes=None,
+ service_account_impersonation_url=None,
+ ):
+ return CredentialsImpl(
+ audience=cls.AUDIENCE,
+ subject_token_type=cls.SUBJECT_TOKEN_TYPE,
+ token_url=cls.TOKEN_URL,
+ service_account_impersonation_url=service_account_impersonation_url,
+ credential_source=cls.CREDENTIAL_SOURCE,
+ client_id=client_id,
+ client_secret=client_secret,
+ quota_project_id=quota_project_id,
+ scopes=scopes,
+ default_scopes=default_scopes,
+ )
+
+ @classmethod
+ def make_mock_request(
+ cls,
+ status=http_client.OK,
+ data=None,
+ impersonation_status=None,
+ impersonation_data=None,
+ cloud_resource_manager_status=None,
+ cloud_resource_manager_data=None,
+ ):
+ # STS token exchange request.
+ token_response = mock.create_autospec(transport.Response, instance=True)
+ token_response.status = status
+ token_response.data = json.dumps(data).encode("utf-8")
+ responses = [token_response]
+
+ # If service account impersonation is requested, mock the expected response.
+ if impersonation_status:
+ impersonation_response = mock.create_autospec(
+ transport.Response, instance=True
+ )
+ impersonation_response.status = impersonation_status
+ impersonation_response.data = json.dumps(impersonation_data).encode("utf-8")
+ responses.append(impersonation_response)
+
+ # If cloud resource manager is requested, mock the expected response.
+ if cloud_resource_manager_status:
+ cloud_resource_manager_response = mock.create_autospec(
+ transport.Response, instance=True
+ )
+ cloud_resource_manager_response.status = cloud_resource_manager_status
+ cloud_resource_manager_response.data = json.dumps(
+ cloud_resource_manager_data
+ ).encode("utf-8")
+ responses.append(cloud_resource_manager_response)
+
+ request = mock.create_autospec(transport.Request)
+ request.side_effect = responses
+
+ return request
+
+ @classmethod
+ def assert_token_request_kwargs(cls, request_kwargs, headers, request_data):
+ assert request_kwargs["url"] == cls.TOKEN_URL
+ assert request_kwargs["method"] == "POST"
+ assert request_kwargs["headers"] == headers
+ assert request_kwargs["body"] is not None
+ body_tuples = urllib.parse.parse_qsl(request_kwargs["body"])
+ for (k, v) in body_tuples:
+ assert v.decode("utf-8") == request_data[k.decode("utf-8")]
+ assert len(body_tuples) == len(request_data.keys())
+
+ @classmethod
+ def assert_impersonation_request_kwargs(cls, request_kwargs, headers, request_data):
+ assert request_kwargs["url"] == cls.SERVICE_ACCOUNT_IMPERSONATION_URL
+ assert request_kwargs["method"] == "POST"
+ assert request_kwargs["headers"] == headers
+ assert request_kwargs["body"] is not None
+ body_json = json.loads(request_kwargs["body"].decode("utf-8"))
+ assert body_json == request_data
+
+ @classmethod
+ def assert_resource_manager_request_kwargs(
+ cls, request_kwargs, project_number, headers
+ ):
+ assert request_kwargs["url"] == cls.CLOUD_RESOURCE_MANAGER_URL + project_number
+ assert request_kwargs["method"] == "GET"
+ assert request_kwargs["headers"] == headers
+ assert "body" not in request_kwargs
+
+ def test_default_state(self):
+ credentials = self.make_credentials()
+
+ # Not token acquired yet
+ assert not credentials.token
+ assert not credentials.valid
+ # Expiration hasn't been set yet
+ assert not credentials.expiry
+ assert not credentials.expired
+ # Scopes are required
+ assert not credentials.scopes
+ assert credentials.requires_scopes
+ assert not credentials.quota_project_id
+
+ def test_with_scopes(self):
+ credentials = self.make_credentials()
+
+ assert not credentials.scopes
+ assert credentials.requires_scopes
+
+ scoped_credentials = credentials.with_scopes(["email"])
+
+ assert scoped_credentials.has_scopes(["email"])
+ assert not scoped_credentials.requires_scopes
+
+ def test_with_scopes_using_user_and_default_scopes(self):
+ credentials = self.make_credentials()
+
+ assert not credentials.scopes
+ assert credentials.requires_scopes
+
+ scoped_credentials = credentials.with_scopes(
+ ["email"], default_scopes=["profile"]
+ )
+
+ assert scoped_credentials.has_scopes(["email"])
+ assert not scoped_credentials.has_scopes(["profile"])
+ assert not scoped_credentials.requires_scopes
+ assert scoped_credentials.scopes == ["email"]
+ assert scoped_credentials.default_scopes == ["profile"]
+
+ def test_with_scopes_using_default_scopes_only(self):
+ credentials = self.make_credentials()
+
+ assert not credentials.scopes
+ assert credentials.requires_scopes
+
+ scoped_credentials = credentials.with_scopes(None, default_scopes=["profile"])
+
+ assert scoped_credentials.has_scopes(["profile"])
+ assert not scoped_credentials.requires_scopes
+
+ def test_with_scopes_full_options_propagated(self):
+ credentials = self.make_credentials(
+ client_id=CLIENT_ID,
+ client_secret=CLIENT_SECRET,
+ quota_project_id=self.QUOTA_PROJECT_ID,
+ scopes=self.SCOPES,
+ default_scopes=["default1"],
+ service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
+ )
+
+ with mock.patch.object(
+ external_account.Credentials, "__init__", return_value=None
+ ) as mock_init:
+ credentials.with_scopes(["email"], ["default2"])
+
+ # Confirm with_scopes initialized the credential with the expected
+ # parameters and scopes.
+ mock_init.assert_called_once_with(
+ audience=self.AUDIENCE,
+ subject_token_type=self.SUBJECT_TOKEN_TYPE,
+ token_url=self.TOKEN_URL,
+ credential_source=self.CREDENTIAL_SOURCE,
+ service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
+ client_id=CLIENT_ID,
+ client_secret=CLIENT_SECRET,
+ quota_project_id=self.QUOTA_PROJECT_ID,
+ scopes=["email"],
+ default_scopes=["default2"],
+ )
+
+ def test_with_quota_project(self):
+ credentials = self.make_credentials()
+
+ assert not credentials.scopes
+ assert not credentials.quota_project_id
+
+ quota_project_creds = credentials.with_quota_project("project-foo")
+
+ assert quota_project_creds.quota_project_id == "project-foo"
+
+ def test_with_quota_project_full_options_propagated(self):
+ credentials = self.make_credentials(
+ client_id=CLIENT_ID,
+ client_secret=CLIENT_SECRET,
+ quota_project_id=self.QUOTA_PROJECT_ID,
+ scopes=self.SCOPES,
+ default_scopes=["default1"],
+ service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
+ )
+
+ with mock.patch.object(
+ external_account.Credentials, "__init__", return_value=None
+ ) as mock_init:
+ credentials.with_quota_project("project-foo")
+
+ # Confirm with_quota_project initialized the credential with the
+ # expected parameters and quota project ID.
+ mock_init.assert_called_once_with(
+ audience=self.AUDIENCE,
+ subject_token_type=self.SUBJECT_TOKEN_TYPE,
+ token_url=self.TOKEN_URL,
+ credential_source=self.CREDENTIAL_SOURCE,
+ service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
+ client_id=CLIENT_ID,
+ client_secret=CLIENT_SECRET,
+ quota_project_id="project-foo",
+ scopes=self.SCOPES,
+ default_scopes=["default1"],
+ )
+
+ def test_with_invalid_impersonation_target_principal(self):
+ invalid_url = "https://iamcredentials.googleapis.com/v1/invalid"
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ self.make_credentials(service_account_impersonation_url=invalid_url)
+
+ assert excinfo.match(
+ r"Unable to determine target principal from service account impersonation URL."
+ )
+
+ @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
+ def test_refresh_without_client_auth_success(self, unused_utcnow):
+ response = self.SUCCESS_RESPONSE.copy()
+ # Test custom expiration to confirm expiry is set correctly.
+ response["expires_in"] = 2800
+ expected_expiry = datetime.datetime.min + datetime.timedelta(
+ seconds=response["expires_in"]
+ )
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
+ request_data = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
+ "audience": self.AUDIENCE,
+ "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "subject_token": "subject_token_0",
+ "subject_token_type": self.SUBJECT_TOKEN_TYPE,
+ }
+ request = self.make_mock_request(status=http_client.OK, data=response)
+ credentials = self.make_credentials()
+
+ credentials.refresh(request)
+
+ self.assert_token_request_kwargs(
+ request.call_args.kwargs, headers, request_data
+ )
+ assert credentials.valid
+ assert credentials.expiry == expected_expiry
+ assert not credentials.expired
+ assert credentials.token == response["access_token"]
+
+ def test_refresh_impersonation_without_client_auth_success(self):
+ # Simulate service account access token expires in 2800 seconds.
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=2800)
+ ).isoformat("T") + "Z"
+ expected_expiry = datetime.datetime.strptime(expire_time, "%Y-%m-%dT%H:%M:%SZ")
+ # STS token exchange request/response.
+ token_response = self.SUCCESS_RESPONSE.copy()
+ token_headers = {"Content-Type": "application/x-www-form-urlencoded"}
+ token_request_data = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
+ "audience": self.AUDIENCE,
+ "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "subject_token": "subject_token_0",
+ "subject_token_type": self.SUBJECT_TOKEN_TYPE,
+ "scope": "https://www.googleapis.com/auth/iam",
+ }
+ # Service account impersonation request/response.
+ impersonation_response = {
+ "accessToken": "SA_ACCESS_TOKEN",
+ "expireTime": expire_time,
+ }
+ impersonation_headers = {
+ "Content-Type": "application/json",
+ "authorization": "Bearer {}".format(token_response["access_token"]),
+ }
+ impersonation_request_data = {
+ "delegates": None,
+ "scope": self.SCOPES,
+ "lifetime": "3600s",
+ }
+ # Initialize mock request to handle token exchange and service account
+ # impersonation request.
+ request = self.make_mock_request(
+ status=http_client.OK,
+ data=token_response,
+ impersonation_status=http_client.OK,
+ impersonation_data=impersonation_response,
+ )
+ # Initialize credentials with service account impersonation.
+ credentials = self.make_credentials(
+ service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
+ scopes=self.SCOPES,
+ )
+
+ credentials.refresh(request)
+
+ # Only 2 requests should be processed.
+ assert len(request.call_args_list) == 2
+ # Verify token exchange request parameters.
+ self.assert_token_request_kwargs(
+ request.call_args_list[0].kwargs, token_headers, token_request_data
+ )
+ # Verify service account impersonation request parameters.
+ self.assert_impersonation_request_kwargs(
+ request.call_args_list[1].kwargs,
+ impersonation_headers,
+ impersonation_request_data,
+ )
+ assert credentials.valid
+ assert credentials.expiry == expected_expiry
+ assert not credentials.expired
+ assert credentials.token == impersonation_response["accessToken"]
+
+ def test_refresh_without_client_auth_success_explicit_user_scopes_ignore_default_scopes(
+ self
+ ):
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
+ request_data = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
+ "audience": self.AUDIENCE,
+ "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "scope": "scope1 scope2",
+ "subject_token": "subject_token_0",
+ "subject_token_type": self.SUBJECT_TOKEN_TYPE,
+ }
+ request = self.make_mock_request(
+ status=http_client.OK, data=self.SUCCESS_RESPONSE
+ )
+ credentials = self.make_credentials(
+ scopes=["scope1", "scope2"],
+ # Default scopes will be ignored in favor of user scopes.
+ default_scopes=["ignored"],
+ )
+
+ credentials.refresh(request)
+
+ self.assert_token_request_kwargs(
+ request.call_args.kwargs, headers, request_data
+ )
+ assert credentials.valid
+ assert not credentials.expired
+ assert credentials.token == self.SUCCESS_RESPONSE["access_token"]
+ assert credentials.has_scopes(["scope1", "scope2"])
+ assert not credentials.has_scopes(["ignored"])
+
+ def test_refresh_without_client_auth_success_explicit_default_scopes_only(self):
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
+ request_data = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
+ "audience": self.AUDIENCE,
+ "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "scope": "scope1 scope2",
+ "subject_token": "subject_token_0",
+ "subject_token_type": self.SUBJECT_TOKEN_TYPE,
+ }
+ request = self.make_mock_request(
+ status=http_client.OK, data=self.SUCCESS_RESPONSE
+ )
+ credentials = self.make_credentials(
+ scopes=None,
+ # Default scopes will be used since user scopes are none.
+ default_scopes=["scope1", "scope2"],
+ )
+
+ credentials.refresh(request)
+
+ self.assert_token_request_kwargs(
+ request.call_args.kwargs, headers, request_data
+ )
+ assert credentials.valid
+ assert not credentials.expired
+ assert credentials.token == self.SUCCESS_RESPONSE["access_token"]
+ assert credentials.has_scopes(["scope1", "scope2"])
+
+ def test_refresh_without_client_auth_error(self):
+ request = self.make_mock_request(
+ status=http_client.BAD_REQUEST, data=self.ERROR_RESPONSE
+ )
+ credentials = self.make_credentials()
+
+ with pytest.raises(exceptions.OAuthError) as excinfo:
+ credentials.refresh(request)
+
+ assert excinfo.match(
+ r"Error code invalid_request: Invalid subject token - https://tools.ietf.org/html/rfc6749"
+ )
+ assert not credentials.expired
+ assert credentials.token is None
+
+ def test_refresh_impersonation_without_client_auth_error(self):
+ request = self.make_mock_request(
+ status=http_client.OK,
+ data=self.SUCCESS_RESPONSE,
+ impersonation_status=http_client.BAD_REQUEST,
+ impersonation_data=self.IMPERSONATION_ERROR_RESPONSE,
+ )
+ credentials = self.make_credentials(
+ service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
+ scopes=self.SCOPES,
+ )
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ credentials.refresh(request)
+
+ assert excinfo.match(r"Unable to acquire impersonated credentials")
+ assert not credentials.expired
+ assert credentials.token is None
+
+ def test_refresh_with_client_auth_success(self):
+ headers = {
+ "Content-Type": "application/x-www-form-urlencoded",
+ "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING),
+ }
+ request_data = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
+ "audience": self.AUDIENCE,
+ "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "subject_token": "subject_token_0",
+ "subject_token_type": self.SUBJECT_TOKEN_TYPE,
+ }
+ request = self.make_mock_request(
+ status=http_client.OK, data=self.SUCCESS_RESPONSE
+ )
+ credentials = self.make_credentials(
+ client_id=CLIENT_ID, client_secret=CLIENT_SECRET
+ )
+
+ credentials.refresh(request)
+
+ self.assert_token_request_kwargs(
+ request.call_args.kwargs, headers, request_data
+ )
+ assert credentials.valid
+ assert not credentials.expired
+ assert credentials.token == self.SUCCESS_RESPONSE["access_token"]
+
+ def test_refresh_impersonation_with_client_auth_success_ignore_default_scopes(self):
+ # Simulate service account access token expires in 2800 seconds.
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=2800)
+ ).isoformat("T") + "Z"
+ expected_expiry = datetime.datetime.strptime(expire_time, "%Y-%m-%dT%H:%M:%SZ")
+ # STS token exchange request/response.
+ token_response = self.SUCCESS_RESPONSE.copy()
+ token_headers = {
+ "Content-Type": "application/x-www-form-urlencoded",
+ "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING),
+ }
+ token_request_data = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
+ "audience": self.AUDIENCE,
+ "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "subject_token": "subject_token_0",
+ "subject_token_type": self.SUBJECT_TOKEN_TYPE,
+ "scope": "https://www.googleapis.com/auth/iam",
+ }
+ # Service account impersonation request/response.
+ impersonation_response = {
+ "accessToken": "SA_ACCESS_TOKEN",
+ "expireTime": expire_time,
+ }
+ impersonation_headers = {
+ "Content-Type": "application/json",
+ "authorization": "Bearer {}".format(token_response["access_token"]),
+ }
+ impersonation_request_data = {
+ "delegates": None,
+ "scope": self.SCOPES,
+ "lifetime": "3600s",
+ }
+ # Initialize mock request to handle token exchange and service account
+ # impersonation request.
+ request = self.make_mock_request(
+ status=http_client.OK,
+ data=token_response,
+ impersonation_status=http_client.OK,
+ impersonation_data=impersonation_response,
+ )
+ # Initialize credentials with service account impersonation and basic auth.
+ credentials = self.make_credentials(
+ client_id=CLIENT_ID,
+ client_secret=CLIENT_SECRET,
+ service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
+ scopes=self.SCOPES,
+ # Default scopes will be ignored since user scopes are specified.
+ default_scopes=["ignored"],
+ )
+
+ credentials.refresh(request)
+
+ # Only 2 requests should be processed.
+ assert len(request.call_args_list) == 2
+ # Verify token exchange request parameters.
+ self.assert_token_request_kwargs(
+ request.call_args_list[0].kwargs, token_headers, token_request_data
+ )
+ # Verify service account impersonation request parameters.
+ self.assert_impersonation_request_kwargs(
+ request.call_args_list[1].kwargs,
+ impersonation_headers,
+ impersonation_request_data,
+ )
+ assert credentials.valid
+ assert credentials.expiry == expected_expiry
+ assert not credentials.expired
+ assert credentials.token == impersonation_response["accessToken"]
+
+ def test_refresh_impersonation_with_client_auth_success_use_default_scopes(self):
+ # Simulate service account access token expires in 2800 seconds.
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=2800)
+ ).isoformat("T") + "Z"
+ expected_expiry = datetime.datetime.strptime(expire_time, "%Y-%m-%dT%H:%M:%SZ")
+ # STS token exchange request/response.
+ token_response = self.SUCCESS_RESPONSE.copy()
+ token_headers = {
+ "Content-Type": "application/x-www-form-urlencoded",
+ "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING),
+ }
+ token_request_data = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
+ "audience": self.AUDIENCE,
+ "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "subject_token": "subject_token_0",
+ "subject_token_type": self.SUBJECT_TOKEN_TYPE,
+ "scope": "https://www.googleapis.com/auth/iam",
+ }
+ # Service account impersonation request/response.
+ impersonation_response = {
+ "accessToken": "SA_ACCESS_TOKEN",
+ "expireTime": expire_time,
+ }
+ impersonation_headers = {
+ "Content-Type": "application/json",
+ "authorization": "Bearer {}".format(token_response["access_token"]),
+ }
+ impersonation_request_data = {
+ "delegates": None,
+ "scope": self.SCOPES,
+ "lifetime": "3600s",
+ }
+ # Initialize mock request to handle token exchange and service account
+ # impersonation request.
+ request = self.make_mock_request(
+ status=http_client.OK,
+ data=token_response,
+ impersonation_status=http_client.OK,
+ impersonation_data=impersonation_response,
+ )
+ # Initialize credentials with service account impersonation and basic auth.
+ credentials = self.make_credentials(
+ client_id=CLIENT_ID,
+ client_secret=CLIENT_SECRET,
+ service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
+ scopes=None,
+ # Default scopes will be used since user specified scopes are none.
+ default_scopes=self.SCOPES,
+ )
+
+ credentials.refresh(request)
+
+ # Only 2 requests should be processed.
+ assert len(request.call_args_list) == 2
+ # Verify token exchange request parameters.
+ self.assert_token_request_kwargs(
+ request.call_args_list[0].kwargs, token_headers, token_request_data
+ )
+ # Verify service account impersonation request parameters.
+ self.assert_impersonation_request_kwargs(
+ request.call_args_list[1].kwargs,
+ impersonation_headers,
+ impersonation_request_data,
+ )
+ assert credentials.valid
+ assert credentials.expiry == expected_expiry
+ assert not credentials.expired
+ assert credentials.token == impersonation_response["accessToken"]
+
+ def test_apply_without_quota_project_id(self):
+ headers = {}
+ request = self.make_mock_request(
+ status=http_client.OK, data=self.SUCCESS_RESPONSE
+ )
+ credentials = self.make_credentials()
+
+ credentials.refresh(request)
+ credentials.apply(headers)
+
+ assert headers == {
+ "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"])
+ }
+
+ def test_apply_impersonation_without_quota_project_id(self):
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600)
+ ).isoformat("T") + "Z"
+ # Service account impersonation response.
+ impersonation_response = {
+ "accessToken": "SA_ACCESS_TOKEN",
+ "expireTime": expire_time,
+ }
+ # Initialize mock request to handle token exchange and service account
+ # impersonation request.
+ request = self.make_mock_request(
+ status=http_client.OK,
+ data=self.SUCCESS_RESPONSE.copy(),
+ impersonation_status=http_client.OK,
+ impersonation_data=impersonation_response,
+ )
+ # Initialize credentials with service account impersonation.
+ credentials = self.make_credentials(
+ service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
+ scopes=self.SCOPES,
+ )
+ headers = {}
+
+ credentials.refresh(request)
+ credentials.apply(headers)
+
+ assert headers == {
+ "authorization": "Bearer {}".format(impersonation_response["accessToken"])
+ }
+
+ def test_apply_with_quota_project_id(self):
+ headers = {"other": "header-value"}
+ request = self.make_mock_request(
+ status=http_client.OK, data=self.SUCCESS_RESPONSE
+ )
+ credentials = self.make_credentials(quota_project_id=self.QUOTA_PROJECT_ID)
+
+ credentials.refresh(request)
+ credentials.apply(headers)
+
+ assert headers == {
+ "other": "header-value",
+ "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]),
+ "x-goog-user-project": self.QUOTA_PROJECT_ID,
+ }
+
+ def test_apply_impersonation_with_quota_project_id(self):
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600)
+ ).isoformat("T") + "Z"
+ # Service account impersonation response.
+ impersonation_response = {
+ "accessToken": "SA_ACCESS_TOKEN",
+ "expireTime": expire_time,
+ }
+ # Initialize mock request to handle token exchange and service account
+ # impersonation request.
+ request = self.make_mock_request(
+ status=http_client.OK,
+ data=self.SUCCESS_RESPONSE.copy(),
+ impersonation_status=http_client.OK,
+ impersonation_data=impersonation_response,
+ )
+ # Initialize credentials with service account impersonation.
+ credentials = self.make_credentials(
+ service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
+ scopes=self.SCOPES,
+ quota_project_id=self.QUOTA_PROJECT_ID,
+ )
+ headers = {"other": "header-value"}
+
+ credentials.refresh(request)
+ credentials.apply(headers)
+
+ assert headers == {
+ "other": "header-value",
+ "authorization": "Bearer {}".format(impersonation_response["accessToken"]),
+ "x-goog-user-project": self.QUOTA_PROJECT_ID,
+ }
+
+ def test_before_request(self):
+ headers = {"other": "header-value"}
+ request = self.make_mock_request(
+ status=http_client.OK, data=self.SUCCESS_RESPONSE
+ )
+ credentials = self.make_credentials()
+
+ # First call should call refresh, setting the token.
+ credentials.before_request(request, "POST", "https://example.com/api", headers)
+
+ assert headers == {
+ "other": "header-value",
+ "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]),
+ }
+
+ # Second call shouldn't call refresh.
+ credentials.before_request(request, "POST", "https://example.com/api", headers)
+
+ assert headers == {
+ "other": "header-value",
+ "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]),
+ }
+
+ def test_before_request_impersonation(self):
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600)
+ ).isoformat("T") + "Z"
+ # Service account impersonation response.
+ impersonation_response = {
+ "accessToken": "SA_ACCESS_TOKEN",
+ "expireTime": expire_time,
+ }
+ # Initialize mock request to handle token exchange and service account
+ # impersonation request.
+ request = self.make_mock_request(
+ status=http_client.OK,
+ data=self.SUCCESS_RESPONSE.copy(),
+ impersonation_status=http_client.OK,
+ impersonation_data=impersonation_response,
+ )
+ headers = {"other": "header-value"}
+ credentials = self.make_credentials(
+ service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL
+ )
+
+ # First call should call refresh, setting the token.
+ credentials.before_request(request, "POST", "https://example.com/api", headers)
+
+ assert headers == {
+ "other": "header-value",
+ "authorization": "Bearer {}".format(impersonation_response["accessToken"]),
+ }
+
+ # Second call shouldn't call refresh.
+ credentials.before_request(request, "POST", "https://example.com/api", headers)
+
+ assert headers == {
+ "other": "header-value",
+ "authorization": "Bearer {}".format(impersonation_response["accessToken"]),
+ }
+
+ @mock.patch("google.auth._helpers.utcnow")
+ def test_before_request_expired(self, utcnow):
+ headers = {}
+ request = self.make_mock_request(
+ status=http_client.OK, data=self.SUCCESS_RESPONSE
+ )
+ credentials = self.make_credentials()
+ credentials.token = "token"
+ utcnow.return_value = datetime.datetime.min
+ # Set the expiration to one second more than now plus the clock skew
+ # accomodation. These credentials should be valid.
+ credentials.expiry = (
+ datetime.datetime.min + _helpers.CLOCK_SKEW + datetime.timedelta(seconds=1)
+ )
+
+ assert credentials.valid
+ assert not credentials.expired
+
+ credentials.before_request(request, "POST", "https://example.com/api", headers)
+
+ # Cached token should be used.
+ assert headers == {"authorization": "Bearer token"}
+
+ # Next call should simulate 1 second passed.
+ utcnow.return_value = datetime.datetime.min + datetime.timedelta(seconds=1)
+
+ assert not credentials.valid
+ assert credentials.expired
+
+ credentials.before_request(request, "POST", "https://example.com/api", headers)
+
+ # New token should be retrieved.
+ assert headers == {
+ "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"])
+ }
+
+ @mock.patch("google.auth._helpers.utcnow")
+ def test_before_request_impersonation_expired(self, utcnow):
+ headers = {}
+ expire_time = (
+ datetime.datetime.min + datetime.timedelta(seconds=3601)
+ ).isoformat("T") + "Z"
+ # Service account impersonation response.
+ impersonation_response = {
+ "accessToken": "SA_ACCESS_TOKEN",
+ "expireTime": expire_time,
+ }
+ # Initialize mock request to handle token exchange and service account
+ # impersonation request.
+ request = self.make_mock_request(
+ status=http_client.OK,
+ data=self.SUCCESS_RESPONSE.copy(),
+ impersonation_status=http_client.OK,
+ impersonation_data=impersonation_response,
+ )
+ credentials = self.make_credentials(
+ service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL
+ )
+ credentials.token = "token"
+ utcnow.return_value = datetime.datetime.min
+ # Set the expiration to one second more than now plus the clock skew
+ # accomodation. These credentials should be valid.
+ credentials.expiry = (
+ datetime.datetime.min + _helpers.CLOCK_SKEW + datetime.timedelta(seconds=1)
+ )
+
+ assert credentials.valid
+ assert not credentials.expired
+
+ credentials.before_request(request, "POST", "https://example.com/api", headers)
+
+ # Cached token should be used.
+ assert headers == {"authorization": "Bearer token"}
+
+ # Next call should simulate 1 second passed. This will trigger the expiration
+ # threshold.
+ utcnow.return_value = datetime.datetime.min + datetime.timedelta(seconds=1)
+
+ assert not credentials.valid
+ assert credentials.expired
+
+ credentials.before_request(request, "POST", "https://example.com/api", headers)
+
+ # New token should be retrieved.
+ assert headers == {
+ "authorization": "Bearer {}".format(impersonation_response["accessToken"])
+ }
+
+ @pytest.mark.parametrize(
+ "audience",
+ [
+ # Legacy K8s audience format.
+ "identitynamespace:1f12345:my_provider",
+ # Unrealistic audiences.
+ "//iam.googleapis.com/projects",
+ "//iam.googleapis.com/projects/",
+ "//iam.googleapis.com/project/123456",
+ "//iam.googleapis.com/projects//123456",
+ "//iam.googleapis.com/prefix_projects/123456",
+ "//iam.googleapis.com/projects_suffix/123456",
+ ],
+ )
+ def test_project_number_indeterminable(self, audience):
+ credentials = CredentialsImpl(
+ audience=audience,
+ subject_token_type=self.SUBJECT_TOKEN_TYPE,
+ token_url=self.TOKEN_URL,
+ credential_source=self.CREDENTIAL_SOURCE,
+ )
+
+ assert credentials.project_number is None
+ assert credentials.get_project_id(None) is None
+
+ def test_project_number_determinable(self):
+ credentials = CredentialsImpl(
+ audience=self.AUDIENCE,
+ subject_token_type=self.SUBJECT_TOKEN_TYPE,
+ token_url=self.TOKEN_URL,
+ credential_source=self.CREDENTIAL_SOURCE,
+ )
+
+ assert credentials.project_number == self.PROJECT_NUMBER
+
+ def test_project_id_without_scopes(self):
+ # Initialize credentials with no scopes.
+ credentials = CredentialsImpl(
+ audience=self.AUDIENCE,
+ subject_token_type=self.SUBJECT_TOKEN_TYPE,
+ token_url=self.TOKEN_URL,
+ credential_source=self.CREDENTIAL_SOURCE,
+ )
+
+ assert credentials.get_project_id(None) is None
+
+ def test_get_project_id_cloud_resource_manager_success(self):
+ # STS token exchange request/response.
+ token_response = self.SUCCESS_RESPONSE.copy()
+ token_headers = {"Content-Type": "application/x-www-form-urlencoded"}
+ token_request_data = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
+ "audience": self.AUDIENCE,
+ "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "subject_token": "subject_token_0",
+ "subject_token_type": self.SUBJECT_TOKEN_TYPE,
+ "scope": "https://www.googleapis.com/auth/iam",
+ }
+ # Service account impersonation request/response.
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600)
+ ).isoformat("T") + "Z"
+ expected_expiry = datetime.datetime.strptime(expire_time, "%Y-%m-%dT%H:%M:%SZ")
+ impersonation_response = {
+ "accessToken": "SA_ACCESS_TOKEN",
+ "expireTime": expire_time,
+ }
+ impersonation_headers = {
+ "Content-Type": "application/json",
+ "x-goog-user-project": self.QUOTA_PROJECT_ID,
+ "authorization": "Bearer {}".format(token_response["access_token"]),
+ }
+ impersonation_request_data = {
+ "delegates": None,
+ "scope": self.SCOPES,
+ "lifetime": "3600s",
+ }
+ # Initialize mock request to handle token exchange, service account
+ # impersonation and cloud resource manager request.
+ request = self.make_mock_request(
+ status=http_client.OK,
+ data=self.SUCCESS_RESPONSE.copy(),
+ impersonation_status=http_client.OK,
+ impersonation_data=impersonation_response,
+ cloud_resource_manager_status=http_client.OK,
+ cloud_resource_manager_data=self.CLOUD_RESOURCE_MANAGER_SUCCESS_RESPONSE,
+ )
+ credentials = self.make_credentials(
+ service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL,
+ scopes=self.SCOPES,
+ quota_project_id=self.QUOTA_PROJECT_ID,
+ )
+
+ # Expected project ID from cloud resource manager response should be returned.
+ project_id = credentials.get_project_id(request)
+
+ assert project_id == self.PROJECT_ID
+ # 3 requests should be processed.
+ assert len(request.call_args_list) == 3
+ # Verify token exchange request parameters.
+ self.assert_token_request_kwargs(
+ request.call_args_list[0].kwargs, token_headers, token_request_data
+ )
+ # Verify service account impersonation request parameters.
+ self.assert_impersonation_request_kwargs(
+ request.call_args_list[1].kwargs,
+ impersonation_headers,
+ impersonation_request_data,
+ )
+ # In the process of getting project ID, an access token should be
+ # retrieved.
+ assert credentials.valid
+ assert credentials.expiry == expected_expiry
+ assert not credentials.expired
+ assert credentials.token == impersonation_response["accessToken"]
+ # Verify cloud resource manager request parameters.
+ self.assert_resource_manager_request_kwargs(
+ request.call_args_list[2].kwargs,
+ self.PROJECT_NUMBER,
+ {
+ "x-goog-user-project": self.QUOTA_PROJECT_ID,
+ "authorization": "Bearer {}".format(
+ impersonation_response["accessToken"]
+ ),
+ },
+ )
+
+ # Calling get_project_id again should return the cached project_id.
+ project_id = credentials.get_project_id(request)
+
+ assert project_id == self.PROJECT_ID
+ # No additional requests.
+ assert len(request.call_args_list) == 3
+
+ def test_get_project_id_cloud_resource_manager_error(self):
+ # Simulate resource doesn't have sufficient permissions to access
+ # cloud resource manager.
+ request = self.make_mock_request(
+ status=http_client.OK,
+ data=self.SUCCESS_RESPONSE.copy(),
+ cloud_resource_manager_status=http_client.UNAUTHORIZED,
+ )
+ credentials = self.make_credentials(scopes=self.SCOPES)
+
+ project_id = credentials.get_project_id(request)
+
+ assert project_id is None
+ # Only 2 requests to STS and cloud resource manager should be sent.
+ assert len(request.call_args_list) == 2
diff --git a/tests/test_identity_pool.py b/tests/test_identity_pool.py
new file mode 100644
index 0000000..c017ab5
--- /dev/null
+++ b/tests/test_identity_pool.py
@@ -0,0 +1,873 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import json
+import os
+
+import mock
+import pytest
+from six.moves import http_client
+from six.moves import urllib
+
+from google.auth import _helpers
+from google.auth import exceptions
+from google.auth import identity_pool
+from google.auth import transport
+
+
+CLIENT_ID = "username"
+CLIENT_SECRET = "password"
+# Base64 encoding of "username:password".
+BASIC_AUTH_ENCODING = "dXNlcm5hbWU6cGFzc3dvcmQ="
+SERVICE_ACCOUNT_EMAIL = "service-1234@service-name.iam.gserviceaccount.com"
+SERVICE_ACCOUNT_IMPERSONATION_URL = (
+ "https://us-east1-iamcredentials.googleapis.com/v1/projects/-"
+ + "/serviceAccounts/{}:generateAccessToken".format(SERVICE_ACCOUNT_EMAIL)
+)
+QUOTA_PROJECT_ID = "QUOTA_PROJECT_ID"
+SCOPES = ["scope1", "scope2"]
+DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
+SUBJECT_TOKEN_TEXT_FILE = os.path.join(DATA_DIR, "external_subject_token.txt")
+SUBJECT_TOKEN_JSON_FILE = os.path.join(DATA_DIR, "external_subject_token.json")
+SUBJECT_TOKEN_FIELD_NAME = "access_token"
+
+with open(SUBJECT_TOKEN_TEXT_FILE) as fh:
+ TEXT_FILE_SUBJECT_TOKEN = fh.read()
+
+with open(SUBJECT_TOKEN_JSON_FILE) as fh:
+ JSON_FILE_CONTENT = json.load(fh)
+ JSON_FILE_SUBJECT_TOKEN = JSON_FILE_CONTENT.get(SUBJECT_TOKEN_FIELD_NAME)
+
+TOKEN_URL = "https://sts.googleapis.com/v1/token"
+SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt"
+AUDIENCE = "//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/POOL_ID/providers/PROVIDER_ID"
+
+
+class TestCredentials(object):
+ CREDENTIAL_SOURCE_TEXT = {"file": SUBJECT_TOKEN_TEXT_FILE}
+ CREDENTIAL_SOURCE_JSON = {
+ "file": SUBJECT_TOKEN_JSON_FILE,
+ "format": {"type": "json", "subject_token_field_name": "access_token"},
+ }
+ CREDENTIAL_URL = "http://fakeurl.com"
+ CREDENTIAL_SOURCE_TEXT_URL = {"url": CREDENTIAL_URL}
+ CREDENTIAL_SOURCE_JSON_URL = {
+ "url": CREDENTIAL_URL,
+ "format": {"type": "json", "subject_token_field_name": "access_token"},
+ }
+ SUCCESS_RESPONSE = {
+ "access_token": "ACCESS_TOKEN",
+ "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "token_type": "Bearer",
+ "expires_in": 3600,
+ "scope": " ".join(SCOPES),
+ }
+
+ @classmethod
+ def make_mock_response(cls, status, data):
+ response = mock.create_autospec(transport.Response, instance=True)
+ response.status = status
+ if isinstance(data, dict):
+ response.data = json.dumps(data).encode("utf-8")
+ else:
+ response.data = data
+ return response
+
+ @classmethod
+ def make_mock_request(
+ cls, token_status=http_client.OK, token_data=None, *extra_requests
+ ):
+ responses = []
+ responses.append(cls.make_mock_response(token_status, token_data))
+
+ while len(extra_requests) > 0:
+ # If service account impersonation is requested, mock the expected response.
+ status, data, extra_requests = (
+ extra_requests[0],
+ extra_requests[1],
+ extra_requests[2:],
+ )
+ responses.append(cls.make_mock_response(status, data))
+
+ request = mock.create_autospec(transport.Request)
+ request.side_effect = responses
+
+ return request
+
+ @classmethod
+ def assert_credential_request_kwargs(
+ cls, request_kwargs, headers, url=CREDENTIAL_URL
+ ):
+ assert request_kwargs["url"] == url
+ assert request_kwargs["method"] == "GET"
+ assert request_kwargs["headers"] == headers
+ assert request_kwargs.get("body", None) is None
+
+ @classmethod
+ def assert_token_request_kwargs(
+ cls, request_kwargs, headers, request_data, token_url=TOKEN_URL
+ ):
+ assert request_kwargs["url"] == token_url
+ assert request_kwargs["method"] == "POST"
+ assert request_kwargs["headers"] == headers
+ assert request_kwargs["body"] is not None
+ body_tuples = urllib.parse.parse_qsl(request_kwargs["body"])
+ assert len(body_tuples) == len(request_data.keys())
+ for (k, v) in body_tuples:
+ assert v.decode("utf-8") == request_data[k.decode("utf-8")]
+
+ @classmethod
+ def assert_impersonation_request_kwargs(
+ cls,
+ request_kwargs,
+ headers,
+ request_data,
+ service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL,
+ ):
+ assert request_kwargs["url"] == service_account_impersonation_url
+ assert request_kwargs["method"] == "POST"
+ assert request_kwargs["headers"] == headers
+ assert request_kwargs["body"] is not None
+ body_json = json.loads(request_kwargs["body"].decode("utf-8"))
+ assert body_json == request_data
+
+ @classmethod
+ def assert_underlying_credentials_refresh(
+ cls,
+ credentials,
+ audience,
+ subject_token,
+ subject_token_type,
+ token_url,
+ service_account_impersonation_url=None,
+ basic_auth_encoding=None,
+ quota_project_id=None,
+ used_scopes=None,
+ credential_data=None,
+ scopes=None,
+ default_scopes=None,
+ ):
+ """Utility to assert that a credentials are initialized with the expected
+ attributes by calling refresh functionality and confirming response matches
+ expected one and that the underlying requests were populated with the
+ expected parameters.
+ """
+ # STS token exchange request/response.
+ token_response = cls.SUCCESS_RESPONSE.copy()
+ token_headers = {"Content-Type": "application/x-www-form-urlencoded"}
+ if basic_auth_encoding:
+ token_headers["Authorization"] = "Basic " + basic_auth_encoding
+
+ if service_account_impersonation_url:
+ token_scopes = "https://www.googleapis.com/auth/iam"
+ else:
+ token_scopes = " ".join(used_scopes or [])
+
+ token_request_data = {
+ "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
+ "audience": audience,
+ "requested_token_type": "urn:ietf:params:oauth:token-type:access_token",
+ "scope": token_scopes,
+ "subject_token": subject_token,
+ "subject_token_type": subject_token_type,
+ }
+
+ if service_account_impersonation_url:
+ # Service account impersonation request/response.
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0)
+ + datetime.timedelta(seconds=3600)
+ ).isoformat("T") + "Z"
+ impersonation_response = {
+ "accessToken": "SA_ACCESS_TOKEN",
+ "expireTime": expire_time,
+ }
+ impersonation_headers = {
+ "Content-Type": "application/json",
+ "authorization": "Bearer {}".format(token_response["access_token"]),
+ }
+ impersonation_request_data = {
+ "delegates": None,
+ "scope": used_scopes,
+ "lifetime": "3600s",
+ }
+
+ # Initialize mock request to handle token retrieval, token exchange and
+ # service account impersonation request.
+ requests = []
+ if credential_data:
+ requests.append((http_client.OK, credential_data))
+
+ token_request_index = len(requests)
+ requests.append((http_client.OK, token_response))
+
+ if service_account_impersonation_url:
+ impersonation_request_index = len(requests)
+ requests.append((http_client.OK, impersonation_response))
+
+ request = cls.make_mock_request(*[el for req in requests for el in req])
+
+ credentials.refresh(request)
+
+ assert len(request.call_args_list) == len(requests)
+ if credential_data:
+ cls.assert_credential_request_kwargs(request.call_args_list[0].kwargs, None)
+ # Verify token exchange request parameters.
+ cls.assert_token_request_kwargs(
+ request.call_args_list[token_request_index].kwargs,
+ token_headers,
+ token_request_data,
+ token_url,
+ )
+ # Verify service account impersonation request parameters if the request
+ # is processed.
+ if service_account_impersonation_url:
+ cls.assert_impersonation_request_kwargs(
+ request.call_args_list[impersonation_request_index].kwargs,
+ impersonation_headers,
+ impersonation_request_data,
+ service_account_impersonation_url,
+ )
+ assert credentials.token == impersonation_response["accessToken"]
+ else:
+ assert credentials.token == token_response["access_token"]
+ assert credentials.quota_project_id == quota_project_id
+ assert credentials.scopes == scopes
+ assert credentials.default_scopes == default_scopes
+
+ @classmethod
+ def make_credentials(
+ cls,
+ client_id=None,
+ client_secret=None,
+ quota_project_id=None,
+ scopes=None,
+ default_scopes=None,
+ service_account_impersonation_url=None,
+ credential_source=None,
+ ):
+ return identity_pool.Credentials(
+ audience=AUDIENCE,
+ subject_token_type=SUBJECT_TOKEN_TYPE,
+ token_url=TOKEN_URL,
+ service_account_impersonation_url=service_account_impersonation_url,
+ credential_source=credential_source,
+ client_id=client_id,
+ client_secret=client_secret,
+ quota_project_id=quota_project_id,
+ scopes=scopes,
+ default_scopes=default_scopes,
+ )
+
+ @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None)
+ def test_from_info_full_options(self, mock_init):
+ credentials = identity_pool.Credentials.from_info(
+ {
+ "audience": AUDIENCE,
+ "subject_token_type": SUBJECT_TOKEN_TYPE,
+ "token_url": TOKEN_URL,
+ "service_account_impersonation_url": SERVICE_ACCOUNT_IMPERSONATION_URL,
+ "client_id": CLIENT_ID,
+ "client_secret": CLIENT_SECRET,
+ "quota_project_id": QUOTA_PROJECT_ID,
+ "credential_source": self.CREDENTIAL_SOURCE_TEXT,
+ }
+ )
+
+ # Confirm identity_pool.Credentials instantiated with expected attributes.
+ assert isinstance(credentials, identity_pool.Credentials)
+ mock_init.assert_called_once_with(
+ audience=AUDIENCE,
+ subject_token_type=SUBJECT_TOKEN_TYPE,
+ token_url=TOKEN_URL,
+ service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL,
+ client_id=CLIENT_ID,
+ client_secret=CLIENT_SECRET,
+ credential_source=self.CREDENTIAL_SOURCE_TEXT,
+ quota_project_id=QUOTA_PROJECT_ID,
+ )
+
+ @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None)
+ def test_from_info_required_options_only(self, mock_init):
+ credentials = identity_pool.Credentials.from_info(
+ {
+ "audience": AUDIENCE,
+ "subject_token_type": SUBJECT_TOKEN_TYPE,
+ "token_url": TOKEN_URL,
+ "credential_source": self.CREDENTIAL_SOURCE_TEXT,
+ }
+ )
+
+ # Confirm identity_pool.Credentials instantiated with expected attributes.
+ assert isinstance(credentials, identity_pool.Credentials)
+ mock_init.assert_called_once_with(
+ audience=AUDIENCE,
+ subject_token_type=SUBJECT_TOKEN_TYPE,
+ token_url=TOKEN_URL,
+ service_account_impersonation_url=None,
+ client_id=None,
+ client_secret=None,
+ credential_source=self.CREDENTIAL_SOURCE_TEXT,
+ quota_project_id=None,
+ )
+
+ @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None)
+ def test_from_file_full_options(self, mock_init, tmpdir):
+ info = {
+ "audience": AUDIENCE,
+ "subject_token_type": SUBJECT_TOKEN_TYPE,
+ "token_url": TOKEN_URL,
+ "service_account_impersonation_url": SERVICE_ACCOUNT_IMPERSONATION_URL,
+ "client_id": CLIENT_ID,
+ "client_secret": CLIENT_SECRET,
+ "quota_project_id": QUOTA_PROJECT_ID,
+ "credential_source": self.CREDENTIAL_SOURCE_TEXT,
+ }
+ config_file = tmpdir.join("config.json")
+ config_file.write(json.dumps(info))
+ credentials = identity_pool.Credentials.from_file(str(config_file))
+
+ # Confirm identity_pool.Credentials instantiated with expected attributes.
+ assert isinstance(credentials, identity_pool.Credentials)
+ mock_init.assert_called_once_with(
+ audience=AUDIENCE,
+ subject_token_type=SUBJECT_TOKEN_TYPE,
+ token_url=TOKEN_URL,
+ service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL,
+ client_id=CLIENT_ID,
+ client_secret=CLIENT_SECRET,
+ credential_source=self.CREDENTIAL_SOURCE_TEXT,
+ quota_project_id=QUOTA_PROJECT_ID,
+ )
+
+ @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None)
+ def test_from_file_required_options_only(self, mock_init, tmpdir):
+ info = {
+ "audience": AUDIENCE,
+ "subject_token_type": SUBJECT_TOKEN_TYPE,
+ "token_url": TOKEN_URL,
+ "credential_source": self.CREDENTIAL_SOURCE_TEXT,
+ }
+ config_file = tmpdir.join("config.json")
+ config_file.write(json.dumps(info))
+ credentials = identity_pool.Credentials.from_file(str(config_file))
+
+ # Confirm identity_pool.Credentials instantiated with expected attributes.
+ assert isinstance(credentials, identity_pool.Credentials)
+ mock_init.assert_called_once_with(
+ audience=AUDIENCE,
+ subject_token_type=SUBJECT_TOKEN_TYPE,
+ token_url=TOKEN_URL,
+ service_account_impersonation_url=None,
+ client_id=None,
+ client_secret=None,
+ credential_source=self.CREDENTIAL_SOURCE_TEXT,
+ quota_project_id=None,
+ )
+
+ def test_constructor_invalid_options(self):
+ credential_source = {"unsupported": "value"}
+
+ with pytest.raises(ValueError) as excinfo:
+ self.make_credentials(credential_source=credential_source)
+
+ assert excinfo.match(r"Missing credential_source")
+
+ def test_constructor_invalid_options_url_and_file(self):
+ credential_source = {
+ "url": self.CREDENTIAL_URL,
+ "file": SUBJECT_TOKEN_TEXT_FILE,
+ }
+
+ with pytest.raises(ValueError) as excinfo:
+ self.make_credentials(credential_source=credential_source)
+
+ assert excinfo.match(r"Ambiguous credential_source")
+
+ def test_constructor_invalid_options_environment_id(self):
+ credential_source = {"url": self.CREDENTIAL_URL, "environment_id": "aws1"}
+
+ with pytest.raises(ValueError) as excinfo:
+ self.make_credentials(credential_source=credential_source)
+
+ assert excinfo.match(
+ r"Invalid Identity Pool credential_source field 'environment_id'"
+ )
+
+ def test_constructor_invalid_credential_source(self):
+ with pytest.raises(ValueError) as excinfo:
+ self.make_credentials(credential_source="non-dict")
+
+ assert excinfo.match(r"Missing credential_source")
+
+ def test_constructor_invalid_credential_source_format_type(self):
+ credential_source = {"format": {"type": "xml"}}
+
+ with pytest.raises(ValueError) as excinfo:
+ self.make_credentials(credential_source=credential_source)
+
+ assert excinfo.match(r"Invalid credential_source format 'xml'")
+
+ def test_constructor_missing_subject_token_field_name(self):
+ credential_source = {"format": {"type": "json"}}
+
+ with pytest.raises(ValueError) as excinfo:
+ self.make_credentials(credential_source=credential_source)
+
+ assert excinfo.match(
+ r"Missing subject_token_field_name for JSON credential_source format"
+ )
+
+ def test_retrieve_subject_token_missing_subject_token(self, tmpdir):
+ # Provide empty text file.
+ empty_file = tmpdir.join("empty.txt")
+ empty_file.write("")
+ credential_source = {"file": str(empty_file)}
+ credentials = self.make_credentials(credential_source=credential_source)
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ credentials.retrieve_subject_token(None)
+
+ assert excinfo.match(r"Missing subject_token in the credential_source file")
+
+ def test_retrieve_subject_token_text_file(self):
+ credentials = self.make_credentials(
+ credential_source=self.CREDENTIAL_SOURCE_TEXT
+ )
+
+ subject_token = credentials.retrieve_subject_token(None)
+
+ assert subject_token == TEXT_FILE_SUBJECT_TOKEN
+
+ def test_retrieve_subject_token_json_file(self):
+ credentials = self.make_credentials(
+ credential_source=self.CREDENTIAL_SOURCE_JSON
+ )
+
+ subject_token = credentials.retrieve_subject_token(None)
+
+ assert subject_token == JSON_FILE_SUBJECT_TOKEN
+
+ def test_retrieve_subject_token_json_file_invalid_field_name(self):
+ credential_source = {
+ "file": SUBJECT_TOKEN_JSON_FILE,
+ "format": {"type": "json", "subject_token_field_name": "not_found"},
+ }
+ credentials = self.make_credentials(credential_source=credential_source)
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ credentials.retrieve_subject_token(None)
+
+ assert excinfo.match(
+ "Unable to parse subject_token from JSON file '{}' using key '{}'".format(
+ SUBJECT_TOKEN_JSON_FILE, "not_found"
+ )
+ )
+
+ def test_retrieve_subject_token_invalid_json(self, tmpdir):
+ # Provide JSON file. This should result in JSON parsing error.
+ invalid_json_file = tmpdir.join("invalid.json")
+ invalid_json_file.write("{")
+ credential_source = {
+ "file": str(invalid_json_file),
+ "format": {"type": "json", "subject_token_field_name": "access_token"},
+ }
+ credentials = self.make_credentials(credential_source=credential_source)
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ credentials.retrieve_subject_token(None)
+
+ assert excinfo.match(
+ "Unable to parse subject_token from JSON file '{}' using key '{}'".format(
+ str(invalid_json_file), "access_token"
+ )
+ )
+
+ def test_retrieve_subject_token_file_not_found(self):
+ credential_source = {"file": "./not_found.txt"}
+ credentials = self.make_credentials(credential_source=credential_source)
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ credentials.retrieve_subject_token(None)
+
+ assert excinfo.match(r"File './not_found.txt' was not found")
+
+ def test_refresh_text_file_success_without_impersonation_ignore_default_scopes(
+ self
+ ):
+ credentials = self.make_credentials(
+ client_id=CLIENT_ID,
+ client_secret=CLIENT_SECRET,
+ # Test with text format type.
+ credential_source=self.CREDENTIAL_SOURCE_TEXT,
+ scopes=SCOPES,
+ # Default scopes should be ignored.
+ default_scopes=["ignored"],
+ )
+
+ self.assert_underlying_credentials_refresh(
+ credentials=credentials,
+ audience=AUDIENCE,
+ subject_token=TEXT_FILE_SUBJECT_TOKEN,
+ subject_token_type=SUBJECT_TOKEN_TYPE,
+ token_url=TOKEN_URL,
+ service_account_impersonation_url=None,
+ basic_auth_encoding=BASIC_AUTH_ENCODING,
+ quota_project_id=None,
+ used_scopes=SCOPES,
+ scopes=SCOPES,
+ default_scopes=["ignored"],
+ )
+
+ def test_refresh_text_file_success_without_impersonation_use_default_scopes(self):
+ credentials = self.make_credentials(
+ client_id=CLIENT_ID,
+ client_secret=CLIENT_SECRET,
+ # Test with text format type.
+ credential_source=self.CREDENTIAL_SOURCE_TEXT,
+ scopes=None,
+ # Default scopes should be used since user specified scopes are none.
+ default_scopes=SCOPES,
+ )
+
+ self.assert_underlying_credentials_refresh(
+ credentials=credentials,
+ audience=AUDIENCE,
+ subject_token=TEXT_FILE_SUBJECT_TOKEN,
+ subject_token_type=SUBJECT_TOKEN_TYPE,
+ token_url=TOKEN_URL,
+ service_account_impersonation_url=None,
+ basic_auth_encoding=BASIC_AUTH_ENCODING,
+ quota_project_id=None,
+ used_scopes=SCOPES,
+ scopes=None,
+ default_scopes=SCOPES,
+ )
+
+ def test_refresh_text_file_success_with_impersonation_ignore_default_scopes(self):
+ # Initialize credentials with service account impersonation and basic auth.
+ credentials = self.make_credentials(
+ # Test with text format type.
+ credential_source=self.CREDENTIAL_SOURCE_TEXT,
+ service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL,
+ scopes=SCOPES,
+ # Default scopes should be ignored.
+ default_scopes=["ignored"],
+ )
+
+ self.assert_underlying_credentials_refresh(
+ credentials=credentials,
+ audience=AUDIENCE,
+ subject_token=TEXT_FILE_SUBJECT_TOKEN,
+ subject_token_type=SUBJECT_TOKEN_TYPE,
+ token_url=TOKEN_URL,
+ service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL,
+ basic_auth_encoding=None,
+ quota_project_id=None,
+ used_scopes=SCOPES,
+ scopes=SCOPES,
+ default_scopes=["ignored"],
+ )
+
+ def test_refresh_text_file_success_with_impersonation_use_default_scopes(self):
+ # Initialize credentials with service account impersonation, basic auth
+ # and default scopes (no user scopes).
+ credentials = self.make_credentials(
+ # Test with text format type.
+ credential_source=self.CREDENTIAL_SOURCE_TEXT,
+ service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL,
+ scopes=None,
+ # Default scopes should be used since user specified scopes are none.
+ default_scopes=SCOPES,
+ )
+
+ self.assert_underlying_credentials_refresh(
+ credentials=credentials,
+ audience=AUDIENCE,
+ subject_token=TEXT_FILE_SUBJECT_TOKEN,
+ subject_token_type=SUBJECT_TOKEN_TYPE,
+ token_url=TOKEN_URL,
+ service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL,
+ basic_auth_encoding=None,
+ quota_project_id=None,
+ used_scopes=SCOPES,
+ scopes=None,
+ default_scopes=SCOPES,
+ )
+
+ def test_refresh_json_file_success_without_impersonation(self):
+ credentials = self.make_credentials(
+ client_id=CLIENT_ID,
+ client_secret=CLIENT_SECRET,
+ # Test with JSON format type.
+ credential_source=self.CREDENTIAL_SOURCE_JSON,
+ scopes=SCOPES,
+ )
+
+ self.assert_underlying_credentials_refresh(
+ credentials=credentials,
+ audience=AUDIENCE,
+ subject_token=JSON_FILE_SUBJECT_TOKEN,
+ subject_token_type=SUBJECT_TOKEN_TYPE,
+ token_url=TOKEN_URL,
+ service_account_impersonation_url=None,
+ basic_auth_encoding=BASIC_AUTH_ENCODING,
+ quota_project_id=None,
+ used_scopes=SCOPES,
+ scopes=SCOPES,
+ default_scopes=None,
+ )
+
+ def test_refresh_json_file_success_with_impersonation(self):
+ # Initialize credentials with service account impersonation and basic auth.
+ credentials = self.make_credentials(
+ # Test with JSON format type.
+ credential_source=self.CREDENTIAL_SOURCE_JSON,
+ service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL,
+ scopes=SCOPES,
+ )
+
+ self.assert_underlying_credentials_refresh(
+ credentials=credentials,
+ audience=AUDIENCE,
+ subject_token=JSON_FILE_SUBJECT_TOKEN,
+ subject_token_type=SUBJECT_TOKEN_TYPE,
+ token_url=TOKEN_URL,
+ service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL,
+ basic_auth_encoding=None,
+ quota_project_id=None,
+ used_scopes=SCOPES,
+ scopes=SCOPES,
+ default_scopes=None,
+ )
+
+ def test_refresh_with_retrieve_subject_token_error(self):
+ credential_source = {
+ "file": SUBJECT_TOKEN_JSON_FILE,
+ "format": {"type": "json", "subject_token_field_name": "not_found"},
+ }
+ credentials = self.make_credentials(credential_source=credential_source)
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ credentials.refresh(None)
+
+ assert excinfo.match(
+ "Unable to parse subject_token from JSON file '{}' using key '{}'".format(
+ SUBJECT_TOKEN_JSON_FILE, "not_found"
+ )
+ )
+
+ def test_retrieve_subject_token_from_url(self):
+ credentials = self.make_credentials(
+ credential_source=self.CREDENTIAL_SOURCE_TEXT_URL
+ )
+ request = self.make_mock_request(token_data=TEXT_FILE_SUBJECT_TOKEN)
+ subject_token = credentials.retrieve_subject_token(request)
+
+ assert subject_token == TEXT_FILE_SUBJECT_TOKEN
+ self.assert_credential_request_kwargs(request.call_args_list[0].kwargs, None)
+
+ def test_retrieve_subject_token_from_url_with_headers(self):
+ credentials = self.make_credentials(
+ credential_source={"url": self.CREDENTIAL_URL, "headers": {"foo": "bar"}}
+ )
+ request = self.make_mock_request(token_data=TEXT_FILE_SUBJECT_TOKEN)
+ subject_token = credentials.retrieve_subject_token(request)
+
+ assert subject_token == TEXT_FILE_SUBJECT_TOKEN
+ self.assert_credential_request_kwargs(
+ request.call_args_list[0].kwargs, {"foo": "bar"}
+ )
+
+ def test_retrieve_subject_token_from_url_json(self):
+ credentials = self.make_credentials(
+ credential_source=self.CREDENTIAL_SOURCE_JSON_URL
+ )
+ request = self.make_mock_request(token_data=JSON_FILE_CONTENT)
+ subject_token = credentials.retrieve_subject_token(request)
+
+ assert subject_token == JSON_FILE_SUBJECT_TOKEN
+ self.assert_credential_request_kwargs(request.call_args_list[0].kwargs, None)
+
+ def test_retrieve_subject_token_from_url_json_with_headers(self):
+ credentials = self.make_credentials(
+ credential_source={
+ "url": self.CREDENTIAL_URL,
+ "format": {"type": "json", "subject_token_field_name": "access_token"},
+ "headers": {"foo": "bar"},
+ }
+ )
+ request = self.make_mock_request(token_data=JSON_FILE_CONTENT)
+ subject_token = credentials.retrieve_subject_token(request)
+
+ assert subject_token == JSON_FILE_SUBJECT_TOKEN
+ self.assert_credential_request_kwargs(
+ request.call_args_list[0].kwargs, {"foo": "bar"}
+ )
+
+ def test_retrieve_subject_token_from_url_not_found(self):
+ credentials = self.make_credentials(
+ credential_source=self.CREDENTIAL_SOURCE_TEXT_URL
+ )
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ credentials.retrieve_subject_token(
+ self.make_mock_request(token_status=404, token_data=JSON_FILE_CONTENT)
+ )
+
+ assert excinfo.match("Unable to retrieve Identity Pool subject token")
+
+ def test_retrieve_subject_token_from_url_json_invalid_field(self):
+ credential_source = {
+ "url": self.CREDENTIAL_URL,
+ "format": {"type": "json", "subject_token_field_name": "not_found"},
+ }
+ credentials = self.make_credentials(credential_source=credential_source)
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ credentials.retrieve_subject_token(
+ self.make_mock_request(token_data=JSON_FILE_CONTENT)
+ )
+
+ assert excinfo.match(
+ "Unable to parse subject_token from JSON file '{}' using key '{}'".format(
+ self.CREDENTIAL_URL, "not_found"
+ )
+ )
+
+ def test_retrieve_subject_token_from_url_json_invalid_format(self):
+ credentials = self.make_credentials(
+ credential_source=self.CREDENTIAL_SOURCE_JSON_URL
+ )
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ credentials.retrieve_subject_token(self.make_mock_request(token_data="{"))
+
+ assert excinfo.match(
+ "Unable to parse subject_token from JSON file '{}' using key '{}'".format(
+ self.CREDENTIAL_URL, "access_token"
+ )
+ )
+
+ def test_refresh_text_file_success_without_impersonation_url(self):
+ credentials = self.make_credentials(
+ client_id=CLIENT_ID,
+ client_secret=CLIENT_SECRET,
+ # Test with text format type.
+ credential_source=self.CREDENTIAL_SOURCE_TEXT_URL,
+ scopes=SCOPES,
+ )
+
+ self.assert_underlying_credentials_refresh(
+ credentials=credentials,
+ audience=AUDIENCE,
+ subject_token=TEXT_FILE_SUBJECT_TOKEN,
+ subject_token_type=SUBJECT_TOKEN_TYPE,
+ token_url=TOKEN_URL,
+ service_account_impersonation_url=None,
+ basic_auth_encoding=BASIC_AUTH_ENCODING,
+ quota_project_id=None,
+ used_scopes=SCOPES,
+ scopes=SCOPES,
+ default_scopes=None,
+ credential_data=TEXT_FILE_SUBJECT_TOKEN,
+ )
+
+ def test_refresh_text_file_success_with_impersonation_url(self):
+ # Initialize credentials with service account impersonation and basic auth.
+ credentials = self.make_credentials(
+ # Test with text format type.
+ credential_source=self.CREDENTIAL_SOURCE_TEXT_URL,
+ service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL,
+ scopes=SCOPES,
+ )
+
+ self.assert_underlying_credentials_refresh(
+ credentials=credentials,
+ audience=AUDIENCE,
+ subject_token=TEXT_FILE_SUBJECT_TOKEN,
+ subject_token_type=SUBJECT_TOKEN_TYPE,
+ token_url=TOKEN_URL,
+ service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL,
+ basic_auth_encoding=None,
+ quota_project_id=None,
+ used_scopes=SCOPES,
+ scopes=SCOPES,
+ default_scopes=None,
+ credential_data=TEXT_FILE_SUBJECT_TOKEN,
+ )
+
+ def test_refresh_json_file_success_without_impersonation_url(self):
+ credentials = self.make_credentials(
+ client_id=CLIENT_ID,
+ client_secret=CLIENT_SECRET,
+ # Test with JSON format type.
+ credential_source=self.CREDENTIAL_SOURCE_JSON_URL,
+ scopes=SCOPES,
+ )
+
+ self.assert_underlying_credentials_refresh(
+ credentials=credentials,
+ audience=AUDIENCE,
+ subject_token=JSON_FILE_SUBJECT_TOKEN,
+ subject_token_type=SUBJECT_TOKEN_TYPE,
+ token_url=TOKEN_URL,
+ service_account_impersonation_url=None,
+ basic_auth_encoding=BASIC_AUTH_ENCODING,
+ quota_project_id=None,
+ used_scopes=SCOPES,
+ scopes=SCOPES,
+ default_scopes=None,
+ credential_data=JSON_FILE_CONTENT,
+ )
+
+ def test_refresh_json_file_success_with_impersonation_url(self):
+ # Initialize credentials with service account impersonation and basic auth.
+ credentials = self.make_credentials(
+ # Test with JSON format type.
+ credential_source=self.CREDENTIAL_SOURCE_JSON_URL,
+ service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL,
+ scopes=SCOPES,
+ )
+
+ self.assert_underlying_credentials_refresh(
+ credentials=credentials,
+ audience=AUDIENCE,
+ subject_token=JSON_FILE_SUBJECT_TOKEN,
+ subject_token_type=SUBJECT_TOKEN_TYPE,
+ token_url=TOKEN_URL,
+ service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL,
+ basic_auth_encoding=None,
+ quota_project_id=None,
+ used_scopes=SCOPES,
+ scopes=SCOPES,
+ default_scopes=None,
+ credential_data=JSON_FILE_CONTENT,
+ )
+
+ def test_refresh_with_retrieve_subject_token_error_url(self):
+ credential_source = {
+ "url": self.CREDENTIAL_URL,
+ "format": {"type": "json", "subject_token_field_name": "not_found"},
+ }
+ credentials = self.make_credentials(credential_source=credential_source)
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ credentials.refresh(self.make_mock_request(token_data=JSON_FILE_CONTENT))
+
+ assert excinfo.match(
+ "Unable to parse subject_token from JSON file '{}' using key '{}'".format(
+ self.CREDENTIAL_URL, "not_found"
+ )
+ )
diff --git a/tests/test_impersonated_credentials.py b/tests/test_impersonated_credentials.py
index 305f939..430c770 100644
--- a/tests/test_impersonated_credentials.py
+++ b/tests/test_impersonated_credentials.py
@@ -104,12 +104,17 @@
SIGNER, SERVICE_ACCOUNT_EMAIL, TOKEN_URI
)
USER_SOURCE_CREDENTIALS = credentials.Credentials(token="ABCDE")
+ IAM_ENDPOINT_OVERRIDE = (
+ "https://us-east1-iamcredentials.googleapis.com/v1/projects/-"
+ + "/serviceAccounts/{}:generateAccessToken".format(SERVICE_ACCOUNT_EMAIL)
+ )
def make_credentials(
self,
source_credentials=SOURCE_CREDENTIALS,
lifetime=LIFETIME,
target_principal=TARGET_PRINCIPAL,
+ iam_endpoint_override=None,
):
return Credentials(
@@ -118,6 +123,7 @@
target_scopes=self.TARGET_SCOPES,
delegates=self.DELEGATES,
lifetime=lifetime,
+ iam_endpoint_override=iam_endpoint_override,
)
def test_make_from_user_credentials(self):
@@ -172,6 +178,34 @@
assert credentials.valid
assert not credentials.expired
+ @pytest.mark.parametrize("use_data_bytes", [True, False])
+ def test_refresh_success_iam_endpoint_override(
+ self, use_data_bytes, mock_donor_credentials
+ ):
+ credentials = self.make_credentials(
+ lifetime=None, iam_endpoint_override=self.IAM_ENDPOINT_OVERRIDE
+ )
+ token = "token"
+
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
+ ).isoformat("T") + "Z"
+ response_body = {"accessToken": token, "expireTime": expire_time}
+
+ request = self.make_request(
+ data=json.dumps(response_body),
+ status=http_client.OK,
+ use_data_bytes=use_data_bytes,
+ )
+
+ credentials.refresh(request)
+
+ assert credentials.valid
+ assert not credentials.expired
+ # Confirm override endpoint used.
+ request_kwargs = request.call_args.kwargs
+ assert request_kwargs["url"] == self.IAM_ENDPOINT_OVERRIDE
+
@pytest.mark.parametrize("time_skew", [100, -100])
def test_refresh_source_credentials(self, time_skew):
credentials = self.make_credentials(lifetime=None)
@@ -317,6 +351,36 @@
quota_project_creds = credentials.with_quota_project("project-foo")
assert quota_project_creds._quota_project_id == "project-foo"
+ @pytest.mark.parametrize("use_data_bytes", [True, False])
+ def test_with_quota_project_iam_endpoint_override(
+ self, use_data_bytes, mock_donor_credentials
+ ):
+ credentials = self.make_credentials(
+ lifetime=None, iam_endpoint_override=self.IAM_ENDPOINT_OVERRIDE
+ )
+ token = "token"
+ # iam_endpoint_override should be copied to created credentials.
+ quota_project_creds = credentials.with_quota_project("project-foo")
+
+ expire_time = (
+ _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
+ ).isoformat("T") + "Z"
+ response_body = {"accessToken": token, "expireTime": expire_time}
+
+ request = self.make_request(
+ data=json.dumps(response_body),
+ status=http_client.OK,
+ use_data_bytes=use_data_bytes,
+ )
+
+ quota_project_creds.refresh(request)
+
+ assert quota_project_creds.valid
+ assert not quota_project_creds.expired
+ # Confirm override endpoint used.
+ request_kwargs = request.call_args.kwargs
+ assert request_kwargs["url"] == self.IAM_ENDPOINT_OVERRIDE
+
def test_id_token_success(
self, mock_donor_credentials, mock_authorizedsession_idtoken
):