feat: add access token credentials (#476)
feat: add access token credentials
diff --git a/google/auth/_cloud_sdk.py b/google/auth/_cloud_sdk.py
index 61ffd4f..e772fe9 100644
--- a/google/auth/_cloud_sdk.py
+++ b/google/auth/_cloud_sdk.py
@@ -18,8 +18,10 @@
import os
import subprocess
+import six
+
from google.auth import environment_vars
-import google.oauth2.credentials
+from google.auth import exceptions
# The ~/.config subdirectory containing gcloud credentials.
@@ -34,6 +36,8 @@
_CLOUD_SDK_WINDOWS_COMMAND = "gcloud.cmd"
# The command to get the Cloud SDK configuration
_CLOUD_SDK_CONFIG_COMMAND = ("config", "config-helper", "--format", "json")
+# The command to get google user access token
+_CLOUD_SDK_USER_ACCESS_TOKEN_COMMAND = ("auth", "print-access-token")
# Cloud SDK's application-default client ID
CLOUD_SDK_CLIENT_ID = (
"764086051850-6qr4p6gpi6hn506pt8ejuq83di341hur.apps.googleusercontent.com"
@@ -80,21 +84,6 @@
return os.path.join(config_path, _CREDENTIALS_FILENAME)
-def load_authorized_user_credentials(info):
- """Loads an authorized user credential.
-
- Args:
- info (Mapping[str, str]): The loaded file's data.
-
- Returns:
- google.oauth2.credentials.Credentials: The constructed credentials.
-
- Raises:
- ValueError: if the info is in the wrong format or missing data.
- """
- return google.oauth2.credentials.Credentials.from_authorized_user_info(info)
-
-
def get_project_id():
"""Gets the project ID from the Cloud SDK.
@@ -122,3 +111,42 @@
return configuration["configuration"]["properties"]["core"]["project"]
except KeyError:
return None
+
+
+def get_auth_access_token(account=None):
+ """Load user access token with the ``gcloud auth print-access-token`` command.
+
+ Args:
+ account (Optional[str]): Account to get the access token for. If not
+ specified, the current active account will be used.
+
+ Returns:
+ str: The user access token.
+
+ Raises:
+ google.auth.exceptions.UserAccessTokenError: if failed to get access
+ token from gcloud.
+ """
+ if os.name == "nt":
+ command = _CLOUD_SDK_WINDOWS_COMMAND
+ else:
+ command = _CLOUD_SDK_POSIX_COMMAND
+
+ try:
+ if account:
+ command = (
+ (command,)
+ + _CLOUD_SDK_USER_ACCESS_TOKEN_COMMAND
+ + ("--account=" + account,)
+ )
+ else:
+ command = (command,) + _CLOUD_SDK_USER_ACCESS_TOKEN_COMMAND
+
+ access_token = subprocess.check_output(command, stderr=subprocess.STDOUT)
+ # remove the trailing "\n"
+ return access_token.decode("utf-8").strip()
+ except (subprocess.CalledProcessError, OSError, IOError) as caught_exc:
+ new_exc = exceptions.UserAccessTokenError(
+ "Failed to obtain access token", caught_exc
+ )
+ six.raise_from(new_exc, caught_exc)
diff --git a/google/auth/_default.py b/google/auth/_default.py
index 32e81ba..d7110a1 100644
--- a/google/auth/_default.py
+++ b/google/auth/_default.py
@@ -106,10 +106,10 @@
credential_type = info.get("type")
if credential_type == _AUTHORIZED_USER_TYPE:
- from google.auth import _cloud_sdk
+ from google.oauth2 import credentials
try:
- credentials = _cloud_sdk.load_authorized_user_credentials(info)
+ credentials = credentials.Credentials.from_authorized_user_info(info)
except ValueError as caught_exc:
msg = "Failed to load authorized user credentials from {}".format(filename)
new_exc = exceptions.DefaultCredentialsError(msg, caught_exc)
diff --git a/google/auth/exceptions.py b/google/auth/exceptions.py
index e034c55..4f66dc2 100644
--- a/google/auth/exceptions.py
+++ b/google/auth/exceptions.py
@@ -28,5 +28,9 @@
failed."""
+class UserAccessTokenError(GoogleAuthError):
+ """Used to indicate ``gcloud auth print-access-token`` command failed."""
+
+
class DefaultCredentialsError(GoogleAuthError):
"""Used to indicate that acquiring default credentials failed."""
diff --git a/google/oauth2/credentials.py b/google/oauth2/credentials.py
index 1adcbf6..baf3cf7 100644
--- a/google/oauth2/credentials.py
+++ b/google/oauth2/credentials.py
@@ -36,6 +36,7 @@
import six
+from google.auth import _cloud_sdk
from google.auth import _helpers
from google.auth import credentials
from google.auth import exceptions
@@ -292,3 +293,50 @@
prep = {k: v for k, v in prep.items() if k not in strip}
return json.dumps(prep)
+
+
+class UserAccessTokenCredentials(credentials.Credentials):
+ """Access token credentials for user account.
+
+ Obtain the access token for a given user account or the current active
+ user account with the ``gcloud auth print-access-token`` command.
+
+ Args:
+ account (Optional[str]): Account to get the access token for. If not
+ specified, the current active account will be used.
+ """
+
+ def __init__(self, account=None):
+ super(UserAccessTokenCredentials, self).__init__()
+ self._account = account
+
+ def with_account(self, account):
+ """Create a new instance with the given account.
+
+ Args:
+ account (str): Account to get the access token for.
+
+ Returns:
+ google.oauth2.credentials.UserAccessTokenCredentials: The created
+ credentials with the given account.
+ """
+ return self.__class__(account=account)
+
+ def refresh(self, request):
+ """Refreshes the access token.
+
+ Args:
+ request (google.auth.transport.Request): This argument is required
+ by the base class interface but not used in this implementation,
+ so just set it to `None`.
+
+ Raises:
+ google.auth.exceptions.UserAccessTokenError: If the access token
+ refresh failed.
+ """
+ self.token = _cloud_sdk.get_auth_access_token(self._account)
+
+ @_helpers.copy_docstring(credentials.Credentials)
+ def before_request(self, request, method, url, headers):
+ self.refresh(request)
+ self.apply(headers)
diff --git a/system_tests/test_mtls_http.py b/system_tests/test_mtls_http.py
index e7ea0b2..1fd8031 100644
--- a/system_tests/test_mtls_http.py
+++ b/system_tests/test_mtls_http.py
@@ -14,6 +14,7 @@
import json
from os import path
+import time
import google.auth
import google.auth.credentials
@@ -42,6 +43,9 @@
# supposed to be created.
assert authed_session.is_mtls == check_context_aware_metadata()
+ # Sleep 1 second to avoid 503 error.
+ time.sleep(1)
+
if authed_session.is_mtls:
response = authed_session.get(MTLS_ENDPOINT.format(project_id))
else:
@@ -63,6 +67,9 @@
# supposed to be created.
assert is_mtls == check_context_aware_metadata()
+ # Sleep 1 second to avoid 503 error.
+ time.sleep(1)
+
if is_mtls:
response = authed_http.request("GET", MTLS_ENDPOINT.format(project_id))
else:
diff --git a/tests/oauth2/test_credentials.py b/tests/oauth2/test_credentials.py
index bdb63e9..76aa463 100644
--- a/tests/oauth2/test_credentials.py
+++ b/tests/oauth2/test_credentials.py
@@ -421,3 +421,31 @@
) as f:
credentials = pickle.load(f)
assert credentials.quota_project_id is None
+
+
+class TestUserAccessTokenCredentials(object):
+ def test_instance(self):
+ cred = credentials.UserAccessTokenCredentials()
+ assert cred._account is None
+
+ cred = cred.with_account("account")
+ assert cred._account == "account"
+
+ @mock.patch("google.auth._cloud_sdk.get_auth_access_token", autospec=True)
+ def test_refresh(self, get_auth_access_token):
+ get_auth_access_token.return_value = "access_token"
+ cred = credentials.UserAccessTokenCredentials()
+ cred.refresh(None)
+ assert cred.token == "access_token"
+
+ @mock.patch(
+ "google.oauth2.credentials.UserAccessTokenCredentials.apply", autospec=True
+ )
+ @mock.patch(
+ "google.oauth2.credentials.UserAccessTokenCredentials.refresh", autospec=True
+ )
+ def test_before_request(self, refresh, apply):
+ cred = credentials.UserAccessTokenCredentials()
+ cred.before_request(mock.Mock(), "GET", "https://example.com", {})
+ refresh.assert_called()
+ apply.assert_called()
diff --git a/tests/test__cloud_sdk.py b/tests/test__cloud_sdk.py
index 049ed99..3377604 100644
--- a/tests/test__cloud_sdk.py
+++ b/tests/test__cloud_sdk.py
@@ -22,7 +22,7 @@
from google.auth import _cloud_sdk
from google.auth import environment_vars
-import google.oauth2.credentials
+from google.auth import exceptions
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
@@ -137,23 +137,33 @@
assert os.path.split(config_path) == ("G:/\\", _cloud_sdk._CONFIG_DIRECTORY)
-def test_load_authorized_user_credentials():
- credentials = _cloud_sdk.load_authorized_user_credentials(AUTHORIZED_USER_FILE_DATA)
+@mock.patch("os.name", new="nt")
+@mock.patch("subprocess.check_output", autospec=True)
+def test_get_auth_access_token_windows(check_output):
+ check_output.return_value = b"access_token\n"
- assert isinstance(credentials, google.oauth2.credentials.Credentials)
-
- assert credentials.token is None
- assert credentials._refresh_token == AUTHORIZED_USER_FILE_DATA["refresh_token"]
- assert credentials._client_id == AUTHORIZED_USER_FILE_DATA["client_id"]
- assert credentials._client_secret == AUTHORIZED_USER_FILE_DATA["client_secret"]
- assert (
- credentials._token_uri
- == google.oauth2.credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT
+ token = _cloud_sdk.get_auth_access_token()
+ assert token == "access_token"
+ check_output.assert_called_with(
+ ("gcloud.cmd", "auth", "print-access-token"), stderr=subprocess.STDOUT
)
-def test_load_authorized_user_credentials_bad_format():
- with pytest.raises(ValueError) as excinfo:
- _cloud_sdk.load_authorized_user_credentials({})
+@mock.patch("subprocess.check_output", autospec=True)
+def test_get_auth_access_token_with_account(check_output):
+ check_output.return_value = b"access_token\n"
- assert excinfo.match(r"missing fields")
+ token = _cloud_sdk.get_auth_access_token(account="account")
+ assert token == "access_token"
+ check_output.assert_called_with(
+ ("gcloud", "auth", "print-access-token", "--account=account"),
+ stderr=subprocess.STDOUT,
+ )
+
+
+@mock.patch("subprocess.check_output", autospec=True)
+def test_get_auth_access_token_with_exception(check_output):
+ check_output.side_effect = OSError()
+
+ with pytest.raises(exceptions.UserAccessTokenError):
+ _cloud_sdk.get_auth_access_token(account="account")