feat: add reauth feature to user credentials (#727)
* feat: add reauth support to oauth2 credentials
* update
diff --git a/google/auth/exceptions.py b/google/auth/exceptions.py
index b6f686b..57f181e 100644
--- a/google/auth/exceptions.py
+++ b/google/auth/exceptions.py
@@ -48,3 +48,12 @@
class OAuthError(GoogleAuthError):
"""Used to indicate an error occurred during an OAuth related HTTP
request."""
+
+
+class ReauthFailError(RefreshError):
+ """An exception for when reauth failed."""
+
+ def __init__(self, message=None):
+ super(ReauthFailError, self).__init__(
+ "Reauthentication failed. {0}".format(message)
+ )
diff --git a/google/oauth2/_client.py b/google/oauth2/_client.py
index 4487163..2f4e847 100644
--- a/google/oauth2/_client.py
+++ b/google/oauth2/_client.py
@@ -35,29 +35,29 @@
from google.auth import jwt
_URLENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
+_JSON_CONTENT_TYPE = "application/json"
_JWT_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
_REFRESH_GRANT_TYPE = "refresh_token"
-def _handle_error_response(response_body):
- """"Translates an error response into an exception.
+def _handle_error_response(response_data):
+ """Translates an error response into an exception.
Args:
- response_body (str): The decoded response data.
+ response_data (Mapping): The decoded response data.
Raises:
- google.auth.exceptions.RefreshError
+ google.auth.exceptions.RefreshError: The errors contained in response_data.
"""
try:
- error_data = json.loads(response_body)
error_details = "{}: {}".format(
- error_data["error"], error_data.get("error_description")
+ response_data["error"], response_data.get("error_description")
)
# If no details could be extracted, use the response data.
except (KeyError, ValueError):
- error_details = response_body
+ error_details = json.dumps(response_data)
- raise exceptions.RefreshError(error_details, response_body)
+ raise exceptions.RefreshError(error_details, response_data)
def _parse_expiry(response_data):
@@ -78,8 +78,11 @@
return None
-def _token_endpoint_request(request, token_uri, body):
+def _token_endpoint_request_no_throw(
+ request, token_uri, body, access_token=None, use_json=False
+):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
+ This function doesn't throw on response errors.
Args:
request (google.auth.transport.Request): A callable used to make
@@ -87,16 +90,23 @@
token_uri (str): The OAuth 2.0 authorizations server's token endpoint
URI.
body (Mapping[str, str]): The parameters to send in the request body.
+ access_token (Optional(str)): The access token needed to make the request.
+ use_json (Optional(bool)): Use urlencoded format or json format for the
+ content type. The default value is False.
Returns:
- Mapping[str, str]: The JSON-decoded response data.
-
- Raises:
- google.auth.exceptions.RefreshError: If the token endpoint returned
- an error.
+ Tuple(bool, Mapping[str, str]): A boolean indicating if the request is
+ successful, and a mapping for the JSON-decoded response data.
"""
- body = urllib.parse.urlencode(body).encode("utf-8")
- headers = {"content-type": _URLENCODED_CONTENT_TYPE}
+ if use_json:
+ headers = {"Content-Type": _JSON_CONTENT_TYPE}
+ body = json.dumps(body).encode("utf-8")
+ else:
+ headers = {"Content-Type": _URLENCODED_CONTENT_TYPE}
+ body = urllib.parse.urlencode(body).encode("utf-8")
+
+ if access_token:
+ headers["Authorization"] = "Bearer {}".format(access_token)
retry = 0
# retry to fetch token for maximum of two times if any internal failure
@@ -121,8 +131,38 @@
):
retry += 1
continue
- _handle_error_response(response_body)
+ return response.status == http_client.OK, response_data
+ return response.status == http_client.OK, response_data
+
+
+def _token_endpoint_request(
+ request, token_uri, body, access_token=None, use_json=False
+):
+ """Makes a request to the OAuth 2.0 authorization server's token endpoint.
+
+ Args:
+ request (google.auth.transport.Request): A callable used to make
+ HTTP requests.
+ token_uri (str): The OAuth 2.0 authorizations server's token endpoint
+ URI.
+ body (Mapping[str, str]): The parameters to send in the request body.
+ access_token (Optional(str)): The access token needed to make the request.
+ use_json (Optional(bool)): Use urlencoded format or json format for the
+ content type. The default value is False.
+
+ Returns:
+ Mapping[str, str]: The JSON-decoded response data.
+
+ Raises:
+ google.auth.exceptions.RefreshError: If the token endpoint returned
+ an error.
+ """
+ response_status_ok, response_data = _token_endpoint_request_no_throw(
+ request, token_uri, body, access_token=access_token, use_json=use_json
+ )
+ if not response_status_ok:
+ _handle_error_response(response_data)
return response_data
@@ -204,8 +244,43 @@
return id_token, expiry, response_data
+def _handle_refresh_grant_response(response_data, refresh_token):
+ """Extract tokens from refresh grant response.
+
+ Args:
+ response_data (Mapping[str, str]): Refresh grant response data.
+ refresh_token (str): Current refresh token.
+
+ Returns:
+ Tuple[str, str, Optional[datetime], Mapping[str, str]]: The access token,
+ refresh token, expiration, and additional data returned by the token
+ endpoint. If response_data doesn't have refresh token, then the current
+ refresh token will be returned.
+
+ Raises:
+ google.auth.exceptions.RefreshError: If the token endpoint returned
+ an error.
+ """
+ try:
+ access_token = response_data["access_token"]
+ except KeyError as caught_exc:
+ new_exc = exceptions.RefreshError("No access token in response.", response_data)
+ six.raise_from(new_exc, caught_exc)
+
+ refresh_token = response_data.get("refresh_token", refresh_token)
+ expiry = _parse_expiry(response_data)
+
+ return access_token, refresh_token, expiry, response_data
+
+
def refresh_grant(
- request, token_uri, refresh_token, client_id, client_secret, scopes=None
+ request,
+ token_uri,
+ refresh_token,
+ client_id,
+ client_secret,
+ scopes=None,
+ rapt_token=None,
):
"""Implements the OAuth 2.0 refresh token grant.
@@ -224,10 +299,11 @@
scopes must be authorized for the refresh token. Useful if refresh
token has a wild card scope (e.g.
'https://www.googleapis.com/auth/any-api').
+ rapt_token (Optional(str)): The reauth Proof Token.
Returns:
- Tuple[str, Optional[str], Optional[datetime], Mapping[str, str]]: The
- access token, new refresh token, expiration, and additional data
+ Tuple[str, str, Optional[datetime], Mapping[str, str]]: The access
+ token, new or current refresh token, expiration, and additional data
returned by the token endpoint.
Raises:
@@ -244,16 +320,8 @@
}
if scopes:
body["scope"] = " ".join(scopes)
+ if rapt_token:
+ body["rapt"] = rapt_token
response_data = _token_endpoint_request(request, token_uri, body)
-
- try:
- access_token = response_data["access_token"]
- except KeyError as caught_exc:
- new_exc = exceptions.RefreshError("No access token in response.", response_data)
- six.raise_from(new_exc, caught_exc)
-
- refresh_token = response_data.get("refresh_token", refresh_token)
- expiry = _parse_expiry(response_data)
-
- return access_token, refresh_token, expiry, response_data
+ return _handle_refresh_grant_response(response_data, refresh_token)
diff --git a/google/oauth2/challenges.py b/google/oauth2/challenges.py
new file mode 100644
index 0000000..d0b070e
--- /dev/null
+++ b/google/oauth2/challenges.py
@@ -0,0 +1,157 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""" Challenges for reauthentication.
+"""
+
+import abc
+import base64
+import getpass
+import sys
+
+import six
+
+from google.auth import _helpers
+from google.auth import exceptions
+
+
+REAUTH_ORIGIN = "https://accounts.google.com"
+
+
+def get_user_password(text):
+ """Get password from user.
+
+ Override this function with a different logic if you are using this library
+ outside a CLI.
+
+ Args:
+ text (str): message for the password prompt.
+
+ Returns:
+ str: password string.
+ """
+ return getpass.getpass(text)
+
+
+@six.add_metaclass(abc.ABCMeta)
+class ReauthChallenge(object):
+ """Base class for reauth challenges."""
+
+ @property
+ @abc.abstractmethod
+ def name(self): # pragma: NO COVER
+ """Returns the name of the challenge."""
+ raise NotImplementedError("name property must be implemented")
+
+ @property
+ @abc.abstractmethod
+ def is_locally_eligible(self): # pragma: NO COVER
+ """Returns true if a challenge is supported locally on this machine."""
+ raise NotImplementedError("is_locally_eligible property must be implemented")
+
+ @abc.abstractmethod
+ def obtain_challenge_input(self, metadata): # pragma: NO COVER
+ """Performs logic required to obtain credentials and returns it.
+
+ Args:
+ metadata (Mapping): challenge metadata returned in the 'challenges' field in
+ the initial reauth request. Includes the 'challengeType' field
+ and other challenge-specific fields.
+
+ Returns:
+ response that will be send to the reauth service as the content of
+ the 'proposalResponse' field in the request body. Usually a dict
+ with the keys specific to the challenge. For example,
+ ``{'credential': password}`` for password challenge.
+ """
+ raise NotImplementedError("obtain_challenge_input method must be implemented")
+
+
+class PasswordChallenge(ReauthChallenge):
+ """Challenge that asks for user's password."""
+
+ @property
+ def name(self):
+ return "PASSWORD"
+
+ @property
+ def is_locally_eligible(self):
+ return True
+
+ @_helpers.copy_docstring(ReauthChallenge)
+ def obtain_challenge_input(self, unused_metadata):
+ passwd = get_user_password("Please enter your password:")
+ if not passwd:
+ passwd = " " # avoid the server crashing in case of no password :D
+ return {"credential": passwd}
+
+
+class SecurityKeyChallenge(ReauthChallenge):
+ """Challenge that asks for user's security key touch."""
+
+ @property
+ def name(self):
+ return "SECURITY_KEY"
+
+ @property
+ def is_locally_eligible(self):
+ return True
+
+ @_helpers.copy_docstring(ReauthChallenge)
+ def obtain_challenge_input(self, metadata):
+ try:
+ import pyu2f.convenience.authenticator
+ import pyu2f.errors
+ import pyu2f.model
+ except ImportError:
+ raise exceptions.ReauthFailError(
+ "pyu2f dependency is required to use Security key reauth feature. "
+ "It can be installed via `pip install pyu2f` or `pip install google-auth[reauth]`."
+ )
+ sk = metadata["securityKey"]
+ challenges = sk["challenges"]
+ app_id = sk["applicationId"]
+
+ challenge_data = []
+ for c in challenges:
+ kh = c["keyHandle"].encode("ascii")
+ key = pyu2f.model.RegisteredKey(bytearray(base64.urlsafe_b64decode(kh)))
+ challenge = c["challenge"].encode("ascii")
+ challenge = base64.urlsafe_b64decode(challenge)
+ challenge_data.append({"key": key, "challenge": challenge})
+
+ try:
+ api = pyu2f.convenience.authenticator.CreateCompositeAuthenticator(
+ REAUTH_ORIGIN
+ )
+ response = api.Authenticate(
+ app_id, challenge_data, print_callback=sys.stderr.write
+ )
+ return {"securityKey": response}
+ except pyu2f.errors.U2FError as e:
+ if e.code == pyu2f.errors.U2FError.DEVICE_INELIGIBLE:
+ sys.stderr.write("Ineligible security key.\n")
+ elif e.code == pyu2f.errors.U2FError.TIMEOUT:
+ sys.stderr.write("Timed out while waiting for security key touch.\n")
+ else:
+ raise e
+ except pyu2f.errors.NoDeviceFoundError:
+ sys.stderr.write("No security key found.\n")
+ return None
+
+
+AVAILABLE_CHALLENGES = {
+ challenge.name: challenge
+ for challenge in [SecurityKeyChallenge(), PasswordChallenge()]
+}
diff --git a/google/oauth2/credentials.py b/google/oauth2/credentials.py
index 464cc48..dcfa5f9 100644
--- a/google/oauth2/credentials.py
+++ b/google/oauth2/credentials.py
@@ -41,7 +41,7 @@
from google.auth import _helpers
from google.auth import credentials
from google.auth import exceptions
-from google.oauth2 import _client
+from google.oauth2 import reauth
# The Google OAuth 2.0 token endpoint. Used for authorized user credentials.
@@ -55,6 +55,10 @@
quota project, use :meth:`with_quota_project` or ::
credentials = credentials.with_quota_project('myproject-123)
+
+ If reauth is enabled, `pyu2f` dependency has to be installed in order to use security
+ key reauth feature. Dependency can be installed via `pip install pyu2f` or `pip install
+ google-auth[reauth]`.
"""
def __init__(
@@ -69,6 +73,7 @@
default_scopes=None,
quota_project_id=None,
expiry=None,
+ rapt_token=None,
):
"""
Args:
@@ -97,6 +102,7 @@
quota_project_id (Optional[str]): The project ID used for quota and billing.
This project may be different from the project used to
create the credentials.
+ rapt_token (Optional[str]): The reauth Proof Token.
"""
super(Credentials, self).__init__()
self.token = token
@@ -109,6 +115,7 @@
self._client_id = client_id
self._client_secret = client_secret
self._quota_project_id = quota_project_id
+ self._rapt_token = rapt_token
def __getstate__(self):
"""A __getstate__ method must exist for the __setstate__ to be called
@@ -130,6 +137,7 @@
self._client_id = d.get("_client_id")
self._client_secret = d.get("_client_secret")
self._quota_project_id = d.get("_quota_project_id")
+ self._rapt_token = d.get("_rapt_token")
@property
def refresh_token(self):
@@ -174,6 +182,11 @@
the initial token is requested and can not be changed."""
return False
+ @property
+ def rapt_token(self):
+ """Optional[str]: The reauth Proof Token."""
+ return self._rapt_token
+
@_helpers.copy_docstring(credentials.CredentialsWithQuotaProject)
def with_quota_project(self, quota_project_id):
@@ -187,6 +200,7 @@
scopes=self.scopes,
default_scopes=self.default_scopes,
quota_project_id=quota_project_id,
+ rapt_token=self.rapt_token,
)
@_helpers.copy_docstring(credentials.Credentials)
@@ -205,23 +219,31 @@
scopes = self._scopes if self._scopes is not None else self._default_scopes
- access_token, refresh_token, expiry, grant_response = _client.refresh_grant(
+ (
+ access_token,
+ refresh_token,
+ expiry,
+ grant_response,
+ rapt_token,
+ ) = reauth.refresh_grant(
request,
self._token_uri,
self._refresh_token,
self._client_id,
self._client_secret,
- scopes,
+ scopes=scopes,
+ rapt_token=self._rapt_token,
)
self.token = access_token
self.expiry = expiry
self._refresh_token = refresh_token
self._id_token = grant_response.get("id_token")
+ self._rapt_token = rapt_token
- if scopes and "scopes" in grant_response:
+ if scopes and "scope" in grant_response:
requested_scopes = frozenset(scopes)
- granted_scopes = frozenset(grant_response["scopes"].split())
+ granted_scopes = frozenset(grant_response["scope"].split())
scopes_requested_but_not_granted = requested_scopes - granted_scopes
if scopes_requested_but_not_granted:
raise exceptions.RefreshError(
@@ -323,6 +345,7 @@
"client_id": self.client_id,
"client_secret": self.client_secret,
"scopes": self.scopes,
+ "rapt_token": self.rapt_token,
}
if self.expiry: # flatten expiry timestamp
prep["expiry"] = self.expiry.isoformat() + "Z"
diff --git a/google/oauth2/reauth.py b/google/oauth2/reauth.py
new file mode 100644
index 0000000..d539d7c
--- /dev/null
+++ b/google/oauth2/reauth.py
@@ -0,0 +1,341 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A module that provides functions for handling rapt authentication.
+
+Reauth is a process of obtaining additional authentication (such as password,
+security token, etc.) while refreshing OAuth 2.0 credentials for a user.
+
+Credentials that use the Reauth flow must have the reauth scope,
+``https://www.googleapis.com/auth/accounts.reauth``.
+
+This module provides a high-level function for executing the Reauth process,
+:func:`refresh_grant`, and lower-level helpers for doing the individual
+steps of the reauth process.
+
+Those steps are:
+
+1. Obtaining a list of challenges from the reauth server.
+2. Running through each challenge and sending the result back to the reauth
+ server.
+3. Refreshing the access token using the returned rapt token.
+"""
+
+import sys
+
+from six.moves import range
+
+from google.auth import exceptions
+from google.oauth2 import _client
+from google.oauth2 import challenges
+
+
+_REAUTH_SCOPE = "https://www.googleapis.com/auth/accounts.reauth"
+_REAUTH_API = "https://reauth.googleapis.com/v2/sessions"
+
+_REAUTH_NEEDED_ERROR = "invalid_grant"
+_REAUTH_NEEDED_ERROR_INVALID_RAPT = "invalid_rapt"
+_REAUTH_NEEDED_ERROR_RAPT_REQUIRED = "rapt_required"
+
+_AUTHENTICATED = "AUTHENTICATED"
+_CHALLENGE_REQUIRED = "CHALLENGE_REQUIRED"
+_CHALLENGE_PENDING = "CHALLENGE_PENDING"
+
+
+# Override this global variable to set custom max number of rounds of reauth
+# challenges should be run.
+RUN_CHALLENGE_RETRY_LIMIT = 5
+
+
+def is_interactive():
+ """Check if we are in an interractive environment.
+
+ Override this function with a different logic if you are using this library
+ outside a CLI.
+
+ If the rapt token needs refreshing, the user needs to answer the challenges.
+ If the user is not in an interractive environment, the challenges can not
+ be answered and we just wait for timeout for no reason.
+
+ Returns:
+ bool: True if is interactive environment, False otherwise.
+ """
+
+ return sys.stdin.isatty()
+
+
+def _get_challenges(
+ request, supported_challenge_types, access_token, requested_scopes=None
+):
+ """Does initial request to reauth API to get the challenges.
+
+ Args:
+ request (google.auth.transport.Request): A callable used to make
+ HTTP requests.
+ supported_challenge_types (Sequence[str]): list of challenge names
+ supported by the manager.
+ access_token (str): Access token with reauth scopes.
+ requested_scopes (Optional(Sequence[str])): Authorized scopes for the credentials.
+
+ Returns:
+ dict: The response from the reauth API.
+ """
+ body = {"supportedChallengeTypes": supported_challenge_types}
+ if requested_scopes:
+ body["oauthScopesForDomainPolicyLookup"] = requested_scopes
+
+ return _client._token_endpoint_request(
+ request, _REAUTH_API + ":start", body, access_token=access_token, use_json=True
+ )
+
+
+def _send_challenge_result(
+ request, session_id, challenge_id, client_input, access_token
+):
+ """Attempt to refresh access token by sending next challenge result.
+
+ Args:
+ request (google.auth.transport.Request): A callable used to make
+ HTTP requests.
+ session_id (str): session id returned by the initial reauth call.
+ challenge_id (str): challenge id returned by the initial reauth call.
+ client_input: dict with a challenge-specific client input. For example:
+ ``{'credential': password}`` for password challenge.
+ access_token (str): Access token with reauth scopes.
+
+ Returns:
+ dict: The response from the reauth API.
+ """
+ body = {
+ "sessionId": session_id,
+ "challengeId": challenge_id,
+ "action": "RESPOND",
+ "proposalResponse": client_input,
+ }
+
+ return _client._token_endpoint_request(
+ request,
+ _REAUTH_API + "/{}:continue".format(session_id),
+ body,
+ access_token=access_token,
+ use_json=True,
+ )
+
+
+def _run_next_challenge(msg, request, access_token):
+ """Get the next challenge from msg and run it.
+
+ Args:
+ msg (dict): Reauth API response body (either from the initial request to
+ https://reauth.googleapis.com/v2/sessions:start or from sending the
+ previous challenge response to
+ https://reauth.googleapis.com/v2/sessions/id:continue)
+ request (google.auth.transport.Request): A callable used to make
+ HTTP requests.
+ access_token (str): reauth access token
+
+ Returns:
+ dict: The response from the reauth API.
+
+ Raises:
+ google.auth.exceptions.ReauthError: if reauth failed.
+ """
+ for challenge in msg["challenges"]:
+ if challenge["status"] != "READY":
+ # Skip non-activated challenges.
+ continue
+ c = challenges.AVAILABLE_CHALLENGES.get(challenge["challengeType"], None)
+ if not c:
+ raise exceptions.ReauthFailError(
+ "Unsupported challenge type {0}. Supported types: {1}".format(
+ challenge["challengeType"],
+ ",".join(list(challenges.AVAILABLE_CHALLENGES.keys())),
+ )
+ )
+ if not c.is_locally_eligible:
+ raise exceptions.ReauthFailError(
+ "Challenge {0} is not locally eligible".format(
+ challenge["challengeType"]
+ )
+ )
+ client_input = c.obtain_challenge_input(challenge)
+ if not client_input:
+ return None
+ return _send_challenge_result(
+ request,
+ msg["sessionId"],
+ challenge["challengeId"],
+ client_input,
+ access_token,
+ )
+ return None
+
+
+def _obtain_rapt(request, access_token, requested_scopes):
+ """Given an http request method and reauth access token, get rapt token.
+
+ Args:
+ request (google.auth.transport.Request): A callable used to make
+ HTTP requests.
+ access_token (str): reauth access token
+ requested_scopes (Sequence[str]): scopes required by the client application
+
+ Returns:
+ str: The rapt token.
+
+ Raises:
+ google.auth.exceptions.ReauthError: if reauth failed
+ """
+ msg = _get_challenges(
+ request,
+ list(challenges.AVAILABLE_CHALLENGES.keys()),
+ access_token,
+ requested_scopes,
+ )
+
+ if msg["status"] == _AUTHENTICATED:
+ return msg["encodedProofOfReauthToken"]
+
+ for _ in range(0, RUN_CHALLENGE_RETRY_LIMIT):
+ if not (
+ msg["status"] == _CHALLENGE_REQUIRED or msg["status"] == _CHALLENGE_PENDING
+ ):
+ raise exceptions.ReauthFailError(
+ "Reauthentication challenge failed due to API error: {}".format(
+ msg["status"]
+ )
+ )
+
+ if not is_interactive():
+ raise exceptions.ReauthFailError(
+ "Reauthentication challenge could not be answered because you are not"
+ " in an interactive session."
+ )
+
+ msg = _run_next_challenge(msg, request, access_token)
+
+ if msg["status"] == _AUTHENTICATED:
+ return msg["encodedProofOfReauthToken"]
+
+ # If we got here it means we didn't get authenticated.
+ raise exceptions.ReauthFailError("Failed to obtain rapt token.")
+
+
+def get_rapt_token(
+ request, client_id, client_secret, refresh_token, token_uri, scopes=None
+):
+ """Given an http request method and refresh_token, get rapt token.
+
+ Args:
+ request (google.auth.transport.Request): A callable used to make
+ HTTP requests.
+ client_id (str): client id to get access token for reauth scope.
+ client_secret (str): client secret for the client_id
+ refresh_token (str): refresh token to refresh access token
+ token_uri (str): uri to refresh access token
+ scopes (Optional(Sequence[str])): scopes required by the client application
+
+ Returns:
+ str: The rapt token.
+ Raises:
+ google.auth.exceptions.RefreshError: If reauth failed.
+ """
+ sys.stderr.write("Reauthentication required.\n")
+
+ # Get access token for reauth.
+ access_token, _, _, _ = _client.refresh_grant(
+ request=request,
+ client_id=client_id,
+ client_secret=client_secret,
+ refresh_token=refresh_token,
+ token_uri=token_uri,
+ scopes=[_REAUTH_SCOPE],
+ )
+
+ # Get rapt token from reauth API.
+ rapt_token = _obtain_rapt(request, access_token, requested_scopes=scopes)
+
+ return rapt_token
+
+
+def refresh_grant(
+ request,
+ token_uri,
+ refresh_token,
+ client_id,
+ client_secret,
+ scopes=None,
+ rapt_token=None,
+):
+ """Implements the reauthentication flow.
+
+ Args:
+ request (google.auth.transport.Request): A callable used to make
+ HTTP requests.
+ token_uri (str): The OAuth 2.0 authorizations server's token endpoint
+ URI.
+ refresh_token (str): The refresh token to use to get a new access
+ token.
+ client_id (str): The OAuth 2.0 application's client ID.
+ client_secret (str): The Oauth 2.0 appliaction's client secret.
+ scopes (Optional(Sequence[str])): Scopes to request. If present, all
+ scopes must be authorized for the refresh token. Useful if refresh
+ token has a wild card scope (e.g.
+ 'https://www.googleapis.com/auth/any-api').
+ rapt_token (Optional(str)): The rapt token for reauth.
+
+ Returns:
+ Tuple[str, Optional[str], Optional[datetime], Mapping[str, str]]: The
+ access token, new refresh token, expiration, and additional data
+ returned by the token endpoint.
+
+ Raises:
+ google.auth.exceptions.RefreshError: If the token endpoint returned
+ an error.
+ """
+ body = {
+ "grant_type": _client._REFRESH_GRANT_TYPE,
+ "client_id": client_id,
+ "client_secret": client_secret,
+ "refresh_token": refresh_token,
+ }
+ if scopes:
+ body["scope"] = " ".join(scopes)
+ if rapt_token:
+ body["rapt"] = rapt_token
+
+ response_status_ok, response_data = _client._token_endpoint_request_no_throw(
+ request, token_uri, body
+ )
+ if (
+ not response_status_ok
+ and response_data.get("error") == _REAUTH_NEEDED_ERROR
+ and (
+ response_data.get("error_subtype") == _REAUTH_NEEDED_ERROR_INVALID_RAPT
+ or response_data.get("error_subtype") == _REAUTH_NEEDED_ERROR_RAPT_REQUIRED
+ )
+ ):
+ rapt_token = get_rapt_token(
+ request, client_id, client_secret, refresh_token, token_uri, scopes=scopes
+ )
+ body["rapt"] = rapt_token
+ (response_status_ok, response_data) = _client._token_endpoint_request_no_throw(
+ request, token_uri, body
+ )
+
+ if not response_status_ok:
+ _client._handle_error_response(response_data)
+ return _client._handle_refresh_grant_response(response_data, refresh_token) + (
+ rapt_token,
+ )
diff --git a/noxfile.py b/noxfile.py
index 3b4863c..0bd7f6c 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -25,6 +25,7 @@
"pytest",
"pytest-cov",
"pytest-localserver",
+ "pyu2f",
"requests",
"urllib3",
"cryptography",
diff --git a/setup.py b/setup.py
index 16ba98c..ef723f8 100644
--- a/setup.py
+++ b/setup.py
@@ -33,6 +33,7 @@
extras = {
"aiohttp": "aiohttp >= 3.6.2, < 4.0.0dev; python_version>='3.6'",
"pyopenssl": "pyopenssl>=20.0.0",
+ "reauth": "pyu2f>=0.1.5",
}
with io.open("README.rst", "r") as fh:
diff --git a/tests/oauth2/test__client.py b/tests/oauth2/test__client.py
index c3ae2af..54686df 100644
--- a/tests/oauth2/test__client.py
+++ b/tests/oauth2/test__client.py
@@ -48,7 +48,7 @@
def test__handle_error_response():
- response_data = json.dumps({"error": "help", "error_description": "I'm alive"})
+ response_data = {"error": "help", "error_description": "I'm alive"}
with pytest.raises(exceptions.RefreshError) as excinfo:
_client._handle_error_response(response_data)
@@ -57,12 +57,12 @@
def test__handle_error_response_non_json():
- response_data = "Help, I'm alive"
+ response_data = {"foo": "bar"}
with pytest.raises(exceptions.RefreshError) as excinfo:
_client._handle_error_response(response_data)
- assert excinfo.match(r"Help, I\'m alive")
+ assert excinfo.match(r"{\"foo\": \"bar\"}")
@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
@@ -95,7 +95,7 @@
request.assert_called_with(
method="POST",
url="http://example.com",
- headers={"content-type": "application/x-www-form-urlencoded"},
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
body="test=params".encode("utf-8"),
)
@@ -103,6 +103,32 @@
assert result == {"test": "response"}
+def test__token_endpoint_request_use_json():
+ request = make_request({"test": "response"})
+
+ result = _client._token_endpoint_request(
+ request,
+ "http://example.com",
+ {"test": "params"},
+ access_token="access_token",
+ use_json=True,
+ )
+
+ # Check request call
+ request.assert_called_with(
+ method="POST",
+ url="http://example.com",
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": "Bearer access_token",
+ },
+ body=b'{"test": "params"}',
+ )
+
+ # Check result
+ assert result == {"test": "response"}
+
+
def test__token_endpoint_request_error():
request = make_request({}, status=http_client.BAD_REQUEST)
@@ -220,7 +246,12 @@
)
token, refresh_token, expiry, extra_data = _client.refresh_grant(
- request, "http://example.com", "refresh_token", "client_id", "client_secret"
+ request,
+ "http://example.com",
+ "refresh_token",
+ "client_id",
+ "client_secret",
+ rapt_token="rapt_token",
)
# Check request call
@@ -231,6 +262,7 @@
"refresh_token": "refresh_token",
"client_id": "client_id",
"client_secret": "client_secret",
+ "rapt": "rapt_token",
},
)
diff --git a/tests/oauth2/test_challenges.py b/tests/oauth2/test_challenges.py
new file mode 100644
index 0000000..019b908
--- /dev/null
+++ b/tests/oauth2/test_challenges.py
@@ -0,0 +1,132 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for the reauth module."""
+
+import base64
+import sys
+
+import mock
+import pytest
+import pyu2f
+
+from google.auth import exceptions
+from google.oauth2 import challenges
+
+
+def test_get_user_password():
+ with mock.patch("getpass.getpass", return_value="foo"):
+ assert challenges.get_user_password("") == "foo"
+
+
+def test_security_key():
+ metadata = {
+ "status": "READY",
+ "challengeId": 2,
+ "challengeType": "SECURITY_KEY",
+ "securityKey": {
+ "applicationId": "security_key_application_id",
+ "challenges": [
+ {
+ "keyHandle": "some_key",
+ "challenge": base64.urlsafe_b64encode(
+ "some_challenge".encode("ascii")
+ ).decode("ascii"),
+ }
+ ],
+ },
+ }
+ mock_key = mock.Mock()
+
+ challenge = challenges.SecurityKeyChallenge()
+
+ # Test the case that security key challenge is passed.
+ with mock.patch("pyu2f.model.RegisteredKey", return_value=mock_key):
+ with mock.patch(
+ "pyu2f.convenience.authenticator.CompositeAuthenticator.Authenticate"
+ ) as mock_authenticate:
+ mock_authenticate.return_value = "security key response"
+ assert challenge.name == "SECURITY_KEY"
+ assert challenge.is_locally_eligible
+ assert challenge.obtain_challenge_input(metadata) == {
+ "securityKey": "security key response"
+ }
+ mock_authenticate.assert_called_with(
+ "security_key_application_id",
+ [{"key": mock_key, "challenge": b"some_challenge"}],
+ print_callback=sys.stderr.write,
+ )
+
+ # Test various types of exceptions.
+ with mock.patch("pyu2f.model.RegisteredKey", return_value=mock_key):
+ with mock.patch(
+ "pyu2f.convenience.authenticator.CompositeAuthenticator.Authenticate"
+ ) as mock_authenticate:
+ mock_authenticate.side_effect = pyu2f.errors.U2FError(
+ pyu2f.errors.U2FError.DEVICE_INELIGIBLE
+ )
+ assert challenge.obtain_challenge_input(metadata) is None
+
+ with mock.patch(
+ "pyu2f.convenience.authenticator.CompositeAuthenticator.Authenticate"
+ ) as mock_authenticate:
+ mock_authenticate.side_effect = pyu2f.errors.U2FError(
+ pyu2f.errors.U2FError.TIMEOUT
+ )
+ assert challenge.obtain_challenge_input(metadata) is None
+
+ with mock.patch(
+ "pyu2f.convenience.authenticator.CompositeAuthenticator.Authenticate"
+ ) as mock_authenticate:
+ mock_authenticate.side_effect = pyu2f.errors.U2FError(
+ pyu2f.errors.U2FError.BAD_REQUEST
+ )
+ with pytest.raises(pyu2f.errors.U2FError):
+ challenge.obtain_challenge_input(metadata)
+
+ with mock.patch(
+ "pyu2f.convenience.authenticator.CompositeAuthenticator.Authenticate"
+ ) as mock_authenticate:
+ mock_authenticate.side_effect = pyu2f.errors.NoDeviceFoundError()
+ assert challenge.obtain_challenge_input(metadata) is None
+
+ with mock.patch(
+ "pyu2f.convenience.authenticator.CompositeAuthenticator.Authenticate"
+ ) as mock_authenticate:
+ mock_authenticate.side_effect = pyu2f.errors.UnsupportedVersionException()
+ with pytest.raises(pyu2f.errors.UnsupportedVersionException):
+ challenge.obtain_challenge_input(metadata)
+
+ with mock.patch.dict("sys.modules"):
+ sys.modules["pyu2f"] = None
+ with pytest.raises(exceptions.ReauthFailError) as excinfo:
+ challenge.obtain_challenge_input(metadata)
+ assert excinfo.match(r"pyu2f dependency is required")
+
+
+@mock.patch("getpass.getpass", return_value="foo")
+def test_password_challenge(getpass_mock):
+ challenge = challenges.PasswordChallenge()
+
+ with mock.patch("getpass.getpass", return_value="foo"):
+ assert challenge.is_locally_eligible
+ assert challenge.name == "PASSWORD"
+ assert challenges.PasswordChallenge().obtain_challenge_input({}) == {
+ "credential": "foo"
+ }
+
+ with mock.patch("getpass.getpass", return_value=None):
+ assert challenges.PasswordChallenge().obtain_challenge_input({}) == {
+ "credential": " "
+ }
diff --git a/tests/oauth2/test_credentials.py b/tests/oauth2/test_credentials.py
index b885d29..4a387a5 100644
--- a/tests/oauth2/test_credentials.py
+++ b/tests/oauth2/test_credentials.py
@@ -38,6 +38,7 @@
class TestCredentials(object):
TOKEN_URI = "https://example.com/oauth2/token"
REFRESH_TOKEN = "refresh_token"
+ RAPT_TOKEN = "rapt_token"
CLIENT_ID = "client_id"
CLIENT_SECRET = "client_secret"
@@ -49,6 +50,7 @@
token_uri=cls.TOKEN_URI,
client_id=cls.CLIENT_ID,
client_secret=cls.CLIENT_SECRET,
+ rapt_token=cls.RAPT_TOKEN,
)
def test_default_state(self):
@@ -63,14 +65,16 @@
assert credentials.token_uri == self.TOKEN_URI
assert credentials.client_id == self.CLIENT_ID
assert credentials.client_secret == self.CLIENT_SECRET
+ assert credentials.rapt_token == self.RAPT_TOKEN
- @mock.patch("google.oauth2._client.refresh_grant", autospec=True)
+ @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True)
@mock.patch(
"google.auth._helpers.utcnow",
return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
)
def test_refresh_success(self, unused_utcnow, refresh_grant):
token = "token"
+ new_rapt_token = "new_rapt_token"
expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
grant_response = {"id_token": mock.sentinel.id_token}
refresh_grant.return_value = (
@@ -82,6 +86,8 @@
expiry,
# Extra data
grant_response,
+ # rapt_token
+ new_rapt_token,
)
request = mock.create_autospec(transport.Request)
@@ -98,12 +104,14 @@
self.CLIENT_ID,
self.CLIENT_SECRET,
None,
+ self.RAPT_TOKEN,
)
# Check that the credentials have the token and expiry
assert credentials.token == token
assert credentials.expiry == expiry
assert credentials.id_token == mock.sentinel.id_token
+ assert credentials.rapt_token == new_rapt_token
# Check that the credentials are valid (have a token and are not
# expired)
@@ -118,7 +126,7 @@
request.assert_not_called()
- @mock.patch("google.oauth2._client.refresh_grant", autospec=True)
+ @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True)
@mock.patch(
"google.auth._helpers.utcnow",
return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
@@ -129,8 +137,9 @@
scopes = ["email", "profile"]
default_scopes = ["https://www.googleapis.com/auth/cloud-platform"]
token = "token"
+ new_rapt_token = "new_rapt_token"
expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
- grant_response = {"id_token": mock.sentinel.id_token}
+ grant_response = {"id_token": mock.sentinel.id_token, "scope": "email profile"}
refresh_grant.return_value = (
# Access token
token,
@@ -140,6 +149,8 @@
expiry,
# Extra data
grant_response,
+ # rapt token
+ new_rapt_token,
)
request = mock.create_autospec(transport.Request)
@@ -151,6 +162,7 @@
client_secret=self.CLIENT_SECRET,
scopes=scopes,
default_scopes=default_scopes,
+ rapt_token=self.RAPT_TOKEN,
)
# Refresh credentials
@@ -164,6 +176,7 @@
self.CLIENT_ID,
self.CLIENT_SECRET,
scopes,
+ self.RAPT_TOKEN,
)
# Check that the credentials have the token and expiry
@@ -171,12 +184,13 @@
assert creds.expiry == expiry
assert creds.id_token == mock.sentinel.id_token
assert creds.has_scopes(scopes)
+ assert creds.rapt_token == new_rapt_token
# Check that the credentials are valid (have a token and are not
# expired.)
assert creds.valid
- @mock.patch("google.oauth2._client.refresh_grant", autospec=True)
+ @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True)
@mock.patch(
"google.auth._helpers.utcnow",
return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
@@ -186,6 +200,7 @@
):
default_scopes = ["email", "profile"]
token = "token"
+ new_rapt_token = "new_rapt_token"
expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
grant_response = {"id_token": mock.sentinel.id_token}
refresh_grant.return_value = (
@@ -197,6 +212,8 @@
expiry,
# Extra data
grant_response,
+ # rapt token
+ new_rapt_token,
)
request = mock.create_autospec(transport.Request)
@@ -207,6 +224,7 @@
client_id=self.CLIENT_ID,
client_secret=self.CLIENT_SECRET,
default_scopes=default_scopes,
+ rapt_token=self.RAPT_TOKEN,
)
# Refresh credentials
@@ -220,6 +238,7 @@
self.CLIENT_ID,
self.CLIENT_SECRET,
default_scopes,
+ self.RAPT_TOKEN,
)
# Check that the credentials have the token and expiry
@@ -227,12 +246,13 @@
assert creds.expiry == expiry
assert creds.id_token == mock.sentinel.id_token
assert creds.has_scopes(default_scopes)
+ assert creds.rapt_token == new_rapt_token
# Check that the credentials are valid (have a token and are not
# expired.)
assert creds.valid
- @mock.patch("google.oauth2._client.refresh_grant", autospec=True)
+ @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True)
@mock.patch(
"google.auth._helpers.utcnow",
return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
@@ -242,6 +262,7 @@
):
scopes = ["email", "profile"]
token = "token"
+ new_rapt_token = "new_rapt_token"
expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
grant_response = {
"id_token": mock.sentinel.id_token,
@@ -256,6 +277,8 @@
expiry,
# Extra data
grant_response,
+ # rapt token
+ new_rapt_token,
)
request = mock.create_autospec(transport.Request)
@@ -266,6 +289,7 @@
client_id=self.CLIENT_ID,
client_secret=self.CLIENT_SECRET,
scopes=scopes,
+ rapt_token=self.RAPT_TOKEN,
)
# Refresh credentials
@@ -279,6 +303,7 @@
self.CLIENT_ID,
self.CLIENT_SECRET,
scopes,
+ self.RAPT_TOKEN,
)
# Check that the credentials have the token and expiry
@@ -286,12 +311,13 @@
assert creds.expiry == expiry
assert creds.id_token == mock.sentinel.id_token
assert creds.has_scopes(scopes)
+ assert creds.rapt_token == new_rapt_token
# Check that the credentials are valid (have a token and are not
# expired.)
assert creds.valid
- @mock.patch("google.oauth2._client.refresh_grant", autospec=True)
+ @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True)
@mock.patch(
"google.auth._helpers.utcnow",
return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
@@ -302,10 +328,11 @@
scopes = ["email", "profile"]
scopes_returned = ["email"]
token = "token"
+ new_rapt_token = "new_rapt_token"
expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
grant_response = {
"id_token": mock.sentinel.id_token,
- "scopes": " ".join(scopes_returned),
+ "scope": " ".join(scopes_returned),
}
refresh_grant.return_value = (
# Access token
@@ -316,6 +343,8 @@
expiry,
# Extra data
grant_response,
+ # rapt token
+ new_rapt_token,
)
request = mock.create_autospec(transport.Request)
@@ -326,6 +355,7 @@
client_id=self.CLIENT_ID,
client_secret=self.CLIENT_SECRET,
scopes=scopes,
+ rapt_token=self.RAPT_TOKEN,
)
# Refresh credentials
@@ -342,6 +372,7 @@
self.CLIENT_ID,
self.CLIENT_SECRET,
scopes,
+ self.RAPT_TOKEN,
)
# Check that the credentials have the token and expiry
@@ -349,6 +380,7 @@
assert creds.expiry == expiry
assert creds.id_token == mock.sentinel.id_token
assert creds.has_scopes(scopes)
+ assert creds.rapt_token == new_rapt_token
# Check that the credentials are valid (have a token and are not
# expired.)
diff --git a/tests/oauth2/test_reauth.py b/tests/oauth2/test_reauth.py
new file mode 100644
index 0000000..e9ffa8a
--- /dev/null
+++ b/tests/oauth2/test_reauth.py
@@ -0,0 +1,308 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+
+import mock
+import pytest
+
+from google.auth import exceptions
+from google.oauth2 import reauth
+
+
+MOCK_REQUEST = mock.Mock()
+CHALLENGES_RESPONSE_TEMPLATE = {
+ "status": "CHALLENGE_REQUIRED",
+ "sessionId": "123",
+ "challenges": [
+ {
+ "status": "READY",
+ "challengeId": 1,
+ "challengeType": "PASSWORD",
+ "securityKey": {},
+ }
+ ],
+}
+CHALLENGES_RESPONSE_AUTHENTICATED = {
+ "status": "AUTHENTICATED",
+ "sessionId": "123",
+ "encodedProofOfReauthToken": "new_rapt_token",
+}
+
+
+class MockChallenge(object):
+ def __init__(self, name, locally_eligible, challenge_input):
+ self.name = name
+ self.is_locally_eligible = locally_eligible
+ self.challenge_input = challenge_input
+
+ def obtain_challenge_input(self, metadata):
+ return self.challenge_input
+
+
+def test_is_interactive():
+ with mock.patch("sys.stdin.isatty", return_value=True):
+ assert reauth.is_interactive()
+
+
+def test__get_challenges():
+ with mock.patch(
+ "google.oauth2._client._token_endpoint_request"
+ ) as mock_token_endpoint_request:
+ reauth._get_challenges(MOCK_REQUEST, ["SAML"], "token")
+ mock_token_endpoint_request.assert_called_with(
+ MOCK_REQUEST,
+ reauth._REAUTH_API + ":start",
+ {"supportedChallengeTypes": ["SAML"]},
+ access_token="token",
+ use_json=True,
+ )
+
+
+def test__get_challenges_with_scopes():
+ with mock.patch(
+ "google.oauth2._client._token_endpoint_request"
+ ) as mock_token_endpoint_request:
+ reauth._get_challenges(
+ MOCK_REQUEST, ["SAML"], "token", requested_scopes=["scope"]
+ )
+ mock_token_endpoint_request.assert_called_with(
+ MOCK_REQUEST,
+ reauth._REAUTH_API + ":start",
+ {
+ "supportedChallengeTypes": ["SAML"],
+ "oauthScopesForDomainPolicyLookup": ["scope"],
+ },
+ access_token="token",
+ use_json=True,
+ )
+
+
+def test__send_challenge_result():
+ with mock.patch(
+ "google.oauth2._client._token_endpoint_request"
+ ) as mock_token_endpoint_request:
+ reauth._send_challenge_result(
+ MOCK_REQUEST, "123", "1", {"credential": "password"}, "token"
+ )
+ mock_token_endpoint_request.assert_called_with(
+ MOCK_REQUEST,
+ reauth._REAUTH_API + "/123:continue",
+ {
+ "sessionId": "123",
+ "challengeId": "1",
+ "action": "RESPOND",
+ "proposalResponse": {"credential": "password"},
+ },
+ access_token="token",
+ use_json=True,
+ )
+
+
+def test__run_next_challenge_not_ready():
+ challenges_response = copy.deepcopy(CHALLENGES_RESPONSE_TEMPLATE)
+ challenges_response["challenges"][0]["status"] = "STATUS_UNSPECIFIED"
+ assert (
+ reauth._run_next_challenge(challenges_response, MOCK_REQUEST, "token") is None
+ )
+
+
+def test__run_next_challenge_not_supported():
+ challenges_response = copy.deepcopy(CHALLENGES_RESPONSE_TEMPLATE)
+ challenges_response["challenges"][0]["challengeType"] = "CHALLENGE_TYPE_UNSPECIFIED"
+ with pytest.raises(exceptions.ReauthFailError) as excinfo:
+ reauth._run_next_challenge(challenges_response, MOCK_REQUEST, "token")
+ assert excinfo.match(r"Unsupported challenge type CHALLENGE_TYPE_UNSPECIFIED")
+
+
+def test__run_next_challenge_not_locally_eligible():
+ mock_challenge = MockChallenge("PASSWORD", False, "challenge_input")
+ with mock.patch(
+ "google.oauth2.challenges.AVAILABLE_CHALLENGES", {"PASSWORD": mock_challenge}
+ ):
+ with pytest.raises(exceptions.ReauthFailError) as excinfo:
+ reauth._run_next_challenge(
+ CHALLENGES_RESPONSE_TEMPLATE, MOCK_REQUEST, "token"
+ )
+ assert excinfo.match(r"Challenge PASSWORD is not locally eligible")
+
+
+def test__run_next_challenge_no_challenge_input():
+ mock_challenge = MockChallenge("PASSWORD", True, None)
+ with mock.patch(
+ "google.oauth2.challenges.AVAILABLE_CHALLENGES", {"PASSWORD": mock_challenge}
+ ):
+ assert (
+ reauth._run_next_challenge(
+ CHALLENGES_RESPONSE_TEMPLATE, MOCK_REQUEST, "token"
+ )
+ is None
+ )
+
+
+def test__run_next_challenge_success():
+ mock_challenge = MockChallenge("PASSWORD", True, {"credential": "password"})
+ with mock.patch(
+ "google.oauth2.challenges.AVAILABLE_CHALLENGES", {"PASSWORD": mock_challenge}
+ ):
+ with mock.patch(
+ "google.oauth2.reauth._send_challenge_result"
+ ) as mock_send_challenge_result:
+ reauth._run_next_challenge(
+ CHALLENGES_RESPONSE_TEMPLATE, MOCK_REQUEST, "token"
+ )
+ mock_send_challenge_result.assert_called_with(
+ MOCK_REQUEST, "123", 1, {"credential": "password"}, "token"
+ )
+
+
+def test__obtain_rapt_authenticated():
+ with mock.patch(
+ "google.oauth2.reauth._get_challenges",
+ return_value=CHALLENGES_RESPONSE_AUTHENTICATED,
+ ):
+ assert reauth._obtain_rapt(MOCK_REQUEST, "token", None) == "new_rapt_token"
+
+
+def test__obtain_rapt_authenticated_after_run_next_challenge():
+ with mock.patch(
+ "google.oauth2.reauth._get_challenges",
+ return_value=CHALLENGES_RESPONSE_TEMPLATE,
+ ):
+ with mock.patch(
+ "google.oauth2.reauth._run_next_challenge",
+ side_effect=[
+ CHALLENGES_RESPONSE_TEMPLATE,
+ CHALLENGES_RESPONSE_AUTHENTICATED,
+ ],
+ ):
+ with mock.patch("google.oauth2.reauth.is_interactive", return_value=True):
+ assert (
+ reauth._obtain_rapt(MOCK_REQUEST, "token", None) == "new_rapt_token"
+ )
+
+
+def test__obtain_rapt_unsupported_status():
+ challenges_response = copy.deepcopy(CHALLENGES_RESPONSE_TEMPLATE)
+ challenges_response["status"] = "STATUS_UNSPECIFIED"
+ with mock.patch(
+ "google.oauth2.reauth._get_challenges", return_value=challenges_response
+ ):
+ with pytest.raises(exceptions.ReauthFailError) as excinfo:
+ reauth._obtain_rapt(MOCK_REQUEST, "token", None)
+ assert excinfo.match(r"API error: STATUS_UNSPECIFIED")
+
+
+def test__obtain_rapt_not_interactive():
+ with mock.patch(
+ "google.oauth2.reauth._get_challenges",
+ return_value=CHALLENGES_RESPONSE_TEMPLATE,
+ ):
+ with mock.patch("google.oauth2.reauth.is_interactive", return_value=False):
+ with pytest.raises(exceptions.ReauthFailError) as excinfo:
+ reauth._obtain_rapt(MOCK_REQUEST, "token", None)
+ assert excinfo.match(r"not in an interactive session")
+
+
+def test__obtain_rapt_not_authenticated():
+ with mock.patch(
+ "google.oauth2.reauth._get_challenges",
+ return_value=CHALLENGES_RESPONSE_TEMPLATE,
+ ):
+ with mock.patch("google.oauth2.reauth.RUN_CHALLENGE_RETRY_LIMIT", 0):
+ with pytest.raises(exceptions.ReauthFailError) as excinfo:
+ reauth._obtain_rapt(MOCK_REQUEST, "token", None)
+ assert excinfo.match(r"Reauthentication failed")
+
+
+def test_get_rapt_token():
+ with mock.patch(
+ "google.oauth2._client.refresh_grant", return_value=("token", None, None, None)
+ ) as mock_refresh_grant:
+ with mock.patch(
+ "google.oauth2.reauth._obtain_rapt", return_value="new_rapt_token"
+ ) as mock_obtain_rapt:
+ assert (
+ reauth.get_rapt_token(
+ MOCK_REQUEST,
+ "client_id",
+ "client_secret",
+ "refresh_token",
+ "token_uri",
+ )
+ == "new_rapt_token"
+ )
+ mock_refresh_grant.assert_called_with(
+ request=MOCK_REQUEST,
+ client_id="client_id",
+ client_secret="client_secret",
+ refresh_token="refresh_token",
+ token_uri="token_uri",
+ scopes=[reauth._REAUTH_SCOPE],
+ )
+ mock_obtain_rapt.assert_called_with(
+ MOCK_REQUEST, "token", requested_scopes=None
+ )
+
+
+def test_refresh_grant_failed():
+ with mock.patch(
+ "google.oauth2._client._token_endpoint_request_no_throw"
+ ) as mock_token_request:
+ mock_token_request.return_value = (False, {"error": "Bad request"})
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ reauth.refresh_grant(
+ MOCK_REQUEST,
+ "token_uri",
+ "refresh_token",
+ "client_id",
+ "client_secret",
+ scopes=["foo", "bar"],
+ rapt_token="rapt_token",
+ )
+ assert excinfo.match(r"Bad request")
+ mock_token_request.assert_called_with(
+ MOCK_REQUEST,
+ "token_uri",
+ {
+ "grant_type": "refresh_token",
+ "client_id": "client_id",
+ "client_secret": "client_secret",
+ "refresh_token": "refresh_token",
+ "scope": "foo bar",
+ "rapt": "rapt_token",
+ },
+ )
+
+
+def test_refresh_grant_success():
+ with mock.patch(
+ "google.oauth2._client._token_endpoint_request_no_throw"
+ ) as mock_token_request:
+ mock_token_request.side_effect = [
+ (False, {"error": "invalid_grant", "error_subtype": "rapt_required"}),
+ (True, {"access_token": "access_token"}),
+ ]
+ with mock.patch(
+ "google.oauth2.reauth.get_rapt_token", return_value="new_rapt_token"
+ ):
+ assert reauth.refresh_grant(
+ MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret"
+ ) == (
+ "access_token",
+ "refresh_token",
+ None,
+ {"access_token": "access_token"},
+ "new_rapt_token",
+ )