Allow user to pass custom Session object to AuthorizedSession to make requests. (#306)
diff --git a/google/auth/transport/requests.py b/google/auth/transport/requests.py
index 2268243..8250c74 100644
--- a/google/auth/transport/requests.py
+++ b/google/auth/transport/requests.py
@@ -150,33 +150,40 @@
refresh the credentials and retry the request.
refresh_timeout (Optional[int]): The timeout value in seconds for
credential refresh HTTP requests.
- kwargs: Additional arguments passed to the :class:`requests.Session`
- constructor.
+ auth_request (google.auth.transport.requests.Request):
+ (Optional) An instance of
+ :class:`~google.auth.transport.requests.Request` used when
+ refreshing credentials. If not passed,
+ an instance of :class:`~google.auth.transport.requests.Request`
+ is created.
"""
def __init__(self, credentials,
refresh_status_codes=transport.DEFAULT_REFRESH_STATUS_CODES,
max_refresh_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS,
refresh_timeout=None,
- **kwargs):
- super(AuthorizedSession, self).__init__(**kwargs)
+ auth_request=None):
+ super(AuthorizedSession, self).__init__()
self.credentials = credentials
self._refresh_status_codes = refresh_status_codes
self._max_refresh_attempts = max_refresh_attempts
self._refresh_timeout = refresh_timeout
- auth_request_session = requests.Session()
+ if auth_request is None:
+ auth_request_session = requests.Session()
- # Using an adapter to make HTTP requests robust to network errors.
- # This adapter retrys HTTP requests when network errors occur
- # and the requests seems safely retryable.
- retry_adapter = requests.adapters.HTTPAdapter(max_retries=3)
- auth_request_session.mount("https://", retry_adapter)
+ # Using an adapter to make HTTP requests robust to network errors.
+ # This adapter retrys HTTP requests when network errors occur
+ # and the requests seems safely retryable.
+ retry_adapter = requests.adapters.HTTPAdapter(max_retries=3)
+ auth_request_session.mount("https://", retry_adapter)
+
+ # Do not pass `self` as the session here, as it can lead to
+ # infinite recursion.
+ auth_request = Request(auth_request_session)
# Request instance used by internal methods (for example,
# credentials.refresh).
- # Do not pass `self` as the session here, as it can lead to infinite
- # recursion.
- self._auth_request = Request(auth_request_session)
+ self._auth_request = auth_request
def request(self, method, url, data=None, headers=None, **kwargs):
"""Implementation of Requests' request."""
diff --git a/tests/transport/test_requests.py b/tests/transport/test_requests.py
index 41dc237..311992a 100644
--- a/tests/transport/test_requests.py
+++ b/tests/transport/test_requests.py
@@ -85,6 +85,15 @@
assert authed_session.credentials == mock.sentinel.credentials
+ def test_constructor_with_auth_request(self):
+ http = mock.create_autospec(requests.Session)
+ auth_request = google.auth.transport.requests.Request(http)
+
+ authed_session = google.auth.transport.requests.AuthorizedSession(
+ mock.sentinel.credentials, auth_request=auth_request)
+
+ assert authed_session._auth_request == auth_request
+
def test_request_no_refresh(self):
credentials = mock.Mock(wraps=CredentialsStub())
response = make_response()