fix: make gRPC auth plugin non-blocking + add default timeout value for requests transport (#390)
This commit includes the following changes:
- `transport.grpc.AuthMetadataPlugin` is now non-blocking as gRPC requires
- `transport.requests.Request` now has a default timeout value of 120 seconds so that token refreshing will not be stuck
Resolves: #351
diff --git a/google/auth/transport/grpc.py b/google/auth/transport/grpc.py
index 9a1bc6d..80f6e81 100644
--- a/google/auth/transport/grpc.py
+++ b/google/auth/transport/grpc.py
@@ -16,6 +16,8 @@
from __future__ import absolute_import
+from concurrent import futures
+
import six
try:
@@ -51,6 +53,7 @@
super(AuthMetadataPlugin, self).__init__()
self._credentials = credentials
self._request = request
+ self._pool = futures.ThreadPoolExecutor(max_workers=1)
def _get_authorization_headers(self, context):
"""Gets the authorization headers for a request.
@@ -66,6 +69,13 @@
return list(six.iteritems(headers))
+ @staticmethod
+ def _callback_wrapper(callback):
+ def wrapped(future):
+ callback(future.result(), None)
+
+ return wrapped
+
def __call__(self, context, callback):
"""Passes authorization metadata into the given callback.
@@ -74,7 +84,11 @@
callback (grpc.AuthMetadataPluginCallback): The callback that will
be invoked to pass in the authorization metadata.
"""
- callback(self._get_authorization_headers(context), None)
+ future = self._pool.submit(self._get_authorization_headers, context)
+ future.add_done_callback(self._callback_wrapper(callback))
+
+ def __del__(self):
+ self._pool.shutdown(wait=False)
def secure_authorized_channel(
diff --git a/google/auth/transport/requests.py b/google/auth/transport/requests.py
index 564a0cd..d1971cd 100644
--- a/google/auth/transport/requests.py
+++ b/google/auth/transport/requests.py
@@ -95,7 +95,7 @@
self.session = session
def __call__(
- self, url, method="GET", body=None, headers=None, timeout=None, **kwargs
+ self, url, method="GET", body=None, headers=None, timeout=120, **kwargs
):
"""Make an HTTP request using requests.
diff --git a/tests/transport/test_grpc.py b/tests/transport/test_grpc.py
index 810d038..ca12385 100644
--- a/tests/transport/test_grpc.py
+++ b/tests/transport/test_grpc.py
@@ -13,6 +13,7 @@
# limitations under the License.
import datetime
+import time
import mock
import pytest
@@ -58,6 +59,8 @@
plugin(context, callback)
+ time.sleep(2)
+
callback.assert_called_once_with(
[(u"authorization", u"Bearer {}".format(credentials.token))], None
)
@@ -76,6 +79,8 @@
plugin(context, callback)
+ time.sleep(2)
+
assert credentials.token == "token1"
callback.assert_called_once_with(
[(u"authorization", u"Bearer {}".format(credentials.token))], None