feat: support refresh callable on google.oauth2.credentials.Credentials (#812)

This is an optional parameter that can be set via the constructor.
It is used to provide the credentials with new tokens and their
expiration time on `refresh()` call.

```
def refresh_handler(request, scopes):
    # Generate a new token for the requested scopes by calling
    # an external process.
    return (
        "ACCESS_TOKEN",
        _helpers.utcnow() + datetime.timedelta(seconds=3600))

creds = google.oauth2.credentials.Credentials(
    scopes=scopes,
    refresh_handler=refresh_handler)
creds.refresh(request)
```

It is useful in the following cases:
- Useful in general when tokens are obtained by calling some
  external process on demand.
- Useful in particular for retrieving downscoped tokens from a
  token broker.

This should have no impact on existing behavior. Refresh tokens
will still have higher priority over refresh handlers.

A getter and setter is exposed to make it easy to set the callable
on unpickled credentials as the callable may not be easily serialized.

```
unpickled = pickle.loads(pickle.dumps(oauth_creds))
unpickled.refresh_handler = refresh_handler
```
diff --git a/google/oauth2/credentials.py b/google/oauth2/credentials.py
index dcfa5f9..158249e 100644
--- a/google/oauth2/credentials.py
+++ b/google/oauth2/credentials.py
@@ -74,6 +74,7 @@
         quota_project_id=None,
         expiry=None,
         rapt_token=None,
+        refresh_handler=None,
     ):
         """
         Args:
@@ -103,6 +104,13 @@
                 This project may be different from the project used to
                 create the credentials.
             rapt_token (Optional[str]): The reauth Proof Token.
+            refresh_handler (Optional[Callable[[google.auth.transport.Request, Sequence[str]], [str, datetime]]]):
+                A callable which takes in the HTTP request callable and the list of
+                OAuth scopes and when called returns an access token string for the
+                requested scopes and its expiry datetime. This is useful when no
+                refresh tokens are provided and tokens are obtained by calling
+                some external process on demand. It is particularly useful for
+                retrieving downscoped tokens from a token broker.
         """
         super(Credentials, self).__init__()
         self.token = token
@@ -116,13 +124,20 @@
         self._client_secret = client_secret
         self._quota_project_id = quota_project_id
         self._rapt_token = rapt_token
+        self.refresh_handler = refresh_handler
 
     def __getstate__(self):
         """A __getstate__ method must exist for the __setstate__ to be called
         This is identical to the default implementation.
         See https://docs.python.org/3.7/library/pickle.html#object.__setstate__
         """
