feat: support refresh callable on google.oauth2.credentials.Credentials (#812)
This is an optional parameter that can be set via the constructor.
It is used to provide the credentials with new tokens and their
expiration time on `refresh()` call.
```
def refresh_handler(request, scopes):
# Generate a new token for the requested scopes by calling
# an external process.
return (
"ACCESS_TOKEN",
_helpers.utcnow() + datetime.timedelta(seconds=3600))
creds = google.oauth2.credentials.Credentials(
scopes=scopes,
refresh_handler=refresh_handler)
creds.refresh(request)
```
It is useful in the following cases:
- Useful in general when tokens are obtained by calling some
external process on demand.
- Useful in particular for retrieving downscoped tokens from a
token broker.
This should have no impact on existing behavior. Refresh tokens
will still have higher priority over refresh handlers.
A getter and setter is exposed to make it easy to set the callable
on unpickled credentials as the callable may not be easily serialized.
```
unpickled = pickle.loads(pickle.dumps(oauth_creds))
unpickled.refresh_handler = refresh_handler
```
diff --git a/tests/oauth2/test_credentials.py b/tests/oauth2/test_credentials.py
index 4a387a5..4a7f66e 100644
--- a/tests/oauth2/test_credentials.py
+++ b/tests/oauth2/test_credentials.py
@@ -66,6 +66,50 @@
assert credentials.client_id == self.CLIENT_ID
assert credentials.client_secret == self.CLIENT_SECRET
assert credentials.rapt_token == self.RAPT_TOKEN
+ assert credentials.refresh_handler is None
+
+ def test_refresh_handler_setter_and_getter(self):
+ scopes = ["email", "profile"]
+ original_refresh_handler = mock.Mock(return_value=("ACCESS_TOKEN_1", None))
+ updated_refresh_handler = mock.Mock(return_value=("ACCESS_TOKEN_2", None))
+ creds = credentials.Credentials(
+ token=None,
+ refresh_token=None,
+ token_uri=None,
+ client_id=None,
+ client_secret=None,
+ rapt_token=None,
+ scopes=scopes,
+ default_scopes=None,
+ refresh_handler=original_refresh_handler,
+ )
+
+ assert creds.refresh_handler is original_refresh_handler
+
+ creds.refresh_handler = updated_refresh_handler
+
+ assert creds.refresh_handler is updated_refresh_handler
+
+ creds.refresh_handler = None
+
+ assert creds.refresh_handler is None
+
+ def test_invalid_refresh_handler(self):
+ scopes = ["email", "profile"]
+ with pytest.raises(TypeError) as excinfo:
+ credentials.Credentials(
+ token=None,
+ refresh_token=None,
+ token_uri=None,
+ client_id=None,
+ client_secret=None,
+ rapt_token=None,
+ scopes=scopes,
+ default_scopes=None,
+ refresh_handler=object(),
+ )
+
+ assert excinfo.match("The provided refresh_handler is not a callable or None.")
@mock.patch("google.oauth2.reauth.refresh_grant", autospec=True)
@mock.patch(
@@ -131,6 +175,221 @@
"google.auth._helpers.utcnow",
return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
)
+ def test_refresh_with_refresh_token_and_refresh_handler(
+ self, unused_utcnow, refresh_grant
+ ):
+ token = "token"
+ new_rapt_token = "new_rapt_token"
+ expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
+ grant_response = {"id_token": mock.sentinel.id_token}
+ refresh_grant.return_value = (
+ # Access token
+ token,
+ # New refresh token
+ None,
+ # Expiry,
+ expiry,
+ # Extra data
+ grant_response,
+ # rapt_token
+ new_rapt_token,
+ )
+
+ refresh_handler = mock.Mock()
+ request = mock.create_autospec(transport.Request)
+ creds = credentials.Credentials(
+ token=None,
+ refresh_token=self.REFRESH_TOKEN,
+ token_uri=self.TOKEN_URI,
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ rapt_token=self.RAPT_TOKEN,
+ refresh_handler=refresh_handler,
+ )
+
+ # Refresh credentials
+ creds.refresh(request)
+
+ # Check jwt grant call.
+ refresh_grant.assert_called_with(
+ request,
+ self.TOKEN_URI,
+ self.REFRESH_TOKEN,
+ self.CLIENT_ID,
+ self.CLIENT_SECRET,
+ None,
+ self.RAPT_TOKEN,
+ )
+
+ # Check that the credentials have the token and expiry
+ assert creds.token == token
+ assert creds.expiry == expiry
+ assert creds.id_token == mock.sentinel.id_token
+ assert creds.rapt_token == new_rapt_token
+
+ # Check that the credentials are valid (have a token and are not
+ # expired)
+ assert creds.valid
+
+ # Assert refresh handler not called as the refresh token has
+ # higher priority.
+ refresh_handler.assert_not_called()
+
+ @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
+ def test_refresh_with_refresh_handler_success_scopes(self, unused_utcnow):
+ expected_expiry = datetime.datetime.min + datetime.timedelta(seconds=2800)
+ refresh_handler = mock.Mock(return_value=("ACCESS_TOKEN", expected_expiry))
+ scopes = ["email", "profile"]
+ default_scopes = ["https://www.googleapis.com/auth/cloud-platform"]
+ request = mock.create_autospec(transport.Request)
+ creds = credentials.Credentials(
+ token=None,
+ refresh_token=None,
+ token_uri=None,
+ client_id=None,
+ client_secret=None,
+ rapt_token=None,
+ scopes=scopes,
+ default_scopes=default_scopes,
+ refresh_handler=refresh_handler,
+ )
+
+ creds.refresh(request)
+
+ assert creds.token == "ACCESS_TOKEN"
+ assert creds.expiry == expected_expiry
+ assert creds.valid
+ assert not creds.expired
+ # Confirm refresh handler called with the expected arguments.
+ refresh_handler.assert_called_with(request, scopes=scopes)
+
+ @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
+ def test_refresh_with_refresh_handler_success_default_scopes(self, unused_utcnow):
+ expected_expiry = datetime.datetime.min + datetime.timedelta(seconds=2800)
+ original_refresh_handler = mock.Mock(
+ return_value=("UNUSED_TOKEN", expected_expiry)
+ )
+ refresh_handler = mock.Mock(return_value=("ACCESS_TOKEN", expected_expiry))
+ default_scopes = ["https://www.googleapis.com/auth/cloud-platform"]
+ request = mock.create_autospec(transport.Request)
+ creds = credentials.Credentials(
+ token=None,
+ refresh_token=None,
+ token_uri=None,
+ client_id=None,
+ client_secret=None,
+ rapt_token=None,
+ scopes=None,
+ default_scopes=default_scopes,
+ refresh_handler=original_refresh_handler,
+ )
+
+ # Test newly set refresh_handler is used instead of the original one.
+ creds.refresh_handler = refresh_handler
+ creds.refresh(request)
+
+ assert creds.token == "ACCESS_TOKEN"
+ assert creds.expiry == expected_expiry
+ assert creds.valid
+ assert not creds.expired
+ # default_scopes should be used since no developer provided scopes
+ # are provided.
+ refresh_handler.assert_called_with(request, scopes=default_scopes)
+
+ @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
+ def test_refresh_with_refresh_handler_invalid_token(self, unused_utcnow):
+ expected_expiry = datetime.datetime.min + datetime.timedelta(seconds=2800)
+ # Simulate refresh handler does not return a valid token.
+ refresh_handler = mock.Mock(return_value=(None, expected_expiry))
+ scopes = ["email", "profile"]
+ default_scopes = ["https://www.googleapis.com/auth/cloud-platform"]
+ request = mock.create_autospec(transport.Request)
+ creds = credentials.Credentials(
+ token=None,
+ refresh_token=None,
+ token_uri=None,
+ client_id=None,
+ client_secret=None,
+ rapt_token=None,
+ scopes=scopes,
+ default_scopes=default_scopes,
+ refresh_handler=refresh_handler,
+ )
+
+ with pytest.raises(
+ exceptions.RefreshError, match="returned token is not a string"
+ ):
+ creds.refresh(request)
+
+ assert creds.token is None
+ assert creds.expiry is None
+ assert not creds.valid
+ # Confirm refresh handler called with the expected arguments.
+ refresh_handler.assert_called_with(request, scopes=scopes)
+
+ def test_refresh_with_refresh_handler_invalid_expiry(self):
+ # Simulate refresh handler returns expiration time in an invalid unit.
+ refresh_handler = mock.Mock(return_value=("TOKEN", 2800))
+ scopes = ["email", "profile"]
+ default_scopes = ["https://www.googleapis.com/auth/cloud-platform"]
+ request = mock.create_autospec(transport.Request)
+ creds = credentials.Credentials(
+ token=None,
+ refresh_token=None,
+ token_uri=None,
+ client_id=None,
+ client_secret=None,
+ rapt_token=None,
+ scopes=scopes,
+ default_scopes=default_scopes,
+ refresh_handler=refresh_handler,
+ )
+
+ with pytest.raises(
+ exceptions.RefreshError, match="returned expiry is not a datetime object"
+ ):
+ creds.refresh(request)
+
+ assert creds.token is None
+ assert creds.expiry is None
+ assert not creds.valid
+ # Confirm refresh handler called with the expected arguments.
+ refresh_handler.assert_called_with(request, scopes=scopes)
+
+ @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
+ def test_refresh_with_refresh_handler_expired_token(self, unused_utcnow):
+ expected_expiry = datetime.datetime.min + _helpers.CLOCK_SKEW
+ # Simulate refresh handler returns an expired token.
+ refresh_handler = mock.Mock(return_value=("TOKEN", expected_expiry))
+ scopes = ["email", "profile"]
+ default_scopes = ["https://www.googleapis.com/auth/cloud-platform"]
+ request = mock.create_autospec(transport.Request)
+ creds = credentials.Credentials(
+ token=None,
+ refresh_token=None,
+ token_uri=None,
+ client_id=None,
+ client_secret=None,
+ rapt_token=None,
+ scopes=scopes,
+ default_scopes=default_scopes,
+ refresh_handler=refresh_handler,
+ )
+
+ with pytest.raises(exceptions.RefreshError, match="already expired"):
+ creds.refresh(request)
+
+ assert creds.token is None
+ assert creds.expiry is None
+ assert not creds.valid
+ # Confirm refresh handler called with the expected arguments.
+ refresh_handler.assert_called_with(request, scopes=scopes)
+
+ @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True)
+ @mock.patch(
+ "google.auth._helpers.utcnow",
+ return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
+ )
def test_credentials_with_scopes_requested_refresh_success(
self, unused_utcnow, refresh_grant
):
@@ -527,6 +786,32 @@
for attr in list(creds.__dict__):
assert getattr(creds, attr) == getattr(unpickled, attr)
+ def test_pickle_and_unpickle_with_refresh_handler(self):
+ expected_expiry = _helpers.utcnow() + datetime.timedelta(seconds=2800)
+ refresh_handler = mock.Mock(return_value=("TOKEN", expected_expiry))
+
+ creds = credentials.Credentials(
+ token=None,
+ refresh_token=None,
+ token_uri=None,
+ client_id=None,
+ client_secret=None,
+ rapt_token=None,
+ refresh_handler=refresh_handler,
+ )
+ unpickled = pickle.loads(pickle.dumps(creds))
+
+ # make sure attributes aren't lost during pickling
+ assert list(creds.__dict__).sort() == list(unpickled.__dict__).sort()
+
+ for attr in list(creds.__dict__):
+ # For the _refresh_handler property, the unpickled creds should be
+ # set to None.
+ if attr == "_refresh_handler":
+ assert getattr(unpickled, attr) is None
+ else:
+ assert getattr(creds, attr) == getattr(unpickled, attr)
+
def test_pickle_with_missing_attribute(self):
creds = self.make_credentials()