feat: add asyncio based auth flow (#612)
* feat: asyncio http request logic and asynchronous credentials logic (#572)
Co-authored-by: Anirudh Baddepudi <43104821+anibadde@users.noreply.github.com>
diff --git a/tests_async/__init__.py b/tests_async/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests_async/__init__.py
diff --git a/tests_async/conftest.py b/tests_async/conftest.py
new file mode 100644
index 0000000..b4e90f0
--- /dev/null
+++ b/tests_async/conftest.py
@@ -0,0 +1,51 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+
+import mock
+import pytest
+
+
+def pytest_configure():
+ """Load public certificate and private key."""
+ pytest.data_dir = os.path.join(
+ os.path.abspath(os.path.join(__file__, "../..")), "tests/data"
+ )
+
+ with open(os.path.join(pytest.data_dir, "privatekey.pem"), "rb") as fh:
+ pytest.private_key_bytes = fh.read()
+
+ with open(os.path.join(pytest.data_dir, "public_cert.pem"), "rb") as fh:
+ pytest.public_cert_bytes = fh.read()
+
+
+@pytest.fixture
+def mock_non_existent_module(monkeypatch):
+ """Mocks a non-existing module in sys.modules.
+
+ Additionally mocks any non-existing modules specified in the dotted path.
+ """
+
+ def _mock_non_existent_module(path):
+ parts = path.split(".")
+ partial = []
+ for part in parts:
+ partial.append(part)
+ current_module = ".".join(partial)
+ if current_module not in sys.modules:
+ monkeypatch.setitem(sys.modules, current_module, mock.MagicMock())
+
+ return _mock_non_existent_module
diff --git a/tests_async/oauth2/test__client_async.py b/tests_async/oauth2/test__client_async.py
new file mode 100644
index 0000000..458937a
--- /dev/null
+++ b/tests_async/oauth2/test__client_async.py
@@ -0,0 +1,297 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import json
+
+import mock
+import pytest
+import six
+from six.moves import http_client
+from six.moves import urllib
+
+from google.auth import _helpers
+from google.auth import _jwt_async as jwt
+from google.auth import exceptions
+from google.oauth2 import _client as sync_client
+from google.oauth2 import _client_async as _client
+from tests.oauth2 import test__client as test_client
+
+
+def test__handle_error_response():
+ response_data = json.dumps({"error": "help", "error_description": "I'm alive"})
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ _client._handle_error_response(response_data)
+
+ assert excinfo.match(r"help: I\'m alive")
+
+
+def test__handle_error_response_non_json():
+ response_data = "Help, I'm alive"
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ _client._handle_error_response(response_data)
+
+ assert excinfo.match(r"Help, I\'m alive")
+
+
+@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
+def test__parse_expiry(unused_utcnow):
+ result = _client._parse_expiry({"expires_in": 500})
+ assert result == datetime.datetime.min + datetime.timedelta(seconds=500)
+
+
+def test__parse_expiry_none():
+ assert _client._parse_expiry({}) is None
+
+
+def make_request(response_data, status=http_client.OK):
+ response = mock.AsyncMock(spec=["transport.Response"])
+ response.status = status
+ data = json.dumps(response_data).encode("utf-8")
+ response.data = mock.AsyncMock(spec=["__call__", "read"])
+ response.data.read = mock.AsyncMock(spec=["__call__"], return_value=data)
+ response.content = mock.AsyncMock(spec=["__call__"], return_value=data)
+ request = mock.AsyncMock(spec=["transport.Request"])
+ request.return_value = response
+ return request
+
+
+@pytest.mark.asyncio
+async def test__token_endpoint_request():
+
+ request = make_request({"test": "response"})
+
+ result = await _client._token_endpoint_request(
+ request, "http://example.com", {"test": "params"}
+ )
+
+ # Check request call
+ request.assert_called_with(
+ method="POST",
+ url="http://example.com",
+ headers={"content-type": "application/x-www-form-urlencoded"},
+ body="test=params".encode("utf-8"),
+ )
+
+ # Check result
+ assert result == {"test": "response"}
+
+
+@pytest.mark.asyncio
+async def test__token_endpoint_request_error():
+ request = make_request({}, status=http_client.BAD_REQUEST)
+
+ with pytest.raises(exceptions.RefreshError):
+ await _client._token_endpoint_request(request, "http://example.com", {})
+
+
+@pytest.mark.asyncio
+async def test__token_endpoint_request_internal_failure_error():
+ request = make_request(
+ {"error_description": "internal_failure"}, status=http_client.BAD_REQUEST
+ )
+
+ with pytest.raises(exceptions.RefreshError):
+ await _client._token_endpoint_request(
+ request, "http://example.com", {"error_description": "internal_failure"}
+ )
+
+ request = make_request(
+ {"error": "internal_failure"}, status=http_client.BAD_REQUEST
+ )
+
+ with pytest.raises(exceptions.RefreshError):
+ await _client._token_endpoint_request(
+ request, "http://example.com", {"error": "internal_failure"}
+ )
+
+
+def verify_request_params(request, params):
+ request_body = request.call_args[1]["body"].decode("utf-8")
+ request_params = urllib.parse.parse_qs(request_body)
+
+ for key, value in six.iteritems(params):
+ assert request_params[key][0] == value
+
+
+@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
+@pytest.mark.asyncio
+async def test_jwt_grant(utcnow):
+ request = make_request(
+ {"access_token": "token", "expires_in": 500, "extra": "data"}
+ )
+
+ token, expiry, extra_data = await _client.jwt_grant(
+ request, "http://example.com", "assertion_value"
+ )
+
+ # Check request call
+ verify_request_params(
+ request,
+ {"grant_type": sync_client._JWT_GRANT_TYPE, "assertion": "assertion_value"},
+ )
+
+ # Check result
+ assert token == "token"
+ assert expiry == utcnow() + datetime.timedelta(seconds=500)
+ assert extra_data["extra"] == "data"
+
+
+@pytest.mark.asyncio
+async def test_jwt_grant_no_access_token():
+ request = make_request(
+ {
+ # No access token.
+ "expires_in": 500,
+ "extra": "data",
+ }
+ )
+
+ with pytest.raises(exceptions.RefreshError):
+ await _client.jwt_grant(request, "http://example.com", "assertion_value")
+
+
+@pytest.mark.asyncio
+async def test_id_token_jwt_grant():
+ now = _helpers.utcnow()
+ id_token_expiry = _helpers.datetime_to_secs(now)
+ id_token = jwt.encode(test_client.SIGNER, {"exp": id_token_expiry}).decode("utf-8")
+ request = make_request({"id_token": id_token, "extra": "data"})
+
+ token, expiry, extra_data = await _client.id_token_jwt_grant(
+ request, "http://example.com", "assertion_value"
+ )
+
+ # Check request call
+ verify_request_params(
+ request,
+ {"grant_type": sync_client._JWT_GRANT_TYPE, "assertion": "assertion_value"},
+ )
+
+ # Check result
+ assert token == id_token
+ # JWT does not store microseconds
+ now = now.replace(microsecond=0)
+ assert expiry == now
+ assert extra_data["extra"] == "data"
+
+
+@pytest.mark.asyncio
+async def test_id_token_jwt_grant_no_access_token():
+ request = make_request(
+ {
+ # No access token.
+ "expires_in": 500,
+ "extra": "data",
+ }
+ )
+
+ with pytest.raises(exceptions.RefreshError):
+ await _client.id_token_jwt_grant(
+ request, "http://example.com", "assertion_value"
+ )
+
+
+@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
+@pytest.mark.asyncio
+async def test_refresh_grant(unused_utcnow):
+ request = make_request(
+ {
+ "access_token": "token",
+ "refresh_token": "new_refresh_token",
+ "expires_in": 500,
+ "extra": "data",
+ }
+ )
+
+ token, refresh_token, expiry, extra_data = await _client.refresh_grant(
+ request, "http://example.com", "refresh_token", "client_id", "client_secret"
+ )
+
+ # Check request call
+ verify_request_params(
+ request,
+ {
+ "grant_type": sync_client._REFRESH_GRANT_TYPE,
+ "refresh_token": "refresh_token",
+ "client_id": "client_id",
+ "client_secret": "client_secret",
+ },
+ )
+
+ # Check result
+ assert token == "token"
+ assert refresh_token == "new_refresh_token"
+ assert expiry == datetime.datetime.min + datetime.timedelta(seconds=500)
+ assert extra_data["extra"] == "data"
+
+
+@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
+@pytest.mark.asyncio
+async def test_refresh_grant_with_scopes(unused_utcnow):
+ request = make_request(
+ {
+ "access_token": "token",
+ "refresh_token": "new_refresh_token",
+ "expires_in": 500,
+ "extra": "data",
+ "scope": test_client.SCOPES_AS_STRING,
+ }
+ )
+
+ token, refresh_token, expiry, extra_data = await _client.refresh_grant(
+ request,
+ "http://example.com",
+ "refresh_token",
+ "client_id",
+ "client_secret",
+ test_client.SCOPES_AS_LIST,
+ )
+
+ # Check request call.
+ verify_request_params(
+ request,
+ {
+ "grant_type": sync_client._REFRESH_GRANT_TYPE,
+ "refresh_token": "refresh_token",
+ "client_id": "client_id",
+ "client_secret": "client_secret",
+ "scope": test_client.SCOPES_AS_STRING,
+ },
+ )
+
+ # Check result.
+ assert token == "token"
+ assert refresh_token == "new_refresh_token"
+ assert expiry == datetime.datetime.min + datetime.timedelta(seconds=500)
+ assert extra_data["extra"] == "data"
+
+
+@pytest.mark.asyncio
+async def test_refresh_grant_no_access_token():
+ request = make_request(
+ {
+ # No access token.
+ "refresh_token": "new_refresh_token",
+ "expires_in": 500,
+ "extra": "data",
+ }
+ )
+
+ with pytest.raises(exceptions.RefreshError):
+ await _client.refresh_grant(
+ request, "http://example.com", "refresh_token", "client_id", "client_secret"
+ )
diff --git a/tests_async/oauth2/test_credentials_async.py b/tests_async/oauth2/test_credentials_async.py
new file mode 100644
index 0000000..5c883d6
--- /dev/null
+++ b/tests_async/oauth2/test_credentials_async.py
@@ -0,0 +1,478 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import json
+import os
+import pickle
+import sys
+
+import mock
+import pytest
+
+from google.auth import _helpers
+from google.auth import exceptions
+from google.oauth2 import _credentials_async as _credentials_async
+from google.oauth2 import credentials
+from tests.oauth2 import test_credentials
+
+
+class TestCredentials:
+
+ TOKEN_URI = "https://example.com/oauth2/token"
+ REFRESH_TOKEN = "refresh_token"
+ CLIENT_ID = "client_id"
+ CLIENT_SECRET = "client_secret"
+
+ @classmethod
+ def make_credentials(cls):
+ return _credentials_async.Credentials(
+ token=None,
+ refresh_token=cls.REFRESH_TOKEN,
+ token_uri=cls.TOKEN_URI,
+ client_id=cls.CLIENT_ID,
+ client_secret=cls.CLIENT_SECRET,
+ )
+
+ def test_default_state(self):
+ credentials = self.make_credentials()
+ assert not credentials.valid
+ # Expiration hasn't been set yet
+ assert not credentials.expired
+ # Scopes aren't required for these credentials
+ assert not credentials.requires_scopes
+ # Test properties
+ assert credentials.refresh_token == self.REFRESH_TOKEN
+ assert credentials.token_uri == self.TOKEN_URI
+ assert credentials.client_id == self.CLIENT_ID
+ assert credentials.client_secret == self.CLIENT_SECRET
+
+ @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True)
+ @mock.patch(
+ "google.auth._helpers.utcnow",
+ return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
+ )
+ @pytest.mark.asyncio
+ async def test_refresh_success(self, unused_utcnow, refresh_grant):
+ token = "token"
+ expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
+ grant_response = {"id_token": mock.sentinel.id_token}
+ refresh_grant.return_value = (
+ # Access token
+ token,
+ # New refresh token
+ None,
+ # Expiry,
+ expiry,
+ # Extra data
+ grant_response,
+ )
+
+ request = mock.AsyncMock(spec=["transport.Request"])
+ creds = self.make_credentials()
+
+ # Refresh credentials
+ await creds.refresh(request)
+
+ # Check jwt grant call.
+ refresh_grant.assert_called_with(
+ request,
+ self.TOKEN_URI,
+ self.REFRESH_TOKEN,
+ self.CLIENT_ID,
+ self.CLIENT_SECRET,
+ None,
+ )
+
+ # Check that the credentials have the token and expiry
+ assert creds.token == token
+ assert creds.expiry == expiry
+ assert creds.id_token == mock.sentinel.id_token
+
+ # Check that the credentials are valid (have a token and are not
+ # expired)
+ assert creds.valid
+
+ @pytest.mark.asyncio
+ async def test_refresh_no_refresh_token(self):
+ request = mock.AsyncMock(spec=["transport.Request"])
+ credentials_ = _credentials_async.Credentials(token=None, refresh_token=None)
+
+ with pytest.raises(exceptions.RefreshError, match="necessary fields"):
+ await credentials_.refresh(request)
+
+ request.assert_not_called()
+
+ @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True)
+ @mock.patch(
+ "google.auth._helpers.utcnow",
+ return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
+ )
+ @pytest.mark.asyncio
+ async def test_credentials_with_scopes_requested_refresh_success(
+ self, unused_utcnow, refresh_grant
+ ):
+ scopes = ["email", "profile"]
+ token = "token"
+ expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
+ grant_response = {"id_token": mock.sentinel.id_token}
+ refresh_grant.return_value = (
+ # Access token
+ token,
+ # New refresh token
+ None,
+ # Expiry,
+ expiry,
+ # Extra data
+ grant_response,
+ )
+
+ request = mock.AsyncMock(spec=["transport.Request"])
+ creds = _credentials_async.Credentials(
+ token=None,
+ refresh_token=self.REFRESH_TOKEN,
+ token_uri=self.TOKEN_URI,
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ scopes=scopes,
+ )
+
+ # Refresh credentials
+ await creds.refresh(request)
+
+ # Check jwt grant call.
+ refresh_grant.assert_called_with(
+ request,
+ self.TOKEN_URI,
+ self.REFRESH_TOKEN,
+ self.CLIENT_ID,
+ self.CLIENT_SECRET,
+ scopes,
+ )
+
+ # Check that the credentials have the token and expiry
+ assert creds.token == token
+ assert creds.expiry == expiry
+ assert creds.id_token == mock.sentinel.id_token
+ assert creds.has_scopes(scopes)
+
+ # Check that the credentials are valid (have a token and are not
+ # expired.)
+ assert creds.valid
+
+ @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True)
+ @mock.patch(
+ "google.auth._helpers.utcnow",
+ return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
+ )
+ @pytest.mark.asyncio
+ async def test_credentials_with_scopes_returned_refresh_success(
+ self, unused_utcnow, refresh_grant
+ ):
+ scopes = ["email", "profile"]
+ token = "token"
+ expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
+ grant_response = {
+ "id_token": mock.sentinel.id_token,
+ "scopes": " ".join(scopes),
+ }
+ refresh_grant.return_value = (
+ # Access token
+ token,
+ # New refresh token
+ None,
+ # Expiry,
+ expiry,
+ # Extra data
+ grant_response,
+ )
+
+ request = mock.AsyncMock(spec=["transport.Request"])
+ creds = _credentials_async.Credentials(
+ token=None,
+ refresh_token=self.REFRESH_TOKEN,
+ token_uri=self.TOKEN_URI,
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ scopes=scopes,
+ )
+
+ # Refresh credentials
+ await creds.refresh(request)
+
+ # Check jwt grant call.
+ refresh_grant.assert_called_with(
+ request,
+ self.TOKEN_URI,
+ self.REFRESH_TOKEN,
+ self.CLIENT_ID,
+ self.CLIENT_SECRET,
+ scopes,
+ )
+
+ # Check that the credentials have the token and expiry
+ assert creds.token == token
+ assert creds.expiry == expiry
+ assert creds.id_token == mock.sentinel.id_token
+ assert creds.has_scopes(scopes)
+
+ # Check that the credentials are valid (have a token and are not
+ # expired.)
+ assert creds.valid
+
+ @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True)
+ @mock.patch(
+ "google.auth._helpers.utcnow",
+ return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
+ )
+ @pytest.mark.asyncio
+ async def test_credentials_with_scopes_refresh_failure_raises_refresh_error(
+ self, unused_utcnow, refresh_grant
+ ):
+ scopes = ["email", "profile"]
+ scopes_returned = ["email"]
+ token = "token"
+ expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
+ grant_response = {
+ "id_token": mock.sentinel.id_token,
+ "scopes": " ".join(scopes_returned),
+ }
+ refresh_grant.return_value = (
+ # Access token
+ token,
+ # New refresh token
+ None,
+ # Expiry,
+ expiry,
+ # Extra data
+ grant_response,
+ )
+
+ request = mock.AsyncMock(spec=["transport.Request"])
+ creds = _credentials_async.Credentials(
+ token=None,
+ refresh_token=self.REFRESH_TOKEN,
+ token_uri=self.TOKEN_URI,
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ scopes=scopes,
+ )
+
+ # Refresh credentials
+ with pytest.raises(
+ exceptions.RefreshError, match="Not all requested scopes were granted"
+ ):
+ await creds.refresh(request)
+
+ # Check jwt grant call.
+ refresh_grant.assert_called_with(
+ request,
+ self.TOKEN_URI,
+ self.REFRESH_TOKEN,
+ self.CLIENT_ID,
+ self.CLIENT_SECRET,
+ scopes,
+ )
+
+ # Check that the credentials have the token and expiry
+ assert creds.token == token
+ assert creds.expiry == expiry
+ assert creds.id_token == mock.sentinel.id_token
+ assert creds.has_scopes(scopes)
+
+ # Check that the credentials are valid (have a token and are not
+ # expired.)
+ assert creds.valid
+
+ def test_apply_with_quota_project_id(self):
+ creds = _credentials_async.Credentials(
+ token="token",
+ refresh_token=self.REFRESH_TOKEN,
+ token_uri=self.TOKEN_URI,
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ quota_project_id="quota-project-123",
+ )
+
+ headers = {}
+ creds.apply(headers)
+ assert headers["x-goog-user-project"] == "quota-project-123"
+
+ def test_apply_with_no_quota_project_id(self):
+ creds = _credentials_async.Credentials(
+ token="token",
+ refresh_token=self.REFRESH_TOKEN,
+ token_uri=self.TOKEN_URI,
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ )
+
+ headers = {}
+ creds.apply(headers)
+ assert "x-goog-user-project" not in headers
+
+ def test_with_quota_project(self):
+ creds = _credentials_async.Credentials(
+ token="token",
+ refresh_token=self.REFRESH_TOKEN,
+ token_uri=self.TOKEN_URI,
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ quota_project_id="quota-project-123",
+ )
+
+ new_creds = creds.with_quota_project("new-project-456")
+ assert new_creds.quota_project_id == "new-project-456"
+ headers = {}
+ creds.apply(headers)
+ assert "x-goog-user-project" in headers
+
+ def test_from_authorized_user_info(self):
+ info = test_credentials.AUTH_USER_INFO.copy()
+
+ creds = _credentials_async.Credentials.from_authorized_user_info(info)
+ assert creds.client_secret == info["client_secret"]
+ assert creds.client_id == info["client_id"]
+ assert creds.refresh_token == info["refresh_token"]
+ assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT
+ assert creds.scopes is None
+
+ scopes = ["email", "profile"]
+ creds = _credentials_async.Credentials.from_authorized_user_info(info, scopes)
+ assert creds.client_secret == info["client_secret"]
+ assert creds.client_id == info["client_id"]
+ assert creds.refresh_token == info["refresh_token"]
+ assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT
+ assert creds.scopes == scopes
+
+ def test_from_authorized_user_file(self):
+ info = test_credentials.AUTH_USER_INFO.copy()
+
+ creds = _credentials_async.Credentials.from_authorized_user_file(
+ test_credentials.AUTH_USER_JSON_FILE
+ )
+ assert creds.client_secret == info["client_secret"]
+ assert creds.client_id == info["client_id"]
+ assert creds.refresh_token == info["refresh_token"]
+ assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT
+ assert creds.scopes is None
+
+ scopes = ["email", "profile"]
+ creds = _credentials_async.Credentials.from_authorized_user_file(
+ test_credentials.AUTH_USER_JSON_FILE, scopes
+ )
+ assert creds.client_secret == info["client_secret"]
+ assert creds.client_id == info["client_id"]
+ assert creds.refresh_token == info["refresh_token"]
+ assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT
+ assert creds.scopes == scopes
+
+ def test_to_json(self):
+ info = test_credentials.AUTH_USER_INFO.copy()
+ creds = _credentials_async.Credentials.from_authorized_user_info(info)
+
+ # Test with no `strip` arg
+ json_output = creds.to_json()
+ json_asdict = json.loads(json_output)
+ assert json_asdict.get("token") == creds.token
+ assert json_asdict.get("refresh_token") == creds.refresh_token
+ assert json_asdict.get("token_uri") == creds.token_uri
+ assert json_asdict.get("client_id") == creds.client_id
+ assert json_asdict.get("scopes") == creds.scopes
+ assert json_asdict.get("client_secret") == creds.client_secret
+
+ # Test with a `strip` arg
+ json_output = creds.to_json(strip=["client_secret"])
+ json_asdict = json.loads(json_output)
+ assert json_asdict.get("token") == creds.token
+ assert json_asdict.get("refresh_token") == creds.refresh_token
+ assert json_asdict.get("token_uri") == creds.token_uri
+ assert json_asdict.get("client_id") == creds.client_id
+ assert json_asdict.get("scopes") == creds.scopes
+ assert json_asdict.get("client_secret") is None
+
+ def test_pickle_and_unpickle(self):
+ creds = self.make_credentials()
+ unpickled = pickle.loads(pickle.dumps(creds))
+
+ # make sure attributes aren't lost during pickling
+ assert list(creds.__dict__).sort() == list(unpickled.__dict__).sort()
+
+ for attr in list(creds.__dict__):
+ assert getattr(creds, attr) == getattr(unpickled, attr)
+
+ def test_pickle_with_missing_attribute(self):
+ creds = self.make_credentials()
+
+ # remove an optional attribute before pickling
+ # this mimics a pickle created with a previous class definition with
+ # fewer attributes
+ del creds.__dict__["_quota_project_id"]
+
+ unpickled = pickle.loads(pickle.dumps(creds))
+
+ # Attribute should be initialized by `__setstate__`
+ assert unpickled.quota_project_id is None
+
+ # pickles are not compatible across versions
+ @pytest.mark.skipif(
+ sys.version_info < (3, 5),
+ reason="pickle file can only be loaded with Python >= 3.5",
+ )
+ def test_unpickle_old_credentials_pickle(self):
+ # make sure a credentials file pickled with an older
+ # library version (google-auth==1.5.1) can be unpickled
+ with open(
+ os.path.join(test_credentials.DATA_DIR, "old_oauth_credentials_py3.pickle"),
+ "rb",
+ ) as f:
+ credentials = pickle.load(f)
+ assert credentials.quota_project_id is None
+
+
+class TestUserAccessTokenCredentials(object):
+ def test_instance(self):
+ cred = _credentials_async.UserAccessTokenCredentials()
+ assert cred._account is None
+
+ cred = cred.with_account("account")
+ assert cred._account == "account"
+
+ @mock.patch("google.auth._cloud_sdk.get_auth_access_token", autospec=True)
+ def test_refresh(self, get_auth_access_token):
+ get_auth_access_token.return_value = "access_token"
+ cred = _credentials_async.UserAccessTokenCredentials()
+ cred.refresh(None)
+ assert cred.token == "access_token"
+
+ def test_with_quota_project(self):
+ cred = _credentials_async.UserAccessTokenCredentials()
+ quota_project_cred = cred.with_quota_project("project-foo")
+
+ assert quota_project_cred._quota_project_id == "project-foo"
+ assert quota_project_cred._account == cred._account
+
+ @mock.patch(
+ "google.oauth2._credentials_async.UserAccessTokenCredentials.apply",
+ autospec=True,
+ )
+ @mock.patch(
+ "google.oauth2._credentials_async.UserAccessTokenCredentials.refresh",
+ autospec=True,
+ )
+ def test_before_request(self, refresh, apply):
+ cred = _credentials_async.UserAccessTokenCredentials()
+ cred.before_request(mock.Mock(), "GET", "https://example.com", {})
+ refresh.assert_called()
+ apply.assert_called()
diff --git a/tests_async/oauth2/test_id_token.py b/tests_async/oauth2/test_id_token.py
new file mode 100644
index 0000000..a46bd61
--- /dev/null
+++ b/tests_async/oauth2/test_id_token.py
@@ -0,0 +1,205 @@
+# Copyright 2020 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import mock
+import pytest
+
+from google.auth import environment_vars
+from google.auth import exceptions
+import google.auth.compute_engine._metadata
+from google.oauth2 import _id_token_async as id_token
+from google.oauth2 import id_token as sync_id_token
+from tests.oauth2 import test_id_token
+
+
+def make_request(status, data=None):
+ response = mock.AsyncMock(spec=["transport.Response"])
+ response.status = status
+
+ if data is not None:
+ response.data = mock.AsyncMock(spec=["__call__", "read"])
+ response.data.read = mock.AsyncMock(spec=["__call__"], return_value=data)
+
+ request = mock.AsyncMock(spec=["transport.Request"])
+ request.return_value = response
+ return request
+
+
+@pytest.mark.asyncio
+async def test__fetch_certs_success():
+ certs = {"1": "cert"}
+ request = make_request(200, certs)
+
+ returned_certs = await id_token._fetch_certs(request, mock.sentinel.cert_url)
+
+ request.assert_called_once_with(mock.sentinel.cert_url, method="GET")
+ assert returned_certs == certs
+
+
+@pytest.mark.asyncio
+async def test__fetch_certs_failure():
+ request = make_request(404)
+
+ with pytest.raises(exceptions.TransportError):
+ await id_token._fetch_certs(request, mock.sentinel.cert_url)
+
+ request.assert_called_once_with(mock.sentinel.cert_url, method="GET")
+
+
+@mock.patch("google.auth.jwt.decode", autospec=True)
+@mock.patch("google.oauth2._id_token_async._fetch_certs", autospec=True)
+@pytest.mark.asyncio
+async def test_verify_token(_fetch_certs, decode):
+ result = await id_token.verify_token(mock.sentinel.token, mock.sentinel.request)
+
+ assert result == decode.return_value
+ _fetch_certs.assert_called_once_with(
+ mock.sentinel.request, sync_id_token._GOOGLE_OAUTH2_CERTS_URL
+ )
+ decode.assert_called_once_with(
+ mock.sentinel.token, certs=_fetch_certs.return_value, audience=None
+ )
+
+
+@mock.patch("google.auth.jwt.decode", autospec=True)
+@mock.patch("google.oauth2._id_token_async._fetch_certs", autospec=True)
+@pytest.mark.asyncio
+async def test_verify_token_args(_fetch_certs, decode):
+ result = await id_token.verify_token(
+ mock.sentinel.token,
+ mock.sentinel.request,
+ audience=mock.sentinel.audience,
+ certs_url=mock.sentinel.certs_url,
+ )
+
+ assert result == decode.return_value
+ _fetch_certs.assert_called_once_with(mock.sentinel.request, mock.sentinel.certs_url)
+ decode.assert_called_once_with(
+ mock.sentinel.token,
+ certs=_fetch_certs.return_value,
+ audience=mock.sentinel.audience,
+ )
+
+
+@mock.patch("google.oauth2._id_token_async.verify_token", autospec=True)
+@pytest.mark.asyncio
+async def test_verify_oauth2_token(verify_token):
+ verify_token.return_value = {"iss": "accounts.google.com"}
+ result = await id_token.verify_oauth2_token(
+ mock.sentinel.token, mock.sentinel.request, audience=mock.sentinel.audience
+ )
+
+ assert result == verify_token.return_value
+ verify_token.assert_called_once_with(
+ mock.sentinel.token,
+ mock.sentinel.request,
+ audience=mock.sentinel.audience,
+ certs_url=sync_id_token._GOOGLE_OAUTH2_CERTS_URL,
+ )
+
+
+@mock.patch("google.oauth2._id_token_async.verify_token", autospec=True)
+@pytest.mark.asyncio
+async def test_verify_oauth2_token_invalid_iss(verify_token):
+ verify_token.return_value = {"iss": "invalid_issuer"}
+
+ with pytest.raises(exceptions.GoogleAuthError):
+ await id_token.verify_oauth2_token(
+ mock.sentinel.token, mock.sentinel.request, audience=mock.sentinel.audience
+ )
+
+
+@mock.patch("google.oauth2._id_token_async.verify_token", autospec=True)
+@pytest.mark.asyncio
+async def test_verify_firebase_token(verify_token):
+ result = await id_token.verify_firebase_token(
+ mock.sentinel.token, mock.sentinel.request, audience=mock.sentinel.audience
+ )
+
+ assert result == verify_token.return_value
+ verify_token.assert_called_once_with(
+ mock.sentinel.token,
+ mock.sentinel.request,
+ audience=mock.sentinel.audience,
+ certs_url=sync_id_token._GOOGLE_APIS_CERTS_URL,
+ )
+
+
+@pytest.mark.asyncio
+async def test_fetch_id_token_from_metadata_server():
+ def mock_init(self, request, audience, use_metadata_identity_endpoint):
+ assert use_metadata_identity_endpoint
+ self.token = "id_token"
+
+ with mock.patch.multiple(
+ google.auth.compute_engine.IDTokenCredentials,
+ __init__=mock_init,
+ refresh=mock.Mock(),
+ ):
+ request = mock.AsyncMock()
+ token = await id_token.fetch_id_token(request, "https://pubsub.googleapis.com")
+ assert token == "id_token"
+
+
+@mock.patch.object(
+ google.auth.compute_engine.IDTokenCredentials,
+ "__init__",
+ side_effect=exceptions.TransportError(),
+)
+@pytest.mark.asyncio
+async def test_fetch_id_token_from_explicit_cred_json_file(mock_init, monkeypatch):
+ monkeypatch.setenv(environment_vars.CREDENTIALS, test_id_token.SERVICE_ACCOUNT_FILE)
+
+ async def mock_refresh(self, request):
+ self.token = "id_token"
+
+ with mock.patch.object(
+ google.oauth2._service_account_async.IDTokenCredentials, "refresh", mock_refresh
+ ):
+ request = mock.AsyncMock()
+ token = await id_token.fetch_id_token(request, "https://pubsub.googleapis.com")
+ assert token == "id_token"
+
+
+@mock.patch.object(
+ google.auth.compute_engine.IDTokenCredentials,
+ "__init__",
+ side_effect=exceptions.TransportError(),
+)
+@pytest.mark.asyncio
+async def test_fetch_id_token_no_cred_json_file(mock_init, monkeypatch):
+ monkeypatch.delenv(environment_vars.CREDENTIALS, raising=False)
+
+ with pytest.raises(exceptions.DefaultCredentialsError):
+ request = mock.AsyncMock()
+ await id_token.fetch_id_token(request, "https://pubsub.googleapis.com")
+
+
+@mock.patch.object(
+ google.auth.compute_engine.IDTokenCredentials,
+ "__init__",
+ side_effect=exceptions.TransportError(),
+)
+@pytest.mark.asyncio
+async def test_fetch_id_token_invalid_cred_file(mock_init, monkeypatch):
+ not_json_file = os.path.join(
+ os.path.dirname(__file__), "../../tests/data/public_cert.pem"
+ )
+ monkeypatch.setenv(environment_vars.CREDENTIALS, not_json_file)
+
+ with pytest.raises(exceptions.DefaultCredentialsError):
+ request = mock.AsyncMock()
+ await id_token.fetch_id_token(request, "https://pubsub.googleapis.com")
diff --git a/tests_async/oauth2/test_service_account_async.py b/tests_async/oauth2/test_service_account_async.py
new file mode 100644
index 0000000..4079453
--- /dev/null
+++ b/tests_async/oauth2/test_service_account_async.py
@@ -0,0 +1,372 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+
+import mock
+import pytest
+
+from google.auth import _helpers
+from google.auth import crypt
+from google.auth import jwt
+from google.auth import transport
+from google.oauth2 import _service_account_async as service_account
+from tests.oauth2 import test_service_account
+
+
+class TestCredentials(object):
+ SERVICE_ACCOUNT_EMAIL = "service-account@example.com"
+ TOKEN_URI = "https://example.com/oauth2/token"
+
+ @classmethod
+ def make_credentials(cls):
+ return service_account.Credentials(
+ test_service_account.SIGNER, cls.SERVICE_ACCOUNT_EMAIL, cls.TOKEN_URI
+ )
+
+ def test_from_service_account_info(self):
+ credentials = service_account.Credentials.from_service_account_info(
+ test_service_account.SERVICE_ACCOUNT_INFO
+ )
+
+ assert (
+ credentials._signer.key_id
+ == test_service_account.SERVICE_ACCOUNT_INFO["private_key_id"]
+ )
+ assert (
+ credentials.service_account_email
+ == test_service_account.SERVICE_ACCOUNT_INFO["client_email"]
+ )
+ assert (
+ credentials._token_uri
+ == test_service_account.SERVICE_ACCOUNT_INFO["token_uri"]
+ )
+
+ def test_from_service_account_info_args(self):
+ info = test_service_account.SERVICE_ACCOUNT_INFO.copy()
+ scopes = ["email", "profile"]
+ subject = "subject"
+ additional_claims = {"meta": "data"}
+
+ credentials = service_account.Credentials.from_service_account_info(
+ info, scopes=scopes, subject=subject, additional_claims=additional_claims
+ )
+
+ assert credentials.service_account_email == info["client_email"]
+ assert credentials.project_id == info["project_id"]
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._token_uri == info["token_uri"]
+ assert credentials._scopes == scopes
+ assert credentials._subject == subject
+ assert credentials._additional_claims == additional_claims
+
+ def test_from_service_account_file(self):
+ info = test_service_account.SERVICE_ACCOUNT_INFO.copy()
+
+ credentials = service_account.Credentials.from_service_account_file(
+ test_service_account.SERVICE_ACCOUNT_JSON_FILE
+ )
+
+ assert credentials.service_account_email == info["client_email"]
+ assert credentials.project_id == info["project_id"]
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._token_uri == info["token_uri"]
+
+ def test_from_service_account_file_args(self):
+ info = test_service_account.SERVICE_ACCOUNT_INFO.copy()
+ scopes = ["email", "profile"]
+ subject = "subject"
+ additional_claims = {"meta": "data"}
+
+ credentials = service_account.Credentials.from_service_account_file(
+ test_service_account.SERVICE_ACCOUNT_JSON_FILE,
+ subject=subject,
+ scopes=scopes,
+ additional_claims=additional_claims,
+ )
+
+ assert credentials.service_account_email == info["client_email"]
+ assert credentials.project_id == info["project_id"]
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._token_uri == info["token_uri"]
+ assert credentials._scopes == scopes
+ assert credentials._subject == subject
+ assert credentials._additional_claims == additional_claims
+
+ def test_default_state(self):
+ credentials = self.make_credentials()
+ assert not credentials.valid
+ # Expiration hasn't been set yet
+ assert not credentials.expired
+ # Scopes haven't been specified yet
+ assert credentials.requires_scopes
+
+ def test_sign_bytes(self):
+ credentials = self.make_credentials()
+ to_sign = b"123"
+ signature = credentials.sign_bytes(to_sign)
+ assert crypt.verify_signature(
+ to_sign, signature, test_service_account.PUBLIC_CERT_BYTES
+ )
+
+ def test_signer(self):
+ credentials = self.make_credentials()
+ assert isinstance(credentials.signer, crypt.Signer)
+
+ def test_signer_email(self):
+ credentials = self.make_credentials()
+ assert credentials.signer_email == self.SERVICE_ACCOUNT_EMAIL
+
+ def test_create_scoped(self):
+ credentials = self.make_credentials()
+ scopes = ["email", "profile"]
+ credentials = credentials.with_scopes(scopes)
+ assert credentials._scopes == scopes
+
+ def test_with_claims(self):
+ credentials = self.make_credentials()
+ new_credentials = credentials.with_claims({"meep": "moop"})
+ assert new_credentials._additional_claims == {"meep": "moop"}
+
+ def test_with_quota_project(self):
+ credentials = self.make_credentials()
+ new_credentials = credentials.with_quota_project("new-project-456")
+ assert new_credentials.quota_project_id == "new-project-456"
+ hdrs = {}
+ new_credentials.apply(hdrs, token="tok")
+ assert "x-goog-user-project" in hdrs
+
+ def test__make_authorization_grant_assertion(self):
+ credentials = self.make_credentials()
+ token = credentials._make_authorization_grant_assertion()
+ payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES)
+ assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL
+ assert payload["aud"] == self.TOKEN_URI
+
+ def test__make_authorization_grant_assertion_scoped(self):
+ credentials = self.make_credentials()
+ scopes = ["email", "profile"]
+ credentials = credentials.with_scopes(scopes)
+ token = credentials._make_authorization_grant_assertion()
+ payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES)
+ assert payload["scope"] == "email profile"
+
+ def test__make_authorization_grant_assertion_subject(self):
+ credentials = self.make_credentials()
+ subject = "user@example.com"
+ credentials = credentials.with_subject(subject)
+ token = credentials._make_authorization_grant_assertion()
+ payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES)
+ assert payload["sub"] == subject
+
+ @mock.patch("google.oauth2._client_async.jwt_grant", autospec=True)
+ @pytest.mark.asyncio
+ async def test_refresh_success(self, jwt_grant):
+ credentials = self.make_credentials()
+ token = "token"
+ jwt_grant.return_value = (
+ token,
+ _helpers.utcnow() + datetime.timedelta(seconds=500),
+ {},
+ )
+ request = mock.create_autospec(transport.Request, instance=True)
+
+ # Refresh credentials
+ await credentials.refresh(request)
+
+ # Check jwt grant call.
+ assert jwt_grant.called
+
+ called_request, token_uri, assertion = jwt_grant.call_args[0]
+ assert called_request == request
+ assert token_uri == credentials._token_uri
+ assert jwt.decode(assertion, test_service_account.PUBLIC_CERT_BYTES)
+ # No further assertion done on the token, as there are separate tests
+ # for checking the authorization grant assertion.
+
+ # Check that the credentials have the token.
+ assert credentials.token == token
+
+ # Check that the credentials are valid (have a token and are not
+ # expired)
+ assert credentials.valid
+
+ @mock.patch("google.oauth2._client_async.jwt_grant", autospec=True)
+ @pytest.mark.asyncio
+ async def test_before_request_refreshes(self, jwt_grant):
+ credentials = self.make_credentials()
+ token = "token"
+ jwt_grant.return_value = (
+ token,
+ _helpers.utcnow() + datetime.timedelta(seconds=500),
+ None,
+ )
+ request = mock.create_autospec(transport.Request, instance=True)
+
+ # Credentials should start as invalid
+ assert not credentials.valid
+
+ # before_request should cause a refresh
+ await credentials.before_request(request, "GET", "http://example.com?a=1#3", {})
+
+ # The refresh endpoint should've been called.
+ assert jwt_grant.called
+
+ # Credentials should now be valid.
+ assert credentials.valid
+
+
+class TestIDTokenCredentials(object):
+ SERVICE_ACCOUNT_EMAIL = "service-account@example.com"
+ TOKEN_URI = "https://example.com/oauth2/token"
+ TARGET_AUDIENCE = "https://example.com"
+
+ @classmethod
+ def make_credentials(cls):
+ return service_account.IDTokenCredentials(
+ test_service_account.SIGNER,
+ cls.SERVICE_ACCOUNT_EMAIL,
+ cls.TOKEN_URI,
+ cls.TARGET_AUDIENCE,
+ )
+
+ def test_from_service_account_info(self):
+ credentials = service_account.IDTokenCredentials.from_service_account_info(
+ test_service_account.SERVICE_ACCOUNT_INFO,
+ target_audience=self.TARGET_AUDIENCE,
+ )
+
+ assert (
+ credentials._signer.key_id
+ == test_service_account.SERVICE_ACCOUNT_INFO["private_key_id"]
+ )
+ assert (
+ credentials.service_account_email
+ == test_service_account.SERVICE_ACCOUNT_INFO["client_email"]
+ )
+ assert (
+ credentials._token_uri
+ == test_service_account.SERVICE_ACCOUNT_INFO["token_uri"]
+ )
+ assert credentials._target_audience == self.TARGET_AUDIENCE
+
+ def test_from_service_account_file(self):
+ info = test_service_account.SERVICE_ACCOUNT_INFO.copy()
+
+ credentials = service_account.IDTokenCredentials.from_service_account_file(
+ test_service_account.SERVICE_ACCOUNT_JSON_FILE,
+ target_audience=self.TARGET_AUDIENCE,
+ )
+
+ assert credentials.service_account_email == info["client_email"]
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._token_uri == info["token_uri"]
+ assert credentials._target_audience == self.TARGET_AUDIENCE
+
+ def test_default_state(self):
+ credentials = self.make_credentials()
+ assert not credentials.valid
+ # Expiration hasn't been set yet
+ assert not credentials.expired
+
+ def test_sign_bytes(self):
+ credentials = self.make_credentials()
+ to_sign = b"123"
+ signature = credentials.sign_bytes(to_sign)
+ assert crypt.verify_signature(
+ to_sign, signature, test_service_account.PUBLIC_CERT_BYTES
+ )
+
+ def test_signer(self):
+ credentials = self.make_credentials()
+ assert isinstance(credentials.signer, crypt.Signer)
+
+ def test_signer_email(self):
+ credentials = self.make_credentials()
+ assert credentials.signer_email == self.SERVICE_ACCOUNT_EMAIL
+
+ def test_with_target_audience(self):
+ credentials = self.make_credentials()
+ new_credentials = credentials.with_target_audience("https://new.example.com")
+ assert new_credentials._target_audience == "https://new.example.com"
+
+ def test_with_quota_project(self):
+ credentials = self.make_credentials()
+ new_credentials = credentials.with_quota_project("project-foo")
+ assert new_credentials._quota_project_id == "project-foo"
+
+ def test__make_authorization_grant_assertion(self):
+ credentials = self.make_credentials()
+ token = credentials._make_authorization_grant_assertion()
+ payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES)
+ assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL
+ assert payload["aud"] == self.TOKEN_URI
+ assert payload["target_audience"] == self.TARGET_AUDIENCE
+
+ @mock.patch("google.oauth2._client_async.id_token_jwt_grant", autospec=True)
+ @pytest.mark.asyncio
+ async def test_refresh_success(self, id_token_jwt_grant):
+ credentials = self.make_credentials()
+ token = "token"
+ id_token_jwt_grant.return_value = (
+ token,
+ _helpers.utcnow() + datetime.timedelta(seconds=500),
+ {},
+ )
+
+ request = mock.AsyncMock(spec=["transport.Request"])
+
+ # Refresh credentials
+ await credentials.refresh(request)
+
+ # Check jwt grant call.
+ assert id_token_jwt_grant.called
+
+ called_request, token_uri, assertion = id_token_jwt_grant.call_args[0]
+ assert called_request == request
+ assert token_uri == credentials._token_uri
+ assert jwt.decode(assertion, test_service_account.PUBLIC_CERT_BYTES)
+ # No further assertion done on the token, as there are separate tests
+ # for checking the authorization grant assertion.
+
+ # Check that the credentials have the token.
+ assert credentials.token == token
+
+ # Check that the credentials are valid (have a token and are not
+ # expired)
+ assert credentials.valid
+
+ @mock.patch("google.oauth2._client_async.id_token_jwt_grant", autospec=True)
+ @pytest.mark.asyncio
+ async def test_before_request_refreshes(self, id_token_jwt_grant):
+ credentials = self.make_credentials()
+ token = "token"
+ id_token_jwt_grant.return_value = (
+ token,
+ _helpers.utcnow() + datetime.timedelta(seconds=500),
+ None,
+ )
+ request = mock.AsyncMock(spec=["transport.Request"])
+
+ # Credentials should start as invalid
+ assert not credentials.valid
+
+ # before_request should cause a refresh
+ await credentials.before_request(request, "GET", "http://example.com?a=1#3", {})
+
+ # The refresh endpoint should've been called.
+ assert id_token_jwt_grant.called
+
+ # Credentials should now be valid.
+ assert credentials.valid
diff --git a/tests_async/test__default_async.py b/tests_async/test__default_async.py
new file mode 100644
index 0000000..bca396a
--- /dev/null
+++ b/tests_async/test__default_async.py
@@ -0,0 +1,468 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+
+import mock
+import pytest
+
+from google.auth import _credentials_async as credentials
+from google.auth import _default_async as _default
+from google.auth import app_engine
+from google.auth import compute_engine
+from google.auth import environment_vars
+from google.auth import exceptions
+from google.oauth2 import _service_account_async as service_account
+import google.oauth2.credentials
+from tests import test__default as test_default
+
+MOCK_CREDENTIALS = mock.Mock(spec=credentials.CredentialsWithQuotaProject)
+MOCK_CREDENTIALS.with_quota_project.return_value = MOCK_CREDENTIALS
+
+LOAD_FILE_PATCH = mock.patch(
+ "google.auth._default_async.load_credentials_from_file",
+ return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id),
+ autospec=True,
+)
+
+
+def test_load_credentials_from_missing_file():
+ with pytest.raises(exceptions.DefaultCredentialsError) as excinfo:
+ _default.load_credentials_from_file("")
+
+ assert excinfo.match(r"not found")
+
+
+def test_load_credentials_from_file_invalid_json(tmpdir):
+ jsonfile = tmpdir.join("invalid.json")
+ jsonfile.write("{")
+
+ with pytest.raises(exceptions.DefaultCredentialsError) as excinfo:
+ _default.load_credentials_from_file(str(jsonfile))
+
+ assert excinfo.match(r"not a valid json file")
+
+
+def test_load_credentials_from_file_invalid_type(tmpdir):
+ jsonfile = tmpdir.join("invalid.json")
+ jsonfile.write(json.dumps({"type": "not-a-real-type"}))
+
+ with pytest.raises(exceptions.DefaultCredentialsError) as excinfo:
+ _default.load_credentials_from_file(str(jsonfile))
+
+ assert excinfo.match(r"does not have a valid type")
+
+
+def test_load_credentials_from_file_authorized_user():
+ credentials, project_id = _default.load_credentials_from_file(
+ test_default.AUTHORIZED_USER_FILE
+ )
+ assert isinstance(credentials, google.oauth2._credentials_async.Credentials)
+ assert project_id is None
+
+
+def test_load_credentials_from_file_no_type(tmpdir):
+ # use the client_secrets.json, which is valid json but not a
+ # loadable credentials type
+ with pytest.raises(exceptions.DefaultCredentialsError) as excinfo:
+ _default.load_credentials_from_file(test_default.CLIENT_SECRETS_FILE)
+
+ assert excinfo.match(r"does not have a valid type")
+ assert excinfo.match(r"Type is None")
+
+
+def test_load_credentials_from_file_authorized_user_bad_format(tmpdir):
+ filename = tmpdir.join("authorized_user_bad.json")
+ filename.write(json.dumps({"type": "authorized_user"}))
+
+ with pytest.raises(exceptions.DefaultCredentialsError) as excinfo:
+ _default.load_credentials_from_file(str(filename))
+
+ assert excinfo.match(r"Failed to load authorized user")
+ assert excinfo.match(r"missing fields")
+
+
+def test_load_credentials_from_file_authorized_user_cloud_sdk():
+ with pytest.warns(UserWarning, match="Cloud SDK"):
+ credentials, project_id = _default.load_credentials_from_file(
+ test_default.AUTHORIZED_USER_CLOUD_SDK_FILE
+ )
+ assert isinstance(credentials, google.oauth2._credentials_async.Credentials)
+ assert project_id is None
+
+ # No warning if the json file has quota project id.
+ credentials, project_id = _default.load_credentials_from_file(
+ test_default.AUTHORIZED_USER_CLOUD_SDK_WITH_QUOTA_PROJECT_ID_FILE
+ )
+ assert isinstance(credentials, google.oauth2._credentials_async.Credentials)
+ assert project_id is None
+
+
+def test_load_credentials_from_file_authorized_user_cloud_sdk_with_scopes():
+ with pytest.warns(UserWarning, match="Cloud SDK"):
+ credentials, project_id = _default.load_credentials_from_file(
+ test_default.AUTHORIZED_USER_CLOUD_SDK_FILE,
+ scopes=["https://www.google.com/calendar/feeds"],
+ )
+ assert isinstance(credentials, google.oauth2._credentials_async.Credentials)
+ assert project_id is None
+ assert credentials.scopes == ["https://www.google.com/calendar/feeds"]
+
+
+def test_load_credentials_from_file_authorized_user_cloud_sdk_with_quota_project():
+ credentials, project_id = _default.load_credentials_from_file(
+ test_default.AUTHORIZED_USER_CLOUD_SDK_FILE, quota_project_id="project-foo"
+ )
+
+ assert isinstance(credentials, google.oauth2._credentials_async.Credentials)
+ assert project_id is None
+ assert credentials.quota_project_id == "project-foo"
+
+
+def test_load_credentials_from_file_service_account():
+ credentials, project_id = _default.load_credentials_from_file(
+ test_default.SERVICE_ACCOUNT_FILE
+ )
+ assert isinstance(credentials, service_account.Credentials)
+ assert project_id == test_default.SERVICE_ACCOUNT_FILE_DATA["project_id"]
+
+
+def test_load_credentials_from_file_service_account_with_scopes():
+ credentials, project_id = _default.load_credentials_from_file(
+ test_default.SERVICE_ACCOUNT_FILE,
+ scopes=["https://www.google.com/calendar/feeds"],
+ )
+ assert isinstance(credentials, service_account.Credentials)
+ assert project_id == test_default.SERVICE_ACCOUNT_FILE_DATA["project_id"]
+ assert credentials.scopes == ["https://www.google.com/calendar/feeds"]
+
+
+def test_load_credentials_from_file_service_account_bad_format(tmpdir):
+ filename = tmpdir.join("serivce_account_bad.json")
+ filename.write(json.dumps({"type": "service_account"}))
+
+ with pytest.raises(exceptions.DefaultCredentialsError) as excinfo:
+ _default.load_credentials_from_file(str(filename))
+
+ assert excinfo.match(r"Failed to load service account")
+ assert excinfo.match(r"missing fields")
+
+
+@mock.patch.dict(os.environ, {}, clear=True)
+def test__get_explicit_environ_credentials_no_env():
+ assert _default._get_explicit_environ_credentials() == (None, None)
+
+
+@LOAD_FILE_PATCH
+def test__get_explicit_environ_credentials(load, monkeypatch):
+ monkeypatch.setenv(environment_vars.CREDENTIALS, "filename")
+
+ credentials, project_id = _default._get_explicit_environ_credentials()
+
+ assert credentials is MOCK_CREDENTIALS
+ assert project_id is mock.sentinel.project_id
+ load.assert_called_with("filename")
+
+
+@LOAD_FILE_PATCH
+def test__get_explicit_environ_credentials_no_project_id(load, monkeypatch):
+ load.return_value = MOCK_CREDENTIALS, None
+ monkeypatch.setenv(environment_vars.CREDENTIALS, "filename")
+
+ credentials, project_id = _default._get_explicit_environ_credentials()
+
+ assert credentials is MOCK_CREDENTIALS
+ assert project_id is None
+
+
+@LOAD_FILE_PATCH
+@mock.patch(
+ "google.auth._cloud_sdk.get_application_default_credentials_path", autospec=True
+)
+def test__get_gcloud_sdk_credentials(get_adc_path, load):
+ get_adc_path.return_value = test_default.SERVICE_ACCOUNT_FILE
+
+ credentials, project_id = _default._get_gcloud_sdk_credentials()
+
+ assert credentials is MOCK_CREDENTIALS
+ assert project_id is mock.sentinel.project_id
+ load.assert_called_with(test_default.SERVICE_ACCOUNT_FILE)
+
+
+@mock.patch(
+ "google.auth._cloud_sdk.get_application_default_credentials_path", autospec=True
+)
+def test__get_gcloud_sdk_credentials_non_existent(get_adc_path, tmpdir):
+ non_existent = tmpdir.join("non-existent")
+ get_adc_path.return_value = str(non_existent)
+
+ credentials, project_id = _default._get_gcloud_sdk_credentials()
+
+ assert credentials is None
+ assert project_id is None
+
+
+@mock.patch(
+ "google.auth._cloud_sdk.get_project_id",
+ return_value=mock.sentinel.project_id,
+ autospec=True,
+)
+@mock.patch("os.path.isfile", return_value=True, autospec=True)
+@LOAD_FILE_PATCH
+def test__get_gcloud_sdk_credentials_project_id(load, unused_isfile, get_project_id):
+ # Don't return a project ID from load file, make the function check
+ # the Cloud SDK project.
+ load.return_value = MOCK_CREDENTIALS, None
+
+ credentials, project_id = _default._get_gcloud_sdk_credentials()
+
+ assert credentials == MOCK_CREDENTIALS
+ assert project_id == mock.sentinel.project_id
+ assert get_project_id.called
+
+
+@mock.patch("google.auth._cloud_sdk.get_project_id", return_value=None, autospec=True)
+@mock.patch("os.path.isfile", return_value=True)
+@LOAD_FILE_PATCH
+def test__get_gcloud_sdk_credentials_no_project_id(load, unused_isfile, get_project_id):
+ # Don't return a project ID from load file, make the function check
+ # the Cloud SDK project.
+ load.return_value = MOCK_CREDENTIALS, None
+
+ credentials, project_id = _default._get_gcloud_sdk_credentials()
+
+ assert credentials == MOCK_CREDENTIALS
+ assert project_id is None
+ assert get_project_id.called
+
+
+class _AppIdentityModule(object):
+ """The interface of the App Idenity app engine module.
+ See https://cloud.google.com/appengine/docs/standard/python/refdocs\
+ /google.appengine.api.app_identity.app_identity
+ """
+
+ def get_application_id(self):
+ raise NotImplementedError()
+
+
+@pytest.fixture
+def app_identity(monkeypatch):
+ """Mocks the app_identity module for google.auth.app_engine."""
+ app_identity_module = mock.create_autospec(_AppIdentityModule, instance=True)
+ monkeypatch.setattr(app_engine, "app_identity", app_identity_module)
+ yield app_identity_module
+
+
+def test__get_gae_credentials(app_identity):
+ app_identity.get_application_id.return_value = mock.sentinel.project
+
+ credentials, project_id = _default._get_gae_credentials()
+
+ assert isinstance(credentials, app_engine.Credentials)
+ assert project_id == mock.sentinel.project
+
+
+def test__get_gae_credentials_no_app_engine():
+ import sys
+
+ with mock.patch.dict("sys.modules"):
+ sys.modules["google.auth.app_engine"] = None
+ credentials, project_id = _default._get_gae_credentials()
+ assert credentials is None
+ assert project_id is None
+
+
+def test__get_gae_credentials_no_apis():
+ assert _default._get_gae_credentials() == (None, None)
+
+
+@mock.patch(
+ "google.auth.compute_engine._metadata.ping", return_value=True, autospec=True
+)
+@mock.patch(
+ "google.auth.compute_engine._metadata.get_project_id",
+ return_value="example-project",
+ autospec=True,
+)
+def test__get_gce_credentials(unused_get, unused_ping):
+ credentials, project_id = _default._get_gce_credentials()
+
+ assert isinstance(credentials, compute_engine.Credentials)
+ assert project_id == "example-project"
+
+
+@mock.patch(
+ "google.auth.compute_engine._metadata.ping", return_value=False, autospec=True
+)
+def test__get_gce_credentials_no_ping(unused_ping):
+ credentials, project_id = _default._get_gce_credentials()
+
+ assert credentials is None
+ assert project_id is None
+
+
+@mock.patch(
+ "google.auth.compute_engine._metadata.ping", return_value=True, autospec=True
+)
+@mock.patch(
+ "google.auth.compute_engine._metadata.get_project_id",
+ side_effect=exceptions.TransportError(),
+ autospec=True,
+)
+def test__get_gce_credentials_no_project_id(unused_get, unused_ping):
+ credentials, project_id = _default._get_gce_credentials()
+
+ assert isinstance(credentials, compute_engine.Credentials)
+ assert project_id is None
+
+
+def test__get_gce_credentials_no_compute_engine():
+ import sys
+
+ with mock.patch.dict("sys.modules"):
+ sys.modules["google.auth.compute_engine"] = None
+ credentials, project_id = _default._get_gce_credentials()
+ assert credentials is None
+ assert project_id is None
+
+
+@mock.patch(
+ "google.auth.compute_engine._metadata.ping", return_value=False, autospec=True
+)
+def test__get_gce_credentials_explicit_request(ping):
+ _default._get_gce_credentials(mock.sentinel.request)
+ ping.assert_called_with(request=mock.sentinel.request)
+
+
+@mock.patch(
+ "google.auth._default_async._get_explicit_environ_credentials",
+ return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id),
+ autospec=True,
+)
+def test_default_early_out(unused_get):
+ assert _default.default_async() == (MOCK_CREDENTIALS, mock.sentinel.project_id)
+
+
+@mock.patch(
+ "google.auth._default_async._get_explicit_environ_credentials",
+ return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id),
+ autospec=True,
+)
+def test_default_explict_project_id(unused_get, monkeypatch):
+ monkeypatch.setenv(environment_vars.PROJECT, "explicit-env")
+ assert _default.default_async() == (MOCK_CREDENTIALS, "explicit-env")
+
+
+@mock.patch(
+ "google.auth._default_async._get_explicit_environ_credentials",
+ return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id),
+ autospec=True,
+)
+def test_default_explict_legacy_project_id(unused_get, monkeypatch):
+ monkeypatch.setenv(environment_vars.LEGACY_PROJECT, "explicit-env")
+ assert _default.default_async() == (MOCK_CREDENTIALS, "explicit-env")
+
+
+@mock.patch("logging.Logger.warning", autospec=True)
+@mock.patch(
+ "google.auth._default_async._get_explicit_environ_credentials",
+ return_value=(MOCK_CREDENTIALS, None),
+ autospec=True,
+)
+@mock.patch(
+ "google.auth._default_async._get_gcloud_sdk_credentials",
+ return_value=(MOCK_CREDENTIALS, None),
+ autospec=True,
+)
+@mock.patch(
+ "google.auth._default_async._get_gae_credentials",
+ return_value=(MOCK_CREDENTIALS, None),
+ autospec=True,
+)
+@mock.patch(
+ "google.auth._default_async._get_gce_credentials",
+ return_value=(MOCK_CREDENTIALS, None),
+ autospec=True,
+)
+def test_default_without_project_id(
+ unused_gce, unused_gae, unused_sdk, unused_explicit, logger_warning
+):
+ assert _default.default_async() == (MOCK_CREDENTIALS, None)
+ logger_warning.assert_called_with(mock.ANY, mock.ANY, mock.ANY)
+
+
+@mock.patch(
+ "google.auth._default_async._get_explicit_environ_credentials",
+ return_value=(None, None),
+ autospec=True,
+)
+@mock.patch(
+ "google.auth._default_async._get_gcloud_sdk_credentials",
+ return_value=(None, None),
+ autospec=True,
+)
+@mock.patch(
+ "google.auth._default_async._get_gae_credentials",
+ return_value=(None, None),
+ autospec=True,
+)
+@mock.patch(
+ "google.auth._default_async._get_gce_credentials",
+ return_value=(None, None),
+ autospec=True,
+)
+def test_default_fail(unused_gce, unused_gae, unused_sdk, unused_explicit):
+ with pytest.raises(exceptions.DefaultCredentialsError):
+ assert _default.default_async()
+
+
+@mock.patch(
+ "google.auth._default_async._get_explicit_environ_credentials",
+ return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id),
+ autospec=True,
+)
+@mock.patch(
+ "google.auth._credentials_async.with_scopes_if_required",
+ return_value=MOCK_CREDENTIALS,
+ autospec=True,
+)
+def test_default_scoped(with_scopes, unused_get):
+ scopes = ["one", "two"]
+
+ credentials, project_id = _default.default_async(scopes=scopes)
+
+ assert credentials == with_scopes.return_value
+ assert project_id == mock.sentinel.project_id
+ with_scopes.assert_called_once_with(MOCK_CREDENTIALS, scopes)
+
+
+@mock.patch(
+ "google.auth._default_async._get_explicit_environ_credentials",
+ return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id),
+ autospec=True,
+)
+def test_default_no_app_engine_compute_engine_module(unused_get):
+ """
+ google.auth.compute_engine and google.auth.app_engine are both optional
+ to allow not including them when using this package. This verifies
+ that default fails gracefully if these modules are absent
+ """
+ import sys
+
+ with mock.patch.dict("sys.modules"):
+ sys.modules["google.auth.compute_engine"] = None
+ sys.modules["google.auth.app_engine"] = None
+ assert _default.default_async() == (MOCK_CREDENTIALS, mock.sentinel.project_id)
diff --git a/tests_async/test_credentials_async.py b/tests_async/test_credentials_async.py
new file mode 100644
index 0000000..0a48908
--- /dev/null
+++ b/tests_async/test_credentials_async.py
@@ -0,0 +1,177 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+
+import pytest
+
+from google.auth import _credentials_async as credentials
+from google.auth import _helpers
+
+
+class CredentialsImpl(credentials.Credentials):
+ def refresh(self, request):
+ self.token = request
+
+ def with_quota_project(self, quota_project_id):
+ raise NotImplementedError()
+
+
+def test_credentials_constructor():
+ credentials = CredentialsImpl()
+ assert not credentials.token
+ assert not credentials.expiry
+ assert not credentials.expired
+ assert not credentials.valid
+
+
+def test_expired_and_valid():
+ credentials = CredentialsImpl()
+ credentials.token = "token"
+
+ assert credentials.valid
+ assert not credentials.expired
+
+ # Set the expiration to one second more than now plus the clock skew
+ # accomodation. These credentials should be valid.
+ credentials.expiry = (
+ datetime.datetime.utcnow() + _helpers.CLOCK_SKEW + datetime.timedelta(seconds=1)
+ )
+
+ assert credentials.valid
+ assert not credentials.expired
+
+ # Set the credentials expiration to now. Because of the clock skew
+ # accomodation, these credentials should report as expired.
+ credentials.expiry = datetime.datetime.utcnow()
+
+ assert not credentials.valid
+ assert credentials.expired
+
+
+@pytest.mark.asyncio
+async def test_before_request():
+ credentials = CredentialsImpl()
+ request = "token"
+ headers = {}
+
+ # First call should call refresh, setting the token.
+ await credentials.before_request(request, "http://example.com", "GET", headers)
+ assert credentials.valid
+ assert credentials.token == "token"
+ assert headers["authorization"] == "Bearer token"
+
+ request = "token2"
+ headers = {}
+
+ # Second call shouldn't call refresh.
+ credentials.before_request(request, "http://example.com", "GET", headers)
+
+ assert credentials.valid
+ assert credentials.token == "token"
+
+
+def test_anonymous_credentials_ctor():
+ anon = credentials.AnonymousCredentials()
+
+ assert anon.token is None
+ assert anon.expiry is None
+ assert not anon.expired
+ assert anon.valid
+
+
+def test_anonymous_credentials_refresh():
+ anon = credentials.AnonymousCredentials()
+
+ request = object()
+ with pytest.raises(ValueError):
+ anon.refresh(request)
+
+
+def test_anonymous_credentials_apply_default():
+ anon = credentials.AnonymousCredentials()
+ headers = {}
+ anon.apply(headers)
+ assert headers == {}
+ with pytest.raises(ValueError):
+ anon.apply(headers, token="TOKEN")
+
+
+def test_anonymous_credentials_before_request():
+ anon = credentials.AnonymousCredentials()
+ request = object()
+ method = "GET"
+ url = "https://example.com/api/endpoint"
+ headers = {}
+ anon.before_request(request, method, url, headers)
+ assert headers == {}
+
+
+class ReadOnlyScopedCredentialsImpl(credentials.ReadOnlyScoped, CredentialsImpl):
+ @property
+ def requires_scopes(self):
+ return super(ReadOnlyScopedCredentialsImpl, self).requires_scopes
+
+
+def test_readonly_scoped_credentials_constructor():
+ credentials = ReadOnlyScopedCredentialsImpl()
+ assert credentials._scopes is None
+
+
+def test_readonly_scoped_credentials_scopes():
+ credentials = ReadOnlyScopedCredentialsImpl()
+ credentials._scopes = ["one", "two"]
+ assert credentials.scopes == ["one", "two"]
+ assert credentials.has_scopes(["one"])
+ assert credentials.has_scopes(["two"])
+ assert credentials.has_scopes(["one", "two"])
+ assert not credentials.has_scopes(["three"])
+
+
+def test_readonly_scoped_credentials_requires_scopes():
+ credentials = ReadOnlyScopedCredentialsImpl()
+ assert not credentials.requires_scopes
+
+
+class RequiresScopedCredentialsImpl(credentials.Scoped, CredentialsImpl):
+ def __init__(self, scopes=None):
+ super(RequiresScopedCredentialsImpl, self).__init__()
+ self._scopes = scopes
+
+ @property
+ def requires_scopes(self):
+ return not self.scopes
+
+ def with_scopes(self, scopes):
+ return RequiresScopedCredentialsImpl(scopes=scopes)
+
+
+def test_create_scoped_if_required_scoped():
+ unscoped_credentials = RequiresScopedCredentialsImpl()
+ scoped_credentials = credentials.with_scopes_if_required(
+ unscoped_credentials, ["one", "two"]
+ )
+
+ assert scoped_credentials is not unscoped_credentials
+ assert not scoped_credentials.requires_scopes
+ assert scoped_credentials.has_scopes(["one", "two"])
+
+
+def test_create_scoped_if_required_not_scopes():
+ unscoped_credentials = CredentialsImpl()
+ scoped_credentials = credentials.with_scopes_if_required(
+ unscoped_credentials, ["one", "two"]
+ )
+
+ assert scoped_credentials is unscoped_credentials
diff --git a/tests_async/test_jwt_async.py b/tests_async/test_jwt_async.py
new file mode 100644
index 0000000..a35b837
--- /dev/null
+++ b/tests_async/test_jwt_async.py
@@ -0,0 +1,356 @@
+# Copyright 2020 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import json
+
+import mock
+import pytest
+
+from google.auth import _jwt_async as jwt_async
+from google.auth import crypt
+from google.auth import exceptions
+from tests import test_jwt
+
+
+@pytest.fixture
+def signer():
+ return crypt.RSASigner.from_string(test_jwt.PRIVATE_KEY_BYTES, "1")
+
+
+class TestCredentials(object):
+ SERVICE_ACCOUNT_EMAIL = "service-account@example.com"
+ SUBJECT = "subject"
+ AUDIENCE = "audience"
+ ADDITIONAL_CLAIMS = {"meta": "data"}
+ credentials = None
+
+ @pytest.fixture(autouse=True)
+ def credentials_fixture(self, signer):
+ self.credentials = jwt_async.Credentials(
+ signer,
+ self.SERVICE_ACCOUNT_EMAIL,
+ self.SERVICE_ACCOUNT_EMAIL,
+ self.AUDIENCE,
+ )
+
+ def test_from_service_account_info(self):
+ with open(test_jwt.SERVICE_ACCOUNT_JSON_FILE, "r") as fh:
+ info = json.load(fh)
+
+ credentials = jwt_async.Credentials.from_service_account_info(
+ info, audience=self.AUDIENCE
+ )
+
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._issuer == info["client_email"]
+ assert credentials._subject == info["client_email"]
+ assert credentials._audience == self.AUDIENCE
+
+ def test_from_service_account_info_args(self):
+ info = test_jwt.SERVICE_ACCOUNT_INFO.copy()
+
+ credentials = jwt_async.Credentials.from_service_account_info(
+ info,
+ subject=self.SUBJECT,
+ audience=self.AUDIENCE,
+ additional_claims=self.ADDITIONAL_CLAIMS,
+ )
+
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._issuer == info["client_email"]
+ assert credentials._subject == self.SUBJECT
+ assert credentials._audience == self.AUDIENCE
+ assert credentials._additional_claims == self.ADDITIONAL_CLAIMS
+
+ def test_from_service_account_file(self):
+ info = test_jwt.SERVICE_ACCOUNT_INFO.copy()
+
+ credentials = jwt_async.Credentials.from_service_account_file(
+ test_jwt.SERVICE_ACCOUNT_JSON_FILE, audience=self.AUDIENCE
+ )
+
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._issuer == info["client_email"]
+ assert credentials._subject == info["client_email"]
+ assert credentials._audience == self.AUDIENCE
+
+ def test_from_service_account_file_args(self):
+ info = test_jwt.SERVICE_ACCOUNT_INFO.copy()
+
+ credentials = jwt_async.Credentials.from_service_account_file(
+ test_jwt.SERVICE_ACCOUNT_JSON_FILE,
+ subject=self.SUBJECT,
+ audience=self.AUDIENCE,
+ additional_claims=self.ADDITIONAL_CLAIMS,
+ )
+
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._issuer == info["client_email"]
+ assert credentials._subject == self.SUBJECT
+ assert credentials._audience == self.AUDIENCE
+ assert credentials._additional_claims == self.ADDITIONAL_CLAIMS
+
+ def test_from_signing_credentials(self):
+ jwt_from_signing = self.credentials.from_signing_credentials(
+ self.credentials, audience=mock.sentinel.new_audience
+ )
+ jwt_from_info = jwt_async.Credentials.from_service_account_info(
+ test_jwt.SERVICE_ACCOUNT_INFO, audience=mock.sentinel.new_audience
+ )
+
+ assert isinstance(jwt_from_signing, jwt_async.Credentials)
+ assert jwt_from_signing._signer.key_id == jwt_from_info._signer.key_id
+ assert jwt_from_signing._issuer == jwt_from_info._issuer
+ assert jwt_from_signing._subject == jwt_from_info._subject
+ assert jwt_from_signing._audience == jwt_from_info._audience
+
+ def test_default_state(self):
+ assert not self.credentials.valid
+ # Expiration hasn't been set yet
+ assert not self.credentials.expired
+
+ def test_with_claims(self):
+ new_audience = "new_audience"
+ new_credentials = self.credentials.with_claims(audience=new_audience)
+
+ assert new_credentials._signer == self.credentials._signer
+ assert new_credentials._issuer == self.credentials._issuer
+ assert new_credentials._subject == self.credentials._subject
+ assert new_credentials._audience == new_audience
+ assert new_credentials._additional_claims == self.credentials._additional_claims
+ assert new_credentials._quota_project_id == self.credentials._quota_project_id
+
+ def test_with_quota_project(self):
+ quota_project_id = "project-foo"
+
+ new_credentials = self.credentials.with_quota_project(quota_project_id)
+ assert new_credentials._signer == self.credentials._signer
+ assert new_credentials._issuer == self.credentials._issuer
+ assert new_credentials._subject == self.credentials._subject
+ assert new_credentials._audience == self.credentials._audience
+ assert new_credentials._additional_claims == self.credentials._additional_claims
+ assert new_credentials._quota_project_id == quota_project_id
+
+ def test_sign_bytes(self):
+ to_sign = b"123"
+ signature = self.credentials.sign_bytes(to_sign)
+ assert crypt.verify_signature(to_sign, signature, test_jwt.PUBLIC_CERT_BYTES)
+
+ def test_signer(self):
+ assert isinstance(self.credentials.signer, crypt.RSASigner)
+
+ def test_signer_email(self):
+ assert (
+ self.credentials.signer_email
+ == test_jwt.SERVICE_ACCOUNT_INFO["client_email"]
+ )
+
+ def _verify_token(self, token):
+ payload = jwt_async.decode(token, test_jwt.PUBLIC_CERT_BYTES)
+ assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL
+ return payload
+
+ def test_refresh(self):
+ self.credentials.refresh(None)
+ assert self.credentials.valid
+ assert not self.credentials.expired
+
+ def test_expired(self):
+ assert not self.credentials.expired
+
+ self.credentials.refresh(None)
+ assert not self.credentials.expired
+
+ with mock.patch("google.auth._helpers.utcnow") as now:
+ one_day = datetime.timedelta(days=1)
+ now.return_value = self.credentials.expiry + one_day
+ assert self.credentials.expired
+
+ @pytest.mark.asyncio
+ async def test_before_request(self):
+ headers = {}
+
+ self.credentials.refresh(None)
+ await self.credentials.before_request(
+ None, "GET", "http://example.com?a=1#3", headers
+ )
+
+ header_value = headers["authorization"]
+ _, token = header_value.split(" ")
+
+ # Since the audience is set, it should use the existing token.
+ assert token.encode("utf-8") == self.credentials.token
+
+ payload = self._verify_token(token)
+ assert payload["aud"] == self.AUDIENCE
+
+ @pytest.mark.asyncio
+ async def test_before_request_refreshes(self):
+ assert not self.credentials.valid
+ await self.credentials.before_request(
+ None, "GET", "http://example.com?a=1#3", {}
+ )
+ assert self.credentials.valid
+
+
+class TestOnDemandCredentials(object):
+ SERVICE_ACCOUNT_EMAIL = "service-account@example.com"
+ SUBJECT = "subject"
+ ADDITIONAL_CLAIMS = {"meta": "data"}
+ credentials = None
+
+ @pytest.fixture(autouse=True)
+ def credentials_fixture(self, signer):
+ self.credentials = jwt_async.OnDemandCredentials(
+ signer,
+ self.SERVICE_ACCOUNT_EMAIL,
+ self.SERVICE_ACCOUNT_EMAIL,
+ max_cache_size=2,
+ )
+
+ def test_from_service_account_info(self):
+ with open(test_jwt.SERVICE_ACCOUNT_JSON_FILE, "r") as fh:
+ info = json.load(fh)
+
+ credentials = jwt_async.OnDemandCredentials.from_service_account_info(info)
+
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._issuer == info["client_email"]
+ assert credentials._subject == info["client_email"]
+
+ def test_from_service_account_info_args(self):
+ info = test_jwt.SERVICE_ACCOUNT_INFO.copy()
+
+ credentials = jwt_async.OnDemandCredentials.from_service_account_info(
+ info, subject=self.SUBJECT, additional_claims=self.ADDITIONAL_CLAIMS
+ )
+
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._issuer == info["client_email"]
+ assert credentials._subject == self.SUBJECT
+ assert credentials._additional_claims == self.ADDITIONAL_CLAIMS
+
+ def test_from_service_account_file(self):
+ info = test_jwt.SERVICE_ACCOUNT_INFO.copy()
+
+ credentials = jwt_async.OnDemandCredentials.from_service_account_file(
+ test_jwt.SERVICE_ACCOUNT_JSON_FILE
+ )
+
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._issuer == info["client_email"]
+ assert credentials._subject == info["client_email"]
+
+ def test_from_service_account_file_args(self):
+ info = test_jwt.SERVICE_ACCOUNT_INFO.copy()
+
+ credentials = jwt_async.OnDemandCredentials.from_service_account_file(
+ test_jwt.SERVICE_ACCOUNT_JSON_FILE,
+ subject=self.SUBJECT,
+ additional_claims=self.ADDITIONAL_CLAIMS,
+ )
+
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._issuer == info["client_email"]
+ assert credentials._subject == self.SUBJECT
+ assert credentials._additional_claims == self.ADDITIONAL_CLAIMS
+
+ def test_from_signing_credentials(self):
+ jwt_from_signing = self.credentials.from_signing_credentials(self.credentials)
+ jwt_from_info = jwt_async.OnDemandCredentials.from_service_account_info(
+ test_jwt.SERVICE_ACCOUNT_INFO
+ )
+
+ assert isinstance(jwt_from_signing, jwt_async.OnDemandCredentials)
+ assert jwt_from_signing._signer.key_id == jwt_from_info._signer.key_id
+ assert jwt_from_signing._issuer == jwt_from_info._issuer
+ assert jwt_from_signing._subject == jwt_from_info._subject
+
+ def test_default_state(self):
+ # Credentials are *always* valid.
+ assert self.credentials.valid
+ # Credentials *never* expire.
+ assert not self.credentials.expired
+
+ def test_with_claims(self):
+ new_claims = {"meep": "moop"}
+ new_credentials = self.credentials.with_claims(additional_claims=new_claims)
+
+ assert new_credentials._signer == self.credentials._signer
+ assert new_credentials._issuer == self.credentials._issuer
+ assert new_credentials._subject == self.credentials._subject
+ assert new_credentials._additional_claims == new_claims
+
+ def test_with_quota_project(self):
+ quota_project_id = "project-foo"
+ new_credentials = self.credentials.with_quota_project(quota_project_id)
+
+ assert new_credentials._signer == self.credentials._signer
+ assert new_credentials._issuer == self.credentials._issuer
+ assert new_credentials._subject == self.credentials._subject
+ assert new_credentials._additional_claims == self.credentials._additional_claims
+ assert new_credentials._quota_project_id == quota_project_id
+
+ def test_sign_bytes(self):
+ to_sign = b"123"
+ signature = self.credentials.sign_bytes(to_sign)
+ assert crypt.verify_signature(to_sign, signature, test_jwt.PUBLIC_CERT_BYTES)
+
+ def test_signer(self):
+ assert isinstance(self.credentials.signer, crypt.RSASigner)
+
+ def test_signer_email(self):
+ assert (
+ self.credentials.signer_email
+ == test_jwt.SERVICE_ACCOUNT_INFO["client_email"]
+ )
+
+ def _verify_token(self, token):
+ payload = jwt_async.decode(token, test_jwt.PUBLIC_CERT_BYTES)
+ assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL
+ return payload
+
+ def test_refresh(self):
+ with pytest.raises(exceptions.RefreshError):
+ self.credentials.refresh(None)
+
+ def test_before_request(self):
+ headers = {}
+
+ self.credentials.before_request(
+ None, "GET", "http://example.com?a=1#3", headers
+ )
+
+ _, token = headers["authorization"].split(" ")
+ payload = self._verify_token(token)
+
+ assert payload["aud"] == "http://example.com"
+
+ # Making another request should re-use the same token.
+ self.credentials.before_request(None, "GET", "http://example.com?b=2", headers)
+
+ _, new_token = headers["authorization"].split(" ")
+
+ assert new_token == token
+
+ def test_expired_token(self):
+ self.credentials._cache["audience"] = (
+ mock.sentinel.token,
+ datetime.datetime.min,
+ )
+
+ token = self.credentials._get_jwt_for_audience("audience")
+
+ assert token != mock.sentinel.token
diff --git a/tests_async/transport/__init__.py b/tests_async/transport/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests_async/transport/__init__.py
diff --git a/tests_async/transport/async_compliance.py b/tests_async/transport/async_compliance.py
new file mode 100644
index 0000000..9c4b173
--- /dev/null
+++ b/tests_async/transport/async_compliance.py
@@ -0,0 +1,133 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+
+import flask
+import pytest
+from pytest_localserver.http import WSGIServer
+from six.moves import http_client
+
+from google.auth import exceptions
+from tests.transport import compliance
+
+
+class RequestResponseTests(object):
+ @pytest.fixture(scope="module")
+ def server(self):
+ """Provides a test HTTP server.
+
+ The test server is automatically created before
+ a test and destroyed at the end. The server is serving a test
+ application that can be used to verify requests.
+ """
+ app = flask.Flask(__name__)
+ app.debug = True
+
+ # pylint: disable=unused-variable
+ # (pylint thinks the flask routes are unusued.)
+ @app.route("/basic")
+ def index():
+ header_value = flask.request.headers.get("x-test-header", "value")
+ headers = {"X-Test-Header": header_value}
+ return "Basic Content", http_client.OK, headers
+
+ @app.route("/server_error")
+ def server_error():
+ return "Error", http_client.INTERNAL_SERVER_ERROR
+
+ @app.route("/wait")
+ def wait():
+ time.sleep(3)
+ return "Waited"
+
+ # pylint: enable=unused-variable
+
+ server = WSGIServer(application=app.wsgi_app)
+ server.start()
+ yield server
+ server.stop()
+
+ @pytest.mark.asyncio
+ async def test_request_basic(self, server):
+ request = self.make_request()
+ response = await request(url=server.url + "/basic", method="GET")
+ assert response.status == http_client.OK
+ assert response.headers["x-test-header"] == "value"
+
+ # Use 13 as this is the length of the data written into the stream.
+
+ data = await response.data.read(13)
+ assert data == b"Basic Content"
+
+ @pytest.mark.asyncio
+ async def test_request_basic_with_http(self, server):
+ request = self.make_with_parameter_request()
+ response = await request(url=server.url + "/basic", method="GET")
+ assert response.status == http_client.OK
+ assert response.headers["x-test-header"] == "value"
+
+ # Use 13 as this is the length of the data written into the stream.
+
+ data = await response.data.read(13)
+ assert data == b"Basic Content"
+
+ @pytest.mark.asyncio
+ async def test_request_with_timeout_success(self, server):
+ request = self.make_request()
+ response = await request(url=server.url + "/basic", method="GET", timeout=2)
+
+ assert response.status == http_client.OK
+ assert response.headers["x-test-header"] == "value"
+
+ data = await response.data.read(13)
+ assert data == b"Basic Content"
+
+ @pytest.mark.asyncio
+ async def test_request_with_timeout_failure(self, server):
+ request = self.make_request()
+
+ with pytest.raises(exceptions.TransportError):
+ await request(url=server.url + "/wait", method="GET", timeout=1)
+
+ @pytest.mark.asyncio
+ async def test_request_headers(self, server):
+ request = self.make_request()
+ response = await request(
+ url=server.url + "/basic",
+ method="GET",
+ headers={"x-test-header": "hello world"},
+ )
+
+ assert response.status == http_client.OK
+ assert response.headers["x-test-header"] == "hello world"
+
+ data = await response.data.read(13)
+ assert data == b"Basic Content"
+
+ @pytest.mark.asyncio
+ async def test_request_error(self, server):
+ request = self.make_request()
+
+ response = await request(url=server.url + "/server_error", method="GET")
+ assert response.status == http_client.INTERNAL_SERVER_ERROR
+ data = await response.data.read(5)
+ assert data == b"Error"
+
+ @pytest.mark.asyncio
+ async def test_connection_error(self):
+ request = self.make_request()
+
+ with pytest.raises(exceptions.TransportError):
+ await request(url="http://{}".format(compliance.NXDOMAIN), method="GET")
diff --git a/tests_async/transport/test_aiohttp_requests.py b/tests_async/transport/test_aiohttp_requests.py
new file mode 100644
index 0000000..10c31db
--- /dev/null
+++ b/tests_async/transport/test_aiohttp_requests.py
@@ -0,0 +1,245 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import aiohttp
+from aioresponses import aioresponses, core
+import mock
+import pytest
+from tests_async.transport import async_compliance
+
+import google.auth._credentials_async
+from google.auth.transport import _aiohttp_requests as aiohttp_requests
+import google.auth.transport._mtls_helper
+
+
+class TestCombinedResponse:
+ @pytest.mark.asyncio
+ async def test__is_compressed(self):
+ response = core.CallbackResult(headers={"Content-Encoding": "gzip"})
+ combined_response = aiohttp_requests._CombinedResponse(response)
+ compressed = combined_response._is_compressed()
+ assert compressed
+
+ def test__is_compressed_not(self):
+ response = core.CallbackResult(headers={"Content-Encoding": "not"})
+ combined_response = aiohttp_requests._CombinedResponse(response)
+ compressed = combined_response._is_compressed()
+ assert not compressed
+
+ @pytest.mark.asyncio
+ async def test_raw_content(self):
+
+ mock_response = mock.AsyncMock()
+ mock_response.content.read.return_value = mock.sentinel.read
+ combined_response = aiohttp_requests._CombinedResponse(response=mock_response)
+ raw_content = await combined_response.raw_content()
+ assert raw_content == mock.sentinel.read
+
+ # Second call to validate the preconfigured path.
+ combined_response._raw_content = mock.sentinel.stored_raw
+ raw_content = await combined_response.raw_content()
+ assert raw_content == mock.sentinel.stored_raw
+
+ @pytest.mark.asyncio
+ async def test_content(self):
+ mock_response = mock.AsyncMock()
+ mock_response.content.read.return_value = mock.sentinel.read
+ combined_response = aiohttp_requests._CombinedResponse(response=mock_response)
+ content = await combined_response.content()
+ assert content == mock.sentinel.read
+
+ @mock.patch(
+ "google.auth.transport._aiohttp_requests.urllib3.response.MultiDecoder.decompress",
+ return_value="decompressed",
+ autospec=True,
+ )
+ @pytest.mark.asyncio
+ async def test_content_compressed(self, urllib3_mock):
+ rm = core.RequestMatch(
+ "url", headers={"Content-Encoding": "gzip"}, payload="compressed"
+ )
+ response = await rm.build_response(core.URL("url"))
+
+ combined_response = aiohttp_requests._CombinedResponse(response=response)
+ content = await combined_response.content()
+
+ urllib3_mock.assert_called_once()
+ assert content == "decompressed"
+
+
+class TestResponse:
+ def test_ctor(self):
+ response = aiohttp_requests._Response(mock.sentinel.response)
+ assert response._response == mock.sentinel.response
+
+ @pytest.mark.asyncio
+ async def test_headers_prop(self):
+ rm = core.RequestMatch("url", headers={"Content-Encoding": "header prop"})
+ mock_response = await rm.build_response(core.URL("url"))
+
+ response = aiohttp_requests._Response(mock_response)
+ assert response.headers["Content-Encoding"] == "header prop"
+
+ @pytest.mark.asyncio
+ async def test_status_prop(self):
+ rm = core.RequestMatch("url", status=123)
+ mock_response = await rm.build_response(core.URL("url"))
+ response = aiohttp_requests._Response(mock_response)
+ assert response.status == 123
+
+ @pytest.mark.asyncio
+ async def test_data_prop(self):
+ mock_response = mock.AsyncMock()
+ mock_response.content.read.return_value = mock.sentinel.read
+ response = aiohttp_requests._Response(mock_response)
+ data = await response.data.read()
+ assert data == mock.sentinel.read
+
+
+class TestRequestResponse(async_compliance.RequestResponseTests):
+ def make_request(self):
+ return aiohttp_requests.Request()
+
+ def make_with_parameter_request(self):
+ http = mock.create_autospec(aiohttp.ClientSession, instance=True)
+ return aiohttp_requests.Request(http)
+
+ def test_timeout(self):
+ http = mock.create_autospec(aiohttp.ClientSession, instance=True)
+ request = aiohttp_requests.Request(http)
+ request(url="http://example.com", method="GET", timeout=5)
+
+
+class CredentialsStub(google.auth._credentials_async.Credentials):
+ def __init__(self, token="token"):
+ super(CredentialsStub, self).__init__()
+ self.token = token
+
+ def apply(self, headers, token=None):
+ headers["authorization"] = self.token
+
+ def refresh(self, request):
+ self.token += "1"
+
+
+class TestAuthorizedSession(object):
+ TEST_URL = "http://example.com/"
+ method = "GET"
+
+ def test_constructor(self):
+ authed_session = aiohttp_requests.AuthorizedSession(mock.sentinel.credentials)
+ assert authed_session.credentials == mock.sentinel.credentials
+
+ def test_constructor_with_auth_request(self):
+ http = mock.create_autospec(aiohttp.ClientSession)
+ auth_request = aiohttp_requests.Request(http)
+
+ authed_session = aiohttp_requests.AuthorizedSession(
+ mock.sentinel.credentials, auth_request=auth_request
+ )
+
+ assert authed_session._auth_request == auth_request
+
+ @pytest.mark.asyncio
+ async def test_request(self):
+ with aioresponses() as mocked:
+ credentials = mock.Mock(wraps=CredentialsStub())
+
+ mocked.get(self.TEST_URL, status=200, body="test")
+ session = aiohttp_requests.AuthorizedSession(credentials)
+ resp = await session.request(
+ "GET",
+ "http://example.com/",
+ headers={"Keep-Alive": "timeout=5, max=1000", "fake": b"bytes"},
+ )
+
+ assert resp.status == 200
+ assert "test" == await resp.text()
+
+ await session.close()
+
+ @pytest.mark.asyncio
+ async def test_ctx(self):
+ with aioresponses() as mocked:
+ credentials = mock.Mock(wraps=CredentialsStub())
+ mocked.get("http://test.example.com", payload=dict(foo="bar"))
+ session = aiohttp_requests.AuthorizedSession(credentials)
+ resp = await session.request("GET", "http://test.example.com")
+ data = await resp.json()
+
+ assert dict(foo="bar") == data
+
+ await session.close()
+
+ @pytest.mark.asyncio
+ async def test_http_headers(self):
+ with aioresponses() as mocked:
+ credentials = mock.Mock(wraps=CredentialsStub())
+ mocked.post(
+ "http://example.com",
+ payload=dict(),
+ headers=dict(connection="keep-alive"),
+ )
+
+ session = aiohttp_requests.AuthorizedSession(credentials)
+ resp = await session.request("POST", "http://example.com")
+
+ assert resp.headers["Connection"] == "keep-alive"
+
+ await session.close()
+
+ @pytest.mark.asyncio
+ async def test_regexp_example(self):
+ with aioresponses() as mocked:
+ credentials = mock.Mock(wraps=CredentialsStub())
+ mocked.get("http://example.com", status=500)
+ mocked.get("http://example.com", status=200)
+
+ session1 = aiohttp_requests.AuthorizedSession(credentials)
+
+ resp1 = await session1.request("GET", "http://example.com")
+ session2 = aiohttp_requests.AuthorizedSession(credentials)
+ resp2 = await session2.request("GET", "http://example.com")
+
+ assert resp1.status == 500
+ assert resp2.status == 200
+
+ await session1.close()
+ await session2.close()
+
+ @pytest.mark.asyncio
+ async def test_request_no_refresh(self):
+ credentials = mock.Mock(wraps=CredentialsStub())
+ with aioresponses() as mocked:
+ mocked.get("http://example.com", status=200)
+ authed_session = aiohttp_requests.AuthorizedSession(credentials)
+ response = await authed_session.request("GET", "http://example.com")
+ assert response.status == 200
+ assert credentials.before_request.called
+ assert not credentials.refresh.called
+
+ await authed_session.close()
+
+ @pytest.mark.asyncio
+ async def test_request_refresh(self):
+ credentials = mock.Mock(wraps=CredentialsStub())
+ with aioresponses() as mocked:
+ mocked.get("http://example.com", status=401)
+ mocked.get("http://example.com", status=200)
+ authed_session = aiohttp_requests.AuthorizedSession(credentials)
+ response = await authed_session.request("GET", "http://example.com")
+ assert credentials.refresh.called
+ assert response.status == 200
+
+ await authed_session.close()