feat: add timeout parameter to `AuthorizedSession.request()` (#406)
* feat: add timeout to AuthorisedSession.request()
* Add suport for timeout as a tuple to timeout guard
The `request.Request` class also accepts a timeout as a pair
(connect_timeout, read_timeout), and some downstream libraries use
this form.
This commit makes sure that the timeout logic correctly handles
timeouts as a two-tuple.
See also:
https://2.python-requests.org/en/master/user/advanced/#timeouts
diff --git a/google/auth/transport/requests.py b/google/auth/transport/requests.py
index d1971cd..1caec0b 100644
--- a/google/auth/transport/requests.py
+++ b/google/auth/transport/requests.py
@@ -18,6 +18,8 @@
import functools
import logging
+import numbers
+import time
try:
import requests
@@ -64,6 +66,50 @@
return self._response.content
+class TimeoutGuard(object):
+ """A context manager raising an error if the suite execution took too long.
+
+ Args:
+ timeout ([Union[None, float, Tuple[float, float]]]):
+ The maximum number of seconds a suite can run without the context
+ manager raising a timeout exception on exit. If passed as a tuple,
+ the smaller of the values is taken as a timeout. If ``None``, a
+ timeout error is never raised.
+ timeout_error_type (Optional[Exception]):
+ The type of the error to raise on timeout. Defaults to
+ :class:`requests.exceptions.Timeout`.
+ """
+
+ def __init__(self, timeout, timeout_error_type=requests.exceptions.Timeout):
+ self._timeout = timeout
+ self.remaining_timeout = timeout
+ self._timeout_error_type = timeout_error_type
+
+ def __enter__(self):
+ self._start = time.time()
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ if exc_value:
+ return # let the error bubble up automatically
+
+ if self._timeout is None:
+ return # nothing to do, the timeout was not specified
+
+ elapsed = time.time() - self._start
+ deadline_hit = False
+
+ if isinstance(self._timeout, numbers.Number):
+ self.remaining_timeout = self._timeout - elapsed
+ deadline_hit = self.remaining_timeout <= 0
+ else:
+ self.remaining_timeout = tuple(x - elapsed for x in self._timeout)
+ deadline_hit = min(self.remaining_timeout) <= 0
+
+ if deadline_hit:
+ raise self._timeout_error_type()
+
+
class Request(transport.Request):
"""Requests request adapter.
@@ -193,8 +239,19 @@
# credentials.refresh).
self._auth_request = auth_request
- def request(self, method, url, data=None, headers=None, **kwargs):
- """Implementation of Requests' request."""
+ def request(self, method, url, data=None, headers=None, timeout=None, **kwargs):
+ """Implementation of Requests' request.
+
+ Args:
+ timeout (Optional[Union[float, Tuple[float, float]]]): The number
+ of seconds to wait before raising a ``Timeout`` exception. If
+ multiple requests are made under the hood, ``timeout`` is
+ interpreted as the approximate total time of **all** requests.
+
+ If passed as a tuple ``(connect_timeout, read_timeout)``, the
+ smaller of the values is taken as the total allowed time across
+ all requests.
+ """
# pylint: disable=arguments-differ
# Requests has a ton of arguments to request, but only two
# (method, url) are required. We pass through all of the other
@@ -208,13 +265,28 @@
# and we want to pass the original headers if we recurse.
request_headers = headers.copy() if headers is not None else {}
- self.credentials.before_request(
- self._auth_request, method, url, request_headers
+ # Do not apply the timeout unconditionally in order to not override the
+ # _auth_request's default timeout.
+ auth_request = (
+ self._auth_request
+ if timeout is None
+ else functools.partial(self._auth_request, timeout=timeout)
)
- response = super(AuthorizedSession, self).request(
- method, url, data=data, headers=request_headers, **kwargs
- )
+ with TimeoutGuard(timeout) as guard:
+ self.credentials.before_request(auth_request, method, url, request_headers)
+ timeout = guard.remaining_timeout
+
+ with TimeoutGuard(timeout) as guard:
+ response = super(AuthorizedSession, self).request(
+ method,
+ url,
+ data=data,
+ headers=request_headers,
+ timeout=timeout,
+ **kwargs
+ )
+ timeout = guard.remaining_timeout
# If the response indicated that the credentials needed to be
# refreshed, then refresh the credentials and re-attempt the
@@ -233,17 +305,34 @@
self._max_refresh_attempts,
)
- auth_request_with_timeout = functools.partial(
- self._auth_request, timeout=self._refresh_timeout
- )
- self.credentials.refresh(auth_request_with_timeout)
+ if self._refresh_timeout is not None:
+ if timeout is None:
+ timeout = self._refresh_timeout
+ elif isinstance(timeout, numbers.Number):
+ timeout = min(timeout, self._refresh_timeout)
+ else:
+ timeout = tuple(min(x, self._refresh_timeout) for x in timeout)
- # Recurse. Pass in the original headers, not our modified set.
+ # Do not apply the timeout unconditionally in order to not override the
+ # _auth_request's default timeout.
+ auth_request = (
+ self._auth_request
+ if timeout is None
+ else functools.partial(self._auth_request, timeout=timeout)
+ )
+
+ with TimeoutGuard(timeout) as guard:
+ self.credentials.refresh(auth_request)
+ timeout = guard.remaining_timeout
+
+ # Recurse. Pass in the original headers, not our modified set, but
+ # do pass the adjusted timeout (i.e. the remaining time).
return self.request(
method,
url,
data=data,
headers=headers,
+ timeout=timeout,
_credential_refresh_attempt=_credential_refresh_attempt + 1,
**kwargs
)
diff --git a/noxfile.py b/noxfile.py
index aaf1bc5..e170ee5 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -16,6 +16,7 @@
TEST_DEPENDENCIES = [
"flask",
+ "freezegun",
"mock",
"oauth2client",
"pytest",
diff --git a/tests/transport/test_requests.py b/tests/transport/test_requests.py
index 0e165ac..0026974 100644
--- a/tests/transport/test_requests.py
+++ b/tests/transport/test_requests.py
@@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import datetime
+import functools
+
+import freezegun
import mock
+import pytest
import requests
import requests.adapters
from six.moves import http_client
@@ -22,6 +27,12 @@
from tests.transport import compliance
+@pytest.fixture
+def frozen_time():
+ with freezegun.freeze_time("1970-01-01 00:00:00", tick=False) as frozen:
+ yield frozen
+
+
class TestRequestResponse(compliance.RequestResponseTests):
def make_request(self):
return google.auth.transport.requests.Request()
@@ -34,6 +45,52 @@
assert http.request.call_args[1]["timeout"] == 5
+class TestTimeoutGuard(object):
+ def make_guard(self, *args, **kwargs):
+ return google.auth.transport.requests.TimeoutGuard(*args, **kwargs)
+
+ def test_tracks_elapsed_time_w_numeric_timeout(self, frozen_time):
+ with self.make_guard(timeout=10) as guard:
+ frozen_time.tick(delta=3.8)
+ assert guard.remaining_timeout == 6.2
+
+ def test_tracks_elapsed_time_w_tuple_timeout(self, frozen_time):
+ with self.make_guard(timeout=(16, 19)) as guard:
+ frozen_time.tick(delta=3.8)
+ assert guard.remaining_timeout == (12.2, 15.2)
+
+ def test_noop_if_no_timeout(self, frozen_time):
+ with self.make_guard(timeout=None) as guard:
+ frozen_time.tick(delta=datetime.timedelta(days=3650))
+ # NOTE: no timeout error raised, despite years have passed
+ assert guard.remaining_timeout is None
+
+ def test_timeout_error_w_numeric_timeout(self, frozen_time):
+ with pytest.raises(requests.exceptions.Timeout):
+ with self.make_guard(timeout=10) as guard:
+ frozen_time.tick(delta=10.001)
+ assert guard.remaining_timeout == pytest.approx(-0.001)
+
+ def test_timeout_error_w_tuple_timeout(self, frozen_time):
+ with pytest.raises(requests.exceptions.Timeout):
+ with self.make_guard(timeout=(11, 10)) as guard:
+ frozen_time.tick(delta=10.001)
+ assert guard.remaining_timeout == pytest.approx((0.999, -0.001))
+
+ def test_custom_timeout_error_type(self, frozen_time):
+ class FooError(Exception):
+ pass
+
+ with pytest.raises(FooError):
+ with self.make_guard(timeout=1, timeout_error_type=FooError):
+ frozen_time.tick(2)
+
+ def test_lets_suite_errors_bubble_up(self, frozen_time):
+ with pytest.raises(IndexError):
+ with self.make_guard(timeout=1):
+ [1, 2, 3][3]
+
+
class CredentialsStub(google.auth.credentials.Credentials):
def __init__(self, token="token"):
super(CredentialsStub, self).__init__()
@@ -49,6 +106,18 @@
self.token += "1"
+class TimeTickCredentialsStub(CredentialsStub):
+ """Credentials that spend some (mocked) time when refreshing a token."""
+
+ def __init__(self, time_tick, token="token"):
+ self._time_tick = time_tick
+ super(TimeTickCredentialsStub, self).__init__(token=token)
+
+ def refresh(self, request):
+ self._time_tick()
+ super(TimeTickCredentialsStub, self).refresh(requests)
+
+
class AdapterStub(requests.adapters.BaseAdapter):
def __init__(self, responses, headers=None):
super(AdapterStub, self).__init__()
@@ -69,6 +138,18 @@
return
+class TimeTickAdapterStub(AdapterStub):
+ """Adapter that spends some (mocked) time when making a request."""
+
+ def __init__(self, time_tick, responses, headers=None):
+ self._time_tick = time_tick
+ super(TimeTickAdapterStub, self).__init__(responses, headers=headers)
+
+ def send(self, request, **kwargs):
+ self._time_tick()
+ return super(TimeTickAdapterStub, self).send(request, **kwargs)
+
+
def make_response(status=http_client.OK, data=None):
response = requests.Response()
response.status_code = status
@@ -121,7 +202,9 @@
[make_response(status=http_client.UNAUTHORIZED), final_response]
)
- authed_session = google.auth.transport.requests.AuthorizedSession(credentials)
+ authed_session = google.auth.transport.requests.AuthorizedSession(
+ credentials, refresh_timeout=60
+ )
authed_session.mount(self.TEST_URL, adapter)
result = authed_session.request("GET", self.TEST_URL)
@@ -136,3 +219,72 @@
assert adapter.requests[1].url == self.TEST_URL
assert adapter.requests[1].headers["authorization"] == "token1"
+
+ def test_request_timeout(self, frozen_time):
+ tick_one_second = functools.partial(frozen_time.tick, delta=1.0)
+
+ credentials = mock.Mock(
+ wraps=TimeTickCredentialsStub(time_tick=tick_one_second)
+ )
+ adapter = TimeTickAdapterStub(
+ time_tick=tick_one_second,
+ responses=[
+ make_response(status=http_client.UNAUTHORIZED),
+ make_response(status=http_client.OK),
+ ],
+ )
+
+ authed_session = google.auth.transport.requests.AuthorizedSession(credentials)
+ authed_session.mount(self.TEST_URL, adapter)
+
+ # Because at least two requests have to be made, and each takes one
+ # second, the total timeout specified will be exceeded.
+ with pytest.raises(requests.exceptions.Timeout):
+ authed_session.request("GET", self.TEST_URL, timeout=1.9)
+
+ def test_request_timeout_w_refresh_timeout(self, frozen_time):
+ tick_one_second = functools.partial(frozen_time.tick, delta=1.0)
+
+ credentials = mock.Mock(
+ wraps=TimeTickCredentialsStub(time_tick=tick_one_second)
+ )
+ adapter = TimeTickAdapterStub(
+ time_tick=tick_one_second,
+ responses=[
+ make_response(status=http_client.UNAUTHORIZED),
+ make_response(status=http_client.OK),
+ ],
+ )
+
+ authed_session = google.auth.transport.requests.AuthorizedSession(
+ credentials, refresh_timeout=1.9
+ )
+ authed_session.mount(self.TEST_URL, adapter)
+
+ # The timeout is long, but the short refresh timeout will prevail.
+ with pytest.raises(requests.exceptions.Timeout):
+ authed_session.request("GET", self.TEST_URL, timeout=60)
+
+ def test_request_timeout_w_refresh_timeout_and_tuple_timeout(self, frozen_time):
+ tick_one_second = functools.partial(frozen_time.tick, delta=1.0)
+
+ credentials = mock.Mock(
+ wraps=TimeTickCredentialsStub(time_tick=tick_one_second)
+ )
+ adapter = TimeTickAdapterStub(
+ time_tick=tick_one_second,
+ responses=[
+ make_response(status=http_client.UNAUTHORIZED),
+ make_response(status=http_client.OK),
+ ],
+ )
+
+ authed_session = google.auth.transport.requests.AuthorizedSession(
+ credentials, refresh_timeout=100
+ )
+ authed_session.mount(self.TEST_URL, adapter)
+
+ # The shortest timeout will prevail and cause a Timeout error, despite
+ # other timeouts being quite long.
+ with pytest.raises(requests.exceptions.Timeout):
+ authed_session.request("GET", self.TEST_URL, timeout=(100, 2.9))