feat(api-core): pass retry from result() to done() (#9)
Towards: https://github.com/googleapis/python-bigquery/issues/24
diff --git a/google/api_core/future/polling.py b/google/api_core/future/polling.py
index 6b4c687..6466838 100644
--- a/google/api_core/future/polling.py
+++ b/google/api_core/future/polling.py
@@ -78,16 +78,18 @@
# pylint: disable=redundant-returns-doc, missing-raises-doc
raise NotImplementedError()
- def _done_or_raise(self):
+ def _done_or_raise(self, retry=DEFAULT_RETRY):
"""Check if the future is done and raise if it's not."""
- if not self.done():
+ kwargs = {} if retry is DEFAULT_RETRY else {"retry": retry}
+
+ if not self.done(**kwargs):
raise _OperationNotComplete()
def running(self):
"""True if the operation is currently running."""
return not self.done()
- def _blocking_poll(self, timeout=None):
+ def _blocking_poll(self, timeout=None, retry=DEFAULT_RETRY):
"""Poll and wait for the Future to be resolved.
Args:
@@ -101,13 +103,14 @@
retry_ = self._retry.with_deadline(timeout)
try:
- retry_(self._done_or_raise)()
+ kwargs = {} if retry is DEFAULT_RETRY else {"retry": retry}
+ retry_(self._done_or_raise)(**kwargs)
except exceptions.RetryError:
raise concurrent.futures.TimeoutError(
"Operation did not complete within the designated " "timeout."
)
- def result(self, timeout=None):
+ def result(self, timeout=None, retry=DEFAULT_RETRY):
"""Get the result of the operation, blocking if necessary.
Args:
@@ -122,7 +125,8 @@
google.api_core.GoogleAPICallError: If the operation errors or if
the timeout is reached before the operation completes.
"""
- self._blocking_poll(timeout=timeout)
+ kwargs = {} if retry is DEFAULT_RETRY else {"retry": retry}
+ self._blocking_poll(timeout=timeout, **kwargs)
if self._exception is not None:
# pylint: disable=raising-bad-type
diff --git a/tests/unit/future/test_polling.py b/tests/unit/future/test_polling.py
index c67de06..2381d03 100644
--- a/tests/unit/future/test_polling.py
+++ b/tests/unit/future/test_polling.py
@@ -19,7 +19,7 @@
import mock
import pytest
-from google.api_core import exceptions
+from google.api_core import exceptions, retry
from google.api_core.future import polling
@@ -43,6 +43,8 @@
assert not future.cancelled()
assert future.running()
assert future.cancel()
+ with mock.patch.object(future, "done", return_value=True):
+ future.result()
def test_set_result():
@@ -87,7 +89,7 @@
self.poll_count = 0
self.event = threading.Event()
- def done(self):
+ def done(self, retry=polling.DEFAULT_RETRY):
self.poll_count += 1
self.event.wait()
self.set_result(42)
@@ -108,7 +110,7 @@
class PollingFutureImplTimeout(PollingFutureImplWithPoll):
- def done(self):
+ def done(self, retry=polling.DEFAULT_RETRY):
time.sleep(1)
return False
@@ -130,7 +132,7 @@
super(PollingFutureImplTransient, self).__init__()
self._errors = errors
- def done(self):
+ def done(self, retry=polling.DEFAULT_RETRY):
if self._errors:
error, self._errors = self._errors[0], self._errors[1:]
raise error("testing")
@@ -192,3 +194,49 @@
assert future.poll_count == 1
callback.assert_called_once_with(future)
callback2.assert_called_once_with(future)
+
+
+class PollingFutureImplWithoutRetry(PollingFutureImpl):
+ def done(self):
+ return True
+
+ def result(self):
+ return super(PollingFutureImplWithoutRetry, self).result()
+
+ def _blocking_poll(self, timeout):
+ return super(PollingFutureImplWithoutRetry, self)._blocking_poll(
+ timeout=timeout
+ )
+
+
+class PollingFutureImplWith_done_or_raise(PollingFutureImpl):
+ def done(self):
+ return True
+
+ def _done_or_raise(self):
+ return super(PollingFutureImplWith_done_or_raise, self)._done_or_raise()
+
+
+def test_polling_future_without_retry():
+ custom_retry = retry.Retry(
+ predicate=retry.if_exception_type(exceptions.TooManyRequests)
+ )
+ future = PollingFutureImplWithoutRetry()
+ assert future.done()
+ assert future.running()
+ assert future.result() is None
+
+ with mock.patch.object(future, "done") as done_mock:
+ future._done_or_raise()
+ done_mock.assert_called_once_with()
+
+ with mock.patch.object(future, "done") as done_mock:
+ future._done_or_raise(retry=custom_retry)
+ done_mock.assert_called_once_with(retry=custom_retry)
+
+
+def test_polling_future_with__done_or_raise():
+ future = PollingFutureImplWith_done_or_raise()
+ assert future.done()
+ assert future.running()
+ assert future.result() is None