Add `google.oauth2.service_account.IDTokenCredentials`. (#234)
diff --git a/tests/oauth2/test__client.py b/tests/oauth2/test__client.py
index 6aeb3d1..3ec7fc6 100644
--- a/tests/oauth2/test__client.py
+++ b/tests/oauth2/test__client.py
@@ -14,6 +14,7 @@
import datetime
import json
+import os
import mock
import pytest
@@ -21,11 +22,22 @@
from six.moves import http_client
from six.moves import urllib
+from google.auth import _helpers
+from google.auth import crypt
from google.auth import exceptions
+from google.auth import jwt
from google.auth import transport
from google.oauth2 import _client
+DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data')
+
+with open(os.path.join(DATA_DIR, 'privatekey.pem'), 'rb') as fh:
+ PRIVATE_KEY_BYTES = fh.read()
+
+SIGNER = crypt.RSASigner.from_string(PRIVATE_KEY_BYTES, '1')
+
+
def test__handle_error_response():
response_data = json.dumps({
'error': 'help',
@@ -129,6 +141,42 @@
_client.jwt_grant(request, 'http://example.com', 'assertion_value')
+def test_id_token_jwt_grant():
+ now = _helpers.utcnow()
+ id_token_expiry = _helpers.datetime_to_secs(now)
+ id_token = jwt.encode(SIGNER, {'exp': id_token_expiry}).decode('utf-8')
+ request = make_request({
+ 'id_token': id_token,
+ 'extra': 'data'})
+
+ token, expiry, extra_data = _client.id_token_jwt_grant(
+ request, 'http://example.com', 'assertion_value')
+
+ # Check request call
+ verify_request_params(request, {
+ 'grant_type': _client._JWT_GRANT_TYPE,
+ 'assertion': 'assertion_value'
+ })
+
+ # Check result
+ assert token == id_token
+ # JWT does not store microseconds
+ now = now.replace(microsecond=0)
+ assert expiry == now
+ assert extra_data['extra'] == 'data'
+
+
+def test_id_token_jwt_grant_no_access_token():
+ request = make_request({
+ # No access token.
+ 'expires_in': 500,
+ 'extra': 'data'})
+
+ with pytest.raises(exceptions.RefreshError):
+ _client.id_token_jwt_grant(
+ request, 'http://example.com', 'assertion_value')
+
+
@mock.patch('google.auth._helpers.utcnow', return_value=datetime.datetime.min)
def test_refresh_grant(unused_utcnow):
request = make_request({
diff --git a/tests/oauth2/test_service_account.py b/tests/oauth2/test_service_account.py
index 9c235db..54ac0f5 100644
--- a/tests/oauth2/test_service_account.py
+++ b/tests/oauth2/test_service_account.py
@@ -216,3 +216,126 @@
# Credentials should now be valid.
assert credentials.valid
+
+
+class TestIDTokenCredentials(object):
+ SERVICE_ACCOUNT_EMAIL = 'service-account@example.com'
+ TOKEN_URI = 'https://example.com/oauth2/token'
+ TARGET_AUDIENCE = 'https://example.com'
+
+ @classmethod
+ def make_credentials(cls):
+ return service_account.IDTokenCredentials(
+ SIGNER, cls.SERVICE_ACCOUNT_EMAIL, cls.TOKEN_URI,
+ cls.TARGET_AUDIENCE)
+
+ def test_from_service_account_info(self):
+ credentials = (
+ service_account.IDTokenCredentials.from_service_account_info(
+ SERVICE_ACCOUNT_INFO,
+ target_audience=self.TARGET_AUDIENCE))
+
+ assert (credentials._signer.key_id ==
+ SERVICE_ACCOUNT_INFO['private_key_id'])
+ assert (credentials.service_account_email ==
+ SERVICE_ACCOUNT_INFO['client_email'])
+ assert credentials._token_uri == SERVICE_ACCOUNT_INFO['token_uri']
+ assert credentials._target_audience == self.TARGET_AUDIENCE
+
+ def test_from_service_account_file(self):
+ info = SERVICE_ACCOUNT_INFO.copy()
+
+ credentials = (
+ service_account.IDTokenCredentials.from_service_account_file(
+ SERVICE_ACCOUNT_JSON_FILE,
+ target_audience=self.TARGET_AUDIENCE))
+
+ assert credentials.service_account_email == info['client_email']
+ assert credentials._signer.key_id == info['private_key_id']
+ assert credentials._token_uri == info['token_uri']
+ assert credentials._target_audience == self.TARGET_AUDIENCE
+
+ def test_default_state(self):
+ credentials = self.make_credentials()
+ assert not credentials.valid
+ # Expiration hasn't been set yet
+ assert not credentials.expired
+
+ def test_sign_bytes(self):
+ credentials = self.make_credentials()
+ to_sign = b'123'
+ signature = credentials.sign_bytes(to_sign)
+ assert crypt.verify_signature(to_sign, signature, PUBLIC_CERT_BYTES)
+
+ def test_signer(self):
+ credentials = self.make_credentials()
+ assert isinstance(credentials.signer, crypt.Signer)
+
+ def test_signer_email(self):
+ credentials = self.make_credentials()
+ assert credentials.signer_email == self.SERVICE_ACCOUNT_EMAIL
+
+ def test_with_target_audience(self):
+ credentials = self.make_credentials()
+ new_credentials = credentials.with_target_audience(
+ 'https://new.example.com')
+ assert new_credentials._target_audience == 'https://new.example.com'
+
+ def test__make_authorization_grant_assertion(self):
+ credentials = self.make_credentials()
+ token = credentials._make_authorization_grant_assertion()
+ payload = jwt.decode(token, PUBLIC_CERT_BYTES)
+ assert payload['iss'] == self.SERVICE_ACCOUNT_EMAIL
+ assert payload['aud'] == self.TOKEN_URI
+ assert payload['target_audience'] == self.TARGET_AUDIENCE
+
+ @mock.patch('google.oauth2._client.id_token_jwt_grant', autospec=True)
+ def test_refresh_success(self, id_token_jwt_grant):
+ credentials = self.make_credentials()
+ token = 'token'
+ id_token_jwt_grant.return_value = (
+ token,
+ _helpers.utcnow() + datetime.timedelta(seconds=500),
+ {})
+ request = mock.create_autospec(transport.Request, instance=True)
+
+ # Refresh credentials
+ credentials.refresh(request)
+
+ # Check jwt grant call.
+ assert id_token_jwt_grant.called
+
+ called_request, token_uri, assertion = id_token_jwt_grant.call_args[0]
+ assert called_request == request
+ assert token_uri == credentials._token_uri
+ assert jwt.decode(assertion, PUBLIC_CERT_BYTES)
+ # No further assertion done on the token, as there are separate tests
+ # for checking the authorization grant assertion.
+
+ # Check that the credentials have the token.
+ assert credentials.token == token
+
+ # Check that the credentials are valid (have a token and are not
+ # expired)
+ assert credentials.valid
+
+ @mock.patch('google.oauth2._client.id_token_jwt_grant', autospec=True)
+ def test_before_request_refreshes(self, id_token_jwt_grant):
+ credentials = self.make_credentials()
+ token = 'token'
+ id_token_jwt_grant.return_value = (
+ token, _helpers.utcnow() + datetime.timedelta(seconds=500), None)
+ request = mock.create_autospec(transport.Request, instance=True)
+
+ # Credentials should start as invalid
+ assert not credentials.valid
+
+ # before_request should cause a refresh
+ credentials.before_request(
+ request, 'GET', 'http://example.com?a=1#3', {})
+
+ # The refresh endpoint should've been called.
+ assert id_token_jwt_grant.called
+
+ # Credentials should now be valid.
+ assert credentials.valid