feat: add reauth feature to user credentials (#727)
* feat: add reauth support to oauth2 credentials
* update
diff --git a/google/oauth2/_client.py b/google/oauth2/_client.py
index 4487163..2f4e847 100644
--- a/google/oauth2/_client.py
+++ b/google/oauth2/_client.py
@@ -35,29 +35,29 @@
from google.auth import jwt
_URLENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
+_JSON_CONTENT_TYPE = "application/json"
_JWT_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
_REFRESH_GRANT_TYPE = "refresh_token"
-def _handle_error_response(response_body):
- """"Translates an error response into an exception.
+def _handle_error_response(response_data):
+ """Translates an error response into an exception.
Args:
- response_body (str): The decoded response data.
+ response_data (Mapping): The decoded response data.
Raises:
- google.auth.exceptions.RefreshError
+ google.auth.exceptions.RefreshError: The errors contained in response_data.
"""
try:
- error_data = json.loads(response_body)
error_details = "{}: {}".format(
- error_data["error"], error_data.get("error_description")
+ response_data["error"], response_data.get("error_description")
)
# If no details could be extracted, use the response data.
except (KeyError, ValueError):
- error_details = response_body
+ error_details = json.dumps(response_data)
- raise exceptions.RefreshError(error_details, response_body)
+ raise exceptions.RefreshError(error_details, response_data)
def _parse_expiry(response_data):
@@ -78,8 +78,11 @@
return None
-def _token_endpoint_request(request, token_uri, body):
+def _token_endpoint_request_no_throw(
+ request, token_uri, body, access_token=None, use_json=False
+):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
+ This function doesn't throw on response errors.
Args:
request (google.auth.transport.Request): A callable used to make
@@ -87,16 +90,23 @@
token_uri (str): The OAuth 2.0 authorizations server's token endpoint
URI.
body (Mapping[str, str]): The parameters to send in the request body.
+ access_token (Optional(str)): The access token needed to make the request.
+ use_json (Optional(bool)): Use urlencoded format or json format for the
+ content type. The default value is False.
Returns:
- Mapping[str, str]: The JSON-decoded response data.
-
- Raises:
- google.auth.exceptions.RefreshError: If the token endpoint returned
- an error.
+ Tuple(bool, Mapping[str, str]): A boolean indicating if the request is
+ successful, and a mapping for the JSON-decoded response data.
"""
- body = urllib.parse.urlencode(body).encode("utf-8")
- headers = {"content-type": _URLENCODED_CONTENT_TYPE}
+ if use_json:
+ headers = {"Content-Type": _JSON_CONTENT_TYPE}
+ body = json.dumps(body).encode("utf-8")
+ else:
+ headers = {"Content-Type": _URLENCODED_CONTENT_TYPE}
+ body = urllib.parse.urlencode(body).encode("utf-8")
+
+ if access_token:
+ headers["Authorization"] = "Bearer {}".format(access_token)
retry = 0
# retry to fetch token for maximum of two times if any internal failure
@@ -121,8 +131,38 @@
):
retry += 1
continue
- _handle_error_response(response_body)
+ return response.status == http_client.OK, response_data
+ return response.status == http_client.OK, response_data
+
+
+def _token_endpoint_request(
+ request, token_uri, body, access_token=None, use_json=False
+):
+ """Makes a request to the OAuth 2.0 authorization server's token endpoint.
+
+ Args:
+ request (google.auth.transport.Request): A callable used to make
+ HTTP requests.
+ token_uri (str): The OAuth 2.0 authorizations server's token endpoint
+ URI.
+ body (Mapping[str, str]): The parameters to send in the request body.
+ access_token (Optional(str)): The access token needed to make the request.
+ use_json (Optional(bool)): Use urlencoded format or json format for the
+ content type. The default value is False.
+
+ Returns:
+ Mapping[str, str]: The JSON-decoded response data.
+
+ Raises:
+ google.auth.exceptions.RefreshError: If the token endpoint returned
+ an error.
+ """
+ response_status_ok, response_data = _token_endpoint_request_no_throw(
+ request, token_uri, body, access_token=access_token, use_json=use_json
+ )
+ if not response_status_ok:
+ _handle_error_response(response_data)
return response_data
@@ -204,8 +244,43 @@
return id_token, expiry, response_data
+def _handle_refresh_grant_response(response_data, refresh_token):
+ """Extract tokens from refresh grant response.
+
+ Args:
+ response_data (Mapping[str, str]): Refresh grant response data.
+ refresh_token (str): Current refresh token.
+
+ Returns:
+ Tuple[str, str, Optional[datetime], Mapping[str, str]]: The access token,
+ refresh token, expiration, and additional data returned by the token
+ endpoint. If response_data doesn't have refresh token, then the current
+ refresh token will be returned.
+
+ Raises:
+ google.auth.exceptions.RefreshError: If the token endpoint returned
+ an error.
+ """
+ try:
+ access_token = response_data["access_token"]
+ except KeyError as caught_exc:
+ new_exc = exceptions.RefreshError("No access token in response.", response_data)
+ six.raise_from(new_exc, caught_exc)
+
+ refresh_token = response_data.get("refresh_token", refresh_token)
+ expiry = _parse_expiry(response_data)
+
+ return access_token, refresh_token, expiry, response_data
+
+
def refresh_grant(
- request, token_uri, refresh_token, client_id, client_secret, scopes=None
+ request,
+ token_uri,
+ refresh_token,
+ client_id,
+ client_secret,
+ scopes=None,
+ rapt_token=None,
):
"""Implements the OAuth 2.0 refresh token grant.
@@ -224,10 +299,11 @@
scopes must be authorized for the refresh token. Useful if refresh
token has a wild card scope (e.g.
'https://www.googleapis.com/auth/any-api').
+ rapt_token (Optional(str)): The reauth Proof Token.
Returns:
- Tuple[str, Optional[str], Optional[datetime], Mapping[str, str]]: The
- access token, new refresh token, expiration, and additional data
+ Tuple[str, str, Optional[datetime], Mapping[str, str]]: The access
+ token, new or current refresh token, expiration, and additional data
returned by the token endpoint.
Raises:
@@ -244,16 +320,8 @@
}
if scopes:
body["scope"] = " ".join(scopes)
+ if rapt_token:
+ body["rapt"] = rapt_token
response_data = _token_endpoint_request(request, token_uri, body)
-
- try:
- access_token = response_data["access_token"]
- except KeyError as caught_exc:
- new_exc = exceptions.RefreshError("No access token in response.", response_data)
- six.raise_from(new_exc, caught_exc)
-
- refresh_token = response_data.get("refresh_token", refresh_token)
- expiry = _parse_expiry(response_data)
-
- return access_token, refresh_token, expiry, response_data
+ return _handle_refresh_grant_response(response_data, refresh_token)