feat: add timeout to AuthorizedSession.request() (#397)
diff --git a/google/auth/transport/requests.py b/google/auth/transport/requests.py
index d1971cd..f21c524 100644
--- a/google/auth/transport/requests.py
+++ b/google/auth/transport/requests.py
@@ -18,6 +18,7 @@
import functools
import logging
+import time
try:
import requests
@@ -64,6 +65,33 @@
return self._response.content
+class TimeoutGuard(object):
+ """A context manager raising an error if the suite execution took too long.
+ """
+
+ def __init__(self, timeout, timeout_error_type=requests.exceptions.Timeout):
+ self._timeout = timeout
+ self.remaining_timeout = timeout
+ self._timeout_error_type = timeout_error_type
+
+ def __enter__(self):
+ self._start = time.time()
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ if exc_value:
+ return # let the error bubble up automatically
+
+ if self._timeout is None:
+ return # nothing to do, the timeout was not specified
+
+ elapsed = time.time() - self._start
+ self.remaining_timeout = self._timeout - elapsed
+
+ if self.remaining_timeout <= 0:
+ raise self._timeout_error_type()
+
+
class Request(transport.Request):
"""Requests request adapter.
@@ -193,8 +221,12 @@
# credentials.refresh).
self._auth_request = auth_request
- def request(self, method, url, data=None, headers=None, **kwargs):
- """Implementation of Requests' request."""
+ def request(self, method, url, data=None, headers=None, timeout=None, **kwargs):
+ """Implementation of Requests' request.
+
+ The ``timeout`` argument is interpreted as the approximate total time
+ of **all** requests that are made under the hood.
+ """
# pylint: disable=arguments-differ
# Requests has a ton of arguments to request, but only two
# (method, url) are required. We pass through all of the other
@@ -208,13 +240,28 @@
# and we want to pass the original headers if we recurse.
request_headers = headers.copy() if headers is not None else {}
- self.credentials.before_request(
- self._auth_request, method, url, request_headers
+ # Do not apply the timeout unconditionally in order to not override the
+ # _auth_request's default timeout.
+ auth_request = (
+ self._auth_request
+ if timeout is None
+ else functools.partial(self._auth_request, timeout=timeout)
)
- response = super(AuthorizedSession, self).request(
- method, url, data=data, headers=request_headers, **kwargs
- )
+ with TimeoutGuard(timeout) as guard:
+ self.credentials.before_request(auth_request, method, url, request_headers)
+ timeout = guard.remaining_timeout
+
+ with TimeoutGuard(timeout) as guard:
+ response = super(AuthorizedSession, self).request(
+ method,
+ url,
+ data=data,
+ headers=request_headers,
+ timeout=timeout,
+ **kwargs
+ )
+ timeout = guard.remaining_timeout
# If the response indicated that the credentials needed to be
# refreshed, then refresh the credentials and re-attempt the
@@ -233,17 +280,33 @@
self._max_refresh_attempts,
)
- auth_request_with_timeout = functools.partial(
- self._auth_request, timeout=self._refresh_timeout
- )
- self.credentials.refresh(auth_request_with_timeout)
+ if self._refresh_timeout is not None:
+ timeout = (
+ self._refresh_timeout
+ if timeout is None
+ else min(timeout, self._refresh_timeout)
+ )
- # Recurse. Pass in the original headers, not our modified set.
+ # Do not apply the timeout unconditionally in order to not override the
+ # _auth_request's default timeout.
+ auth_request = (
+ self._auth_request
+ if timeout is None
+ else functools.partial(self._auth_request, timeout=timeout)
+ )
+
+ with TimeoutGuard(timeout) as guard:
+ self.credentials.refresh(auth_request)
+ timeout = guard.remaining_timeout
+
+ # Recurse. Pass in the original headers, not our modified set, but
+ # do pass the adjusted timeout (i.e. the remaining time).
return self.request(
method,
url,
data=data,
headers=headers,
+ timeout=timeout,
_credential_refresh_attempt=_credential_refresh_attempt + 1,
**kwargs
)
diff --git a/noxfile.py b/noxfile.py
index aaf1bc5..e170ee5 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -16,6 +16,7 @@
TEST_DEPENDENCIES = [
"flask",
+ "freezegun",
"mock",
"oauth2client",
"pytest",
diff --git a/tests/transport/test_requests.py b/tests/transport/test_requests.py
index 0e165ac..252e4a6 100644
--- a/tests/transport/test_requests.py
+++ b/tests/transport/test_requests.py
@@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import datetime
+import functools
+
+import freezegun
import mock
+import pytest
import requests
import requests.adapters
from six.moves import http_client
@@ -22,6 +27,12 @@
from tests.transport import compliance
+@pytest.fixture
+def frozen_time():
+ with freezegun.freeze_time("1970-01-01 00:00:00", tick=False) as frozen:
+ yield frozen
+
+
class TestRequestResponse(compliance.RequestResponseTests):
def make_request(self):
return google.auth.transport.requests.Request()
@@ -34,6 +45,41 @@
assert http.request.call_args[1]["timeout"] == 5
+class TestTimeoutGuard(object):
+ def make_guard(self, *args, **kwargs):
+ return google.auth.transport.requests.TimeoutGuard(*args, **kwargs)
+
+ def test_tracks_elapsed_time(self, frozen_time):
+ with self.make_guard(timeout=10) as guard:
+ frozen_time.tick(delta=3.8)
+ assert guard.remaining_timeout == 6.2
+
+ def test_noop_if_no_timeout(self, frozen_time):
+ with self.make_guard(timeout=None) as guard:
+ frozen_time.tick(delta=datetime.timedelta(days=3650))
+ # NOTE: no timeout error raised, despite years have passed
+ assert guard.remaining_timeout is None
+
+ def test_error_on_timeout(self, frozen_time):
+ with pytest.raises(requests.exceptions.Timeout):
+ with self.make_guard(timeout=10) as guard:
+ frozen_time.tick(delta=10.001)
+ assert guard.remaining_timeout == pytest.approx(-0.001)
+
+ def test_custom_timeout_error_type(self, frozen_time):
+ class FooError(Exception):
+ pass
+
+ with pytest.raises(FooError):
+ with self.make_guard(timeout=1, timeout_error_type=FooError):
+ frozen_time.tick(2)
+
+ def test_lets_errors_bubble_up(self, frozen_time):
+ with pytest.raises(IndexError):
+ with self.make_guard(timeout=1):
+ [1, 2, 3][3]
+
+
class CredentialsStub(google.auth.credentials.Credentials):
def __init__(self, token="token"):
super(CredentialsStub, self).__init__()
@@ -49,6 +95,18 @@
self.token += "1"
+class TimeTickCredentialsStub(CredentialsStub):
+ """Credentials that spend some (mocked) time when refreshing a token."""
+
+ def __init__(self, time_tick, token="token"):
+ self._time_tick = time_tick
+ super(TimeTickCredentialsStub, self).__init__(token=token)
+
+ def refresh(self, request):
+ self._time_tick()
+ super(TimeTickCredentialsStub, self).refresh(requests)
+
+
class AdapterStub(requests.adapters.BaseAdapter):
def __init__(self, responses, headers=None):
super(AdapterStub, self).__init__()
@@ -69,6 +127,18 @@
return
+class TimeTickAdapterStub(AdapterStub):
+ """Adapter that spends some (mocked) time when making a request."""
+
+ def __init__(self, time_tick, responses, headers=None):
+ self._time_tick = time_tick
+ super(TimeTickAdapterStub, self).__init__(responses, headers=headers)
+
+ def send(self, request, **kwargs):
+ self._time_tick()
+ return super(TimeTickAdapterStub, self).send(request, **kwargs)
+
+
def make_response(status=http_client.OK, data=None):
response = requests.Response()
response.status_code = status
@@ -121,7 +191,9 @@
[make_response(status=http_client.UNAUTHORIZED), final_response]
)
- authed_session = google.auth.transport.requests.AuthorizedSession(credentials)
+ authed_session = google.auth.transport.requests.AuthorizedSession(
+ credentials, refresh_timeout=60
+ )
authed_session.mount(self.TEST_URL, adapter)
result = authed_session.request("GET", self.TEST_URL)
@@ -136,3 +208,44 @@
assert adapter.requests[1].url == self.TEST_URL
assert adapter.requests[1].headers["authorization"] == "token1"
+
+ def test_request_timout(self, frozen_time):
+ tick_one_second = functools.partial(frozen_time.tick, delta=1.0)
+
+ credentials = mock.Mock(
+ wraps=TimeTickCredentialsStub(time_tick=tick_one_second)
+ )
+ adapter = TimeTickAdapterStub(
+ time_tick=tick_one_second,
+ responses=[
+ make_response(status=http_client.UNAUTHORIZED),
+ make_response(status=http_client.OK),
+ ],
+ )
+
+ authed_session = google.auth.transport.requests.AuthorizedSession(credentials)
+ authed_session.mount(self.TEST_URL, adapter)
+
+ # Because at least two requests have to be made, and each takes one
+ # second, the total timeout specified will be exceeded.
+ with pytest.raises(requests.exceptions.Timeout):
+ authed_session.request("GET", self.TEST_URL, timeout=1.9)
+
+ def test_request_timeout_w_refresh_timeout(self, frozen_time):
+ credentials = mock.Mock(wraps=CredentialsStub())
+ adapter = TimeTickAdapterStub(
+ time_tick=functools.partial(frozen_time.tick, delta=1.0), # one second
+ responses=[
+ make_response(status=http_client.UNAUTHORIZED),
+ make_response(status=http_client.OK),
+ ],
+ )
+
+ authed_session = google.auth.transport.requests.AuthorizedSession(
+ credentials, refresh_timeout=0.9
+ )
+ authed_session.mount(self.TEST_URL, adapter)
+
+ # The timeout is long, but the short refresh timeout will prevail.
+ with pytest.raises(requests.exceptions.Timeout):
+ authed_session.request("GET", self.TEST_URL, timeout=60)