feat: add timeout parameter to `AuthorizedSession.request()` (#406)

* feat: add timeout to AuthorisedSession.request()

* Add suport for timeout as a tuple to timeout guard

The `request.Request` class also accepts a timeout as a pair
(connect_timeout, read_timeout), and some downstream libraries use
this form.

This commit makes sure that the timeout logic correctly handles
timeouts as a two-tuple.

See also:
https://2.python-requests.org/en/master/user/advanced/#timeouts
diff --git a/google/auth/transport/requests.py b/google/auth/transport/requests.py
index d1971cd..1caec0b 100644
--- a/google/auth/transport/requests.py
+++ b/google/auth/transport/requests.py
@@ -18,6 +18,8 @@
 
 import functools
 import logging
+import numbers
+import time
 
 try:
     import requests
@@ -64,6 +66,50 @@
         return self._response.content
 
 
+class TimeoutGuard(object):
+    """A context manager raising an error if the suite execution took too long.
+
+    Args:
+        timeout ([Union[None, float, Tuple[float, float]]]):
+            The maximum number of seconds a suite can run without the context
+            manager raising a timeout exception on exit. If passed as a tuple,
+            the smaller of the values is taken as a timeout. If ``None``, a
+            timeout error is never raised.
+        timeout_error_type (Optional[Exception]):
+            The type of the error to raise on timeout. Defaults to
+            :class:`requests.exceptions.Timeout`.
+    """
+
+    def __init__(self, timeout, timeout_error_type=requests.exceptions.Timeout):
+        self._timeout = timeout
+        self.remaining_timeout = timeout
+        self._timeout_error_type = timeout_error_type
+
+    def __enter__(self):
+        self._start = time.time()
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        if exc_value:
+            return  # let the error bubble up automatically
+
+        if self._timeout is None:
+            return  # nothing to do, the timeout was not specified
+
+        elapsed = time.time() - self._start
+        deadline_hit = False
+
+        if isinstance(self._timeout, numbers.Number):
+            self.remaining_timeout = self._timeout - elapsed
+            deadline_hit = self.remaining_timeout <= 0
+        else:
+            self.remaining_timeout = tuple(x - elapsed for x in self._timeout)
+            deadline_hit = min(self.remaining_timeout) <= 0
+
+        if deadline_hit:
+            raise self._timeout_error_type()
+
+
 class Request(transport.Request):
     """Requests request adapter.
 
@@ -193,8 +239,19 @@
         # credentials.refresh).
         self._auth_request = auth_request
 
-    def request(self, method, url, data=None, headers=None, **kwargs):
-        """Implementation of Requests' request."""
+    def request(self, method, url, data=None, headers=None, timeout=None, **kwargs):
+        """Implementation of Requests' request.
+
+        Args:
+            timeout (Optional[Union[float, Tuple[float, float]]]): The number
+                of seconds to wait before raising a ``Timeout`` exception. If
+                multiple requests are made under the hood, ``timeout`` is
+                interpreted as the approximate total time of **all** requests.
+
+                If passed as a tuple ``(connect_timeout, read_timeout)``, the
+                smaller of the values is taken as the total allowed time across
+                all requests.
+        """
         # pylint: disable=arguments-differ
         # Requests has a ton of arguments to request, but only two
         # (method, url) are required. We pass through all of the other
@@ -208,13 +265,28 @@
         # and we want to pass the original headers if we recurse.
         request_headers = headers.copy() if headers is not None else {}
 
-        self.credentials.before_request(
-            self._auth_request, method, url, request_headers
+        # Do not apply the timeout unconditionally in order to not override the
+        # _auth_request's default timeout.
+        auth_request = (
+            self._auth_request
+            if timeout is None
+            else functools.partial(self._auth_request, timeout=timeout)
         )
 
-        response = super(AuthorizedSession, self).request(
-            method, url, data=data, headers=request_headers, **kwargs
-        )
+        with TimeoutGuard(timeout) as guard:
+            self.credentials.before_request(auth_request, method, url, request_headers)
+        timeout = guard.remaining_timeout
+
+        with TimeoutGuard(timeout) as guard:
+            response = super(AuthorizedSession, self).request(
+                method,
+                url,
+                data=data,
+                headers=request_headers,
+                timeout=timeout,
+                **kwargs
+            )
+        timeout = guard.remaining_timeout
 
         # If the response indicated that the credentials needed to be
         # refreshed, then refresh the credentials and re-attempt the
@@ -233,17 +305,34 @@
                 self._max_refresh_attempts,
             )
 
-            auth_request_with_timeout = functools.partial(
-                self._auth_request, timeout=self._refresh_timeout
-            )
-            self.credentials.refresh(auth_request_with_timeout)
+            if self._refresh_timeout is not None:
+                if timeout is None:
+                    timeout = self._refresh_timeout
+                elif isinstance(timeout, numbers.Number):
+                    timeout = min(timeout, self._refresh_timeout)
+                else:
+                    timeout = tuple(min(x, self._refresh_timeout) for x in timeout)
 
-            # Recurse. Pass in the original headers, not our modified set.
+            # Do not apply the timeout unconditionally in order to not override the
+            # _auth_request's default timeout.
+            auth_request = (
+                self._auth_request
+                if timeout is None
+                else functools.partial(self._auth_request, timeout=timeout)
+            )
+
+            with TimeoutGuard(timeout) as guard:
+                self.credentials.refresh(auth_request)
+            timeout = guard.remaining_timeout
+
+            # Recurse. Pass in the original headers, not our modified set, but
+            # do pass the adjusted timeout (i.e. the remaining time).
             return self.request(
                 method,
                 url,
                 data=data,
                 headers=headers,
+                timeout=timeout,
                 _credential_refresh_attempt=_credential_refresh_attempt + 1,
                 **kwargs
             )
diff --git a/noxfile.py b/noxfile.py
index aaf1bc5..e170ee5 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -16,6 +16,7 @@
 
 TEST_DEPENDENCIES = [
     "flask",
+    "freezegun",
     "mock",
     "oauth2client",
     "pytest",
diff --git a/tests/transport/test_requests.py b/tests/transport/test_requests.py
index 0e165ac..0026974 100644
--- a/tests/transport/test_requests.py
+++ b/tests/transport/test_requests.py
@@ -12,7 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import datetime
+import functools
+
+import freezegun
 import mock
+import pytest
 import requests
 import requests.adapters
 from six.moves import http_client
@@ -22,6 +27,12 @@
 from tests.transport import compliance
 
 
+@pytest.fixture
+def frozen_time():
+    with freezegun.freeze_time("1970-01-01 00:00:00", tick=False) as frozen:
+        yield frozen
+
+
 class TestRequestResponse(compliance.RequestResponseTests):
     def make_request(self):
         return google.auth.transport.requests.Request()
@@ -34,6 +45,52 @@
         assert http.request.call_args[1]["timeout"] == 5
 
 
+class TestTimeoutGuard(object):
+    def make_guard(self, *args, **kwargs):
+        return google.auth.transport.requests.TimeoutGuard(*args, **kwargs)
+
+    def test_tracks_elapsed_time_w_numeric_timeout(self, frozen_time):
+        with self.make_guard(timeout=10) as guard:
+            frozen_time.tick(delta=3.8)
+        assert guard.remaining_timeout == 6.2
+
+    def test_tracks_elapsed_time_w_tuple_timeout(self, frozen_time):
+        with self.make_guard(timeout=(16, 19)) as guard:
+            frozen_time.tick(delta=3.8)
+        assert guard.remaining_timeout == (12.2, 15.2)
+
+    def test_noop_if_no_timeout(self, frozen_time):
+        with self.make_guard(timeout=None) as guard:
+            frozen_time.tick(delta=datetime.timedelta(days=3650))
+        # NOTE: no timeout error raised, despite years have passed
+        assert guard.remaining_timeout is None
+
+    def test_timeout_error_w_numeric_timeout(self, frozen_time):
+        with pytest.raises(requests.exceptions.Timeout):
+            with self.make_guard(timeout=10) as guard:
+                frozen_time.tick(delta=10.001)
+        assert guard.remaining_timeout == pytest.approx(-0.001)
+
+    def test_timeout_error_w_tuple_timeout(self, frozen_time):
+        with pytest.raises(requests.exceptions.Timeout):
+            with self.make_guard(timeout=(11, 10)) as guard:
+                frozen_time.tick(delta=10.001)
+        assert guard.remaining_timeout == pytest.approx((0.999, -0.001))
+
+    def test_custom_timeout_error_type(self, frozen_time):
+        class FooError(Exception):
+            pass
+
+        with pytest.raises(FooError):
+            with self.make_guard(timeout=1, timeout_error_type=FooError):
+                frozen_time.tick(2)
+
+    def test_lets_suite_errors_bubble_up(self, frozen_time):
+        with pytest.raises(IndexError):
+            with self.make_guard(timeout=1):
+                [1, 2, 3][3]
+
+
 class CredentialsStub(google.auth.credentials.Credentials):
     def __init__(self, token="token"):
         super(CredentialsStub, self).__init__()
@@ -49,6 +106,18 @@
         self.token += "1"
 
 
+class TimeTickCredentialsStub(CredentialsStub):
+    """Credentials that spend some (mocked) time when refreshing a token."""
+
+    def __init__(self, time_tick, token="token"):
+        self._time_tick = time_tick
+        super(TimeTickCredentialsStub, self).__init__(token=token)
+
+    def refresh(self, request):
+        self._time_tick()
+        super(TimeTickCredentialsStub, self).refresh(requests)
+
+
 class AdapterStub(requests.adapters.BaseAdapter):
     def __init__(self, responses, headers=None):
         super(AdapterStub, self).__init__()
@@ -69,6 +138,18 @@
         return
 
 
+class TimeTickAdapterStub(AdapterStub):
+    """Adapter that spends some (mocked) time when making a request."""
+
+    def __init__(self, time_tick, responses, headers=None):
+        self._time_tick = time_tick
+        super(TimeTickAdapterStub, self).__init__(responses, headers=headers)
+
+    def send(self, request, **kwargs):
+        self._time_tick()
+        return super(TimeTickAdapterStub, self).send(request, **kwargs)
+
+
 def make_response(status=http_client.OK, data=None):
     response = requests.Response()
     response.status_code = status
@@ -121,7 +202,9 @@
             [make_response(status=http_client.UNAUTHORIZED), final_response]
         )
 
-        authed_session = google.auth.transport.requests.AuthorizedSession(credentials)
+        authed_session = google.auth.transport.requests.AuthorizedSession(
+            credentials, refresh_timeout=60
+        )
         authed_session.mount(self.TEST_URL, adapter)
 
         result = authed_session.request("GET", self.TEST_URL)
@@ -136,3 +219,72 @@
 
         assert adapter.requests[1].url == self.TEST_URL
         assert adapter.requests[1].headers["authorization"] == "token1"
+
+    def test_request_timeout(self, frozen_time):
+        tick_one_second = functools.partial(frozen_time.tick, delta=1.0)
+
+        credentials = mock.Mock(
+            wraps=TimeTickCredentialsStub(time_tick=tick_one_second)
+        )
+        adapter = TimeTickAdapterStub(
+            time_tick=tick_one_second,
+            responses=[
+                make_response(status=http_client.UNAUTHORIZED),
+                make_response(status=http_client.OK),
+            ],
+        )
+
+        authed_session = google.auth.transport.requests.AuthorizedSession(credentials)
+        authed_session.mount(self.TEST_URL, adapter)
+
+        # Because at least two requests have to be made, and each takes one
+        # second, the total timeout specified will be exceeded.
+        with pytest.raises(requests.exceptions.Timeout):
+            authed_session.request("GET", self.TEST_URL, timeout=1.9)
+
+    def test_request_timeout_w_refresh_timeout(self, frozen_time):
+        tick_one_second = functools.partial(frozen_time.tick, delta=1.0)
+
+        credentials = mock.Mock(
+            wraps=TimeTickCredentialsStub(time_tick=tick_one_second)
+        )
+        adapter = TimeTickAdapterStub(
+            time_tick=tick_one_second,
+            responses=[
+                make_response(status=http_client.UNAUTHORIZED),
+                make_response(status=http_client.OK),
+            ],
+        )
+
+        authed_session = google.auth.transport.requests.AuthorizedSession(
+            credentials, refresh_timeout=1.9
+        )
+        authed_session.mount(self.TEST_URL, adapter)
+
+        # The timeout is long, but the short refresh timeout will prevail.
+        with pytest.raises(requests.exceptions.Timeout):
+            authed_session.request("GET", self.TEST_URL, timeout=60)
+
+    def test_request_timeout_w_refresh_timeout_and_tuple_timeout(self, frozen_time):
+        tick_one_second = functools.partial(frozen_time.tick, delta=1.0)
+
+        credentials = mock.Mock(
+            wraps=TimeTickCredentialsStub(time_tick=tick_one_second)
+        )
+        adapter = TimeTickAdapterStub(
+            time_tick=tick_one_second,
+            responses=[
+                make_response(status=http_client.UNAUTHORIZED),
+                make_response(status=http_client.OK),
+            ],
+        )
+
+        authed_session = google.auth.transport.requests.AuthorizedSession(
+            credentials, refresh_timeout=100
+        )
+        authed_session.mount(self.TEST_URL, adapter)
+
+        # The shortest timeout will prevail and cause a Timeout error, despite
+        # other timeouts being quite long.
+        with pytest.raises(requests.exceptions.Timeout):
+            authed_session.request("GET", self.TEST_URL, timeout=(100, 2.9))