-        return self.__dict__
+        state_dict = self.__dict__.copy()
+        # Remove _refresh_handler function as there are limitations pickling and
+        # unpickling certain callables (lambda, functools.partial instances)
+        # because they need to be importable.
+        # Instead, the refresh_handler setter should be used to repopulate this.
+        del state_dict["_refresh_handler"]
+        return state_dict
 
     def __setstate__(self, d):
         """Credentials pickled with older versions of the class do not have
@@ -138,6 +153,8 @@
         self._client_secret = d.get("_client_secret")
         self._quota_project_id = d.get("_quota_project_id")
         self._rapt_token = d.get("_rapt_token")
+        # The refresh_handler setter should be used to repopulate this.
+        self._refresh_handler = None
 
     @property
     def refresh_token(self):
@@ -187,6 +204,31 @@
         """Optional[str]: The reauth Proof Token."""
         return self._rapt_token
 
+    @property
+    def refresh_handler(self):
+        """Returns the refresh handler if available.
+
+        Returns:
+           Optional[Callable[[google.auth.transport.Request, Sequence[str]], [str, datetime]]]:
+               The current refresh handler.
+        """
+        return self._refresh_handler
+
+    @refresh_handler.setter
+    def refresh_handler(self, value):
+        """Updates the current refresh handler.
+
+        Args:
+            value (Optional[Callable[[google.auth.transport.Request, Sequence[str]], [str, datetime]]]):
+                The updated value of the refresh handler.
+
+        Raises:
+            TypeError: If the value is not a callable or None.
+        """
+        if not callable(value) and value is not None:
+            raise TypeError("The provided refresh_handler is not a callable or None.")
+        self._refresh_handler = value
+
     @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject)
     def with_quota_project(self, quota_project_id):
 
@@ -205,6 +247,31 @@
 
     @_helpers.copy_docstring(credentials.Credentials)
     def refresh(self, request):
+        scopes = self._scopes if self._scopes is not None else self._default_scopes
+        # Use refresh handler if available and no refresh token is
+        # available. This is useful in general when tokens are obtained by calling
+        # some external process on demand. It is particularly useful for retrieving
+        # downscoped tokens from a token broker.
+        if self._refresh_token is None and self.refresh_handler:
+            token, expiry = self.refresh_handler(request, scopes=scopes)
+            # Validate returned data.
+            if not isinstance(token, str):
+                raise exceptions.RefreshError(
+                    "The refresh_handler returned token is not a string."
+                )
+            if not isinstance(expiry, datetime):
+                raise exceptions.RefreshError(
+                    "The refresh_handler returned expiry is not a datetime object."
+                )
+            if _helpers.utcnow() >= expiry - _helpers.CLOCK_SKEW:
+                raise exceptions.RefreshError(
+                    "The credentials returned by the refresh_handler are "
+                    "already expired."
+                )
+            self.token = token
+            self.expiry = expiry
+            return
+
         if (
             self._refresh_token is None
             or self._token_uri is None
@@ -217,8 +284,6 @@
                 "token_uri, client_id, and client_secret."
             )
 
-        scopes = self._scopes if self._scopes is not None else self._default_scopes
-
         (
             access_token,
             refresh_token,
diff --git a/tests/oauth2/test_credentials.py b/tests/oauth2/test_credentials.py
index 4a387a5..4a7f66e 100644
--- a/tests/oauth2/test_credentials.py
+++ b/tests/oauth2/test_credentials.py
@@ -66,6 +66,50 @@
         assert credentials.client_id == self.CLIENT_ID
         assert credentials.client_secret == self.CLIENT_SECRET
         assert credentials.rapt_token == self.RAPT_TOKEN
+        assert credentials.refresh_handler is None
+
+    def test_refresh_handler_setter_and_getter(self):
+        scopes = ["email", "profile"]
+        original_refresh_handler = mock.Mock(return_value=("ACCESS_TOKEN_1", None))
+        updated_refresh_handler = mock.Mock(return_value=("ACCESS_TOKEN_2", None))
+        creds = credentials.Credentials(
+            token=None,
+            refresh_token=None,
+            token_uri=None,
+            client_id=None,
+            client_secret=None,
+            rapt_token=None,
+            scopes=scopes,
+            default_scopes=None,
+            refresh_handler=original_refresh_handler,
+        )
+
+        assert creds.refresh_handler is original_refresh_handler
+
+        creds.refresh_handler = updated_refresh_handler
+
+        assert creds.refresh_handler is updated_refresh_handler
+
+        creds.refresh_handler = None
+
+        assert creds.refresh_handler is None
+
+    def test_invalid_refresh_handler(self):
+        scopes = ["email", "profile"]
+        with pytest.raises(TypeError) as excinfo:
+            credentials.Credentials(
+                token=None,
+                refresh_token=None,
+                token_uri=None,
+                client_id=None,
+                client_secret=None,
+                rapt_token=None,
+                scopes=scopes,
+                default_scopes=None,
+                refresh_handler=object(),
+            )
+
+        assert excinfo.match("The provided refresh_handler is not a callable or None.")
 
     @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True)
     @mock.patch(
@@ -131,6 +175,221 @@
         "google.auth._helpers.utcnow",
         return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
     )
+    def test_refresh_with_refresh_token_and_refresh_handler(
+        self, unused_utcnow, refresh_grant
+    ):
+        token = "token"
+        new_rapt_token = "new_rapt_token"
+        expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
+        grant_response = {"id_token": mock.sentinel.id_token}
+        refresh_grant.return_value = (
+            # Access token
+            token,
+            # New refresh token
+            None,
+            # Expiry,
+            expiry,
+            # Extra data
+            grant_response,
+            # rapt_token
+            new_rapt_token,
+        )
+
+        refresh_handler = mock.Mock()
+        request = mock.create_autospec(transport.Request)
+        creds = credentials.Credentials(
+            token=None,
+            refresh_token=self.REFRESH_TOKEN,
+            token_uri=self.TOKEN_URI,
+            client_id=self.CLIENT_ID,
+            client_secret=self.CLIENT_SECRET,
+            rapt_token=self.RAPT_TOKEN,
+            refresh_handler=refresh_handler,
+        )
+
+        # Refresh credentials
+        creds.refresh(request)
+
+        # Check jwt grant call.
+        refresh_grant.assert_called_with(
+            request,
+            self.TOKEN_URI,
+            self.REFRESH_TOKEN,
+            self.CLIENT_ID,
+            self.CLIENT_SECRET,
+            None,
+            self.RAPT_TOKEN,
+        )
+
+        # Check that the credentials have the token and expiry
+        assert creds.token == token
+        assert creds.expiry == expiry
+        assert creds.id_token == mock.sentinel.id_token
+        assert creds.rapt_token == new_rapt_token
+
+        # Check that the credentials are valid (have a token and are not
+        # expired)
+        assert creds.valid
+
+        # Assert refresh handler not called as the refresh token has
+        # higher priority.
+        refresh_handler.assert_not_called()
+
+    @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
+    def test_refresh_with_refresh_handler_success_scopes(self, unused_utcnow):
+        expected_expiry = datetime.datetime.min + datetime.timedelta(seconds=2800)
+        refresh_handler = mock.Mock(return_value=("ACCESS_TOKEN", expected_expiry))
+        scopes = ["email", "profile"]
+        default_scopes = ["https://www.googleapis.com/auth/cloud-platform"]
+        request = mock.create_autospec(transport.Request)
+        creds = credentials.Credentials(
+            token=None,
+            refresh_token=None,
+            token_uri=None,
+            client_id=None,
+            client_secret=None,
+            rapt_token=None,
+            scopes=scopes,
+            default_scopes=default_scopes,
+            refresh_handler=refresh_handler,
+        )
+
+        creds.refresh(request)
+
+        assert creds.token == "ACCESS_TOKEN"
+        assert creds.expiry == expected_expiry
+        assert creds.valid
+        assert not creds.expired
+        # Confirm refresh handler called with the expected arguments.
+        refresh_handler.assert_called_with(request, scopes=scopes)
+
+    @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
+    def test_refresh_with_refresh_handler_success_default_scopes(self, unused_utcnow):
+        expected_expiry = datetime.datetime.min + datetime.timedelta(seconds=2800)
+        original_refresh_handler = mock.Mock(
+            return_value=("UNUSED_TOKEN", expected_expiry)
+        )
+        refresh_handler = mock.Mock(return_value=("ACCESS_TOKEN", expected_expiry))
+        default_scopes = ["https://www.googleapis.com/auth/cloud-platform"]
+        request = mock.create_autospec(transport.Request)
+        creds = credentials.Credentials(
+            token=None,
+            refresh_token=None,
+            token_uri=None,
+            client_id=None,
+            client_secret=None,
+            rapt_token=None,
+            scopes=None,
+            default_scopes=default_scopes,
+            refresh_handler=original_refresh_handler,
+        )
+
+        # Test newly set refresh_handler is used instead of the original one.
+        creds.refresh_handler = refresh_handler
+        creds.refresh(request)
+
+        assert creds.token == "ACCESS_TOKEN"
+        assert creds.expiry == expected_expiry
+        assert creds.valid
+        assert not creds.expired
+        # default_scopes should be used since no developer provided scopes
+        # are provided.
+        refresh_handler.assert_called_with(request, scopes=default_scopes)
+
+    @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
+    def test_refresh_with_refresh_handler_invalid_token(self, unused_utcnow):
+        expected_expiry = datetime.datetime.min + datetime.timedelta(seconds=2800)
+        # Simulate refresh handler does not return a valid token.
+        refresh_handler = mock.Mock(return_value=(None, expected_expiry))
+        scopes = ["email", "profile"]
+        default_scopes = ["https://www.googleapis.com/auth/cloud-platform"]
+        request = mock.create_autospec(transport.Request)
+        creds = credentials.Credentials(
+            token=None,
+            refresh_token=None,
+            token_uri=None,
+            client_id=None,
+            client_secret=None,
+            rapt_token=None,
+            scopes=scopes,
+            default_scopes=default_scopes,
+            refresh_handler=refresh_handler,
+        )
+
+        with pytest.raises(
+            exceptions.RefreshError, match="returned token is not a string"
+        ):
+            creds.refresh(request)
+
+        assert creds.token is None
+        assert creds.expiry is None
+        assert not creds.valid
+        # Confirm refresh handler called with the expected arguments.
+        refresh_handler.assert_called_with(request, scopes=scopes)
+
+    def test_refresh_with_refresh_handler_invalid_expiry(self):
+        # Simulate refresh handler returns expiration time in an invalid unit.
+        refresh_handler = mock.Mock(return_value=("TOKEN", 2800))
+        scopes = ["email", "profile"]
+        default_scopes = ["https://www.googleapis.com/auth/cloud-platform"]
+        request = mock.create_autospec(transport.Request)
+        creds = credentials.Credentials(
+            token=None,
+            refresh_token=None,
+            token_uri=None,
+            client_id=None,
+            client_secret=None,
+            rapt_token=None,
+            scopes=scopes,
+            default_scopes=default_scopes,
+            refresh_handler=refresh_handler,
+        )
+
+        with pytest.raises(
+            exceptions.RefreshError, match="returned expiry is not a datetime object"
+        ):
+            creds.refresh(request)
+
+        assert creds.token is None
+        assert creds.expiry is None
+        assert not creds.valid
+        # Confirm refresh handler called with the expected arguments.
+        refresh_handler.assert_called_with(request, scopes=scopes)
+
+    @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
+    def test_refresh_with_refresh_handler_expired_token(self, unused_utcnow):
+        expected_expiry = datetime.datetime.min + _helpers.CLOCK_SKEW
+        # Simulate refresh handler returns an expired token.
+        refresh_handler = mock.Mock(return_value=("TOKEN", expected_expiry))
+        scopes = ["email", "profile"]
+        default_scopes = ["https://www.googleapis.com/auth/cloud-platform"]
+        request = mock.create_autospec(transport.Request)
+        creds = credentials.Credentials(
+            token=None,
+            refresh_token=None,
+            token_uri=None,
+            client_id=None,
+            client_secret=None,
+            rapt_token=None,
+            scopes=scopes,
+            default_scopes=default_scopes,
+            refresh_handler=refresh_handler,
+        )
+
+        with pytest.raises(exceptions.RefreshError, match="already expired"):
+            creds.refresh(request)
+
+        assert creds.token is None
+        assert creds.expiry is None
+        assert not creds.valid
+        # Confirm refresh handler called with the expected arguments.
+        refresh_handler.assert_called_with(request, scopes=scopes)
+
+    @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True)
+    @mock.patch(
+        "google.auth._helpers.utcnow",
+        return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
+    )
     def test_credentials_with_scopes_requested_refresh_success(
         self, unused_utcnow, refresh_grant
     ):
@@ -527,6 +786,32 @@
         for attr in list(creds.__dict__):
             assert getattr(creds, attr) == getattr(unpickled, attr)
 
+    def test_pickle_and_unpickle_with_refresh_handler(self):
+        expected_expiry = _helpers.utcnow() + datetime.timedelta(seconds=2800)
+        refresh_handler = mock.Mock(return_value=("TOKEN", expected_expiry))
+
+        creds = credentials.Credentials(
+            token=None,
+            refresh_token=None,
+            token_uri=None,
+            client_id=None,
+            client_secret=None,
+            rapt_token=None,
+            refresh_handler=refresh_handler,
+        )
+        unpickled = pickle.loads(pickle.dumps(creds))
+
+        # make sure attributes aren't lost during pickling
+        assert list(creds.__dict__).sort() == list(unpickled.__dict__).sort()
+
+        for attr in list(creds.__dict__):
+            # For the _refresh_handler property, the unpickled creds should be
+            # set to None.
+            if attr == "_refresh_handler":
+                assert getattr(unpickled, attr) is None
+            else:
+                assert getattr(creds, attr) == getattr(unpickled, attr)
+
     def test_pickle_with_missing_attribute(self):
         creds = self.make_credentials()