API Core: Retry.__init__ add on_error (#8892)
diff --git a/google/api_core/retry.py b/google/api_core/retry.py
index 747c137..79343ba 100644
--- a/google/api_core/retry.py
+++ b/google/api_core/retry.py
@@ -237,12 +237,14 @@
maximum=_DEFAULT_MAXIMUM_DELAY,
multiplier=_DEFAULT_DELAY_MULTIPLIER,
deadline=_DEFAULT_DEADLINE,
+ on_error=None
):
self._predicate = predicate
self._initial = initial
self._multiplier = multiplier
self._maximum = maximum
self._deadline = deadline
+ self._on_error = on_error
def __call__(self, func, on_error=None):
"""Wrap a callable with retry behavior.
@@ -257,6 +259,8 @@
Callable: A callable that will invoke ``func`` with retry
behavior.
"""
+ if self._on_error is not None:
+ on_error = self._on_error
@general_helpers.wraps(func)
def retry_wrapped_func(*args, **kwargs):
@@ -290,6 +294,7 @@
maximum=self._maximum,
multiplier=self._multiplier,
deadline=deadline,
+ on_error=self._on_error,
)
def with_predicate(self, predicate):
@@ -308,6 +313,7 @@
maximum=self._maximum,
multiplier=self._multiplier,
deadline=self._deadline,
+ on_error=self._on_error,
)
def with_delay(self, initial=None, maximum=None, multiplier=None):
@@ -328,16 +334,18 @@
maximum=maximum if maximum is not None else self._maximum,
multiplier=multiplier if maximum is not None else self._multiplier,
deadline=self._deadline,
+ on_error=self._on_error,
)
def __str__(self):
return (
"<Retry predicate={}, initial={:.1f}, maximum={:.1f}, "
- "multiplier={:.1f}, deadline={:.1f}>".format(
+ "multiplier={:.1f}, deadline={:.1f}, on_error={}>".format(
self._predicate,
self._initial,
self._maximum,
self._multiplier,
self._deadline,
+ self._on_error,
)
)
diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py
index 53c2396..5b5e59b 100644
--- a/tests/unit/test_retry.py
+++ b/tests/unit/test_retry.py
@@ -161,20 +161,25 @@
assert retry_._maximum == 60
assert retry_._multiplier == 2
assert retry_._deadline == 120
+ assert retry_._on_error is None
def test_constructor_options(self):
+ _some_function = mock.Mock()
+
retry_ = retry.Retry(
predicate=mock.sentinel.predicate,
initial=1,
maximum=2,
multiplier=3,
deadline=4,
+ on_error=_some_function,
)
assert retry_._predicate == mock.sentinel.predicate
assert retry_._initial == 1
assert retry_._maximum == 2
assert retry_._multiplier == 3
assert retry_._deadline == 4
+ assert retry_._on_error is _some_function
def test_with_deadline(self):
retry_ = retry.Retry()
@@ -209,7 +214,8 @@
assert re.match(
(
r"<Retry predicate=<function.*?if_exception_type.*?>, "
- r"initial=1.0, maximum=60.0, multiplier=2.0, deadline=120.0>"
+ r"initial=1.0, maximum=60.0, multiplier=2.0, deadline=120.0, "
+ r"on_error=None>"
),
str(retry_),
)
@@ -230,8 +236,7 @@
target.assert_called_once_with("meep")
sleep.assert_not_called()
- # Make uniform return half of its maximum, which will be the calculated
- # sleep time.
+ # Make uniform return half of its maximum, which is the calculated sleep time.
@mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n / 2.0)
@mock.patch("time.sleep", autospec=True)
def test___call___and_execute_retry(self, sleep, uniform):
@@ -253,3 +258,55 @@
target.assert_has_calls([mock.call("meep"), mock.call("meep")])
sleep.assert_called_once_with(retry_._initial)
assert on_error.call_count == 1
+
+ @mock.patch("time.sleep", autospec=True)
+ def test___init___without_retry_executed(self, sleep):
+ _some_function = mock.Mock()
+
+ retry_ = retry.Retry(
+ predicate=retry.if_exception_type(ValueError), on_error=_some_function
+ )
+ # check the proper creation of the class
+ assert retry_._on_error is _some_function
+
+ target = mock.Mock(spec=["__call__"], side_effect=[42])
+ # __name__ is needed by functools.partial.
+ target.__name__ = "target"
+
+ wrapped = retry_(target)
+
+ result = wrapped("meep")
+
+ assert result == 42
+ target.assert_called_once_with("meep")
+ sleep.assert_not_called()
+ _some_function.assert_not_called()
+
+ # Make uniform return half of its maximum, which is the calculated sleep time.
+ @mock.patch("random.uniform", autospec=True, side_effect=lambda m, n: n / 2.0)
+ @mock.patch("time.sleep", autospec=True)
+ def test___init___when_retry_is_executed(self, sleep, uniform):
+ _some_function = mock.Mock()
+
+ retry_ = retry.Retry(
+ predicate=retry.if_exception_type(ValueError), on_error=_some_function
+ )
+ # check the proper creation of the class
+ assert retry_._on_error is _some_function
+
+ target = mock.Mock(
+ spec=["__call__"], side_effect=[ValueError(), ValueError(), 42]
+ )
+ # __name__ is needed by functools.partial.
+ target.__name__ = "target"
+
+ wrapped = retry_(target)
+ target.assert_not_called()
+
+ result = wrapped("meep")
+
+ assert result == 42
+ assert target.call_count == 3
+ assert _some_function.call_count == 2
+ target.assert_has_calls([mock.call("meep"), mock.call("meep")])
+ sleep.assert_any_call(retry_._initial)