Firestore: Add 'should_terminate' predicate for clean BiDi shutdown. (#8650)
Closes #7826.
diff --git a/google/api_core/bidi.py b/google/api_core/bidi.py
index 3b69e91..7d3716d 100644
--- a/google/api_core/bidi.py
+++ b/google/api_core/bidi.py
@@ -349,6 +349,11 @@
return self._request_queue.qsize()
+def _never_terminate(future_or_error):
+ """By default, no errors cause BiDi termination."""
+ return False
+
+
class ResumableBidiRpc(BidiRpc):
"""A :class:`BidiRpc` that can automatically resume the stream on errors.
@@ -391,6 +396,9 @@
should_recover (Callable[[Exception], bool]): A function that returns
True if the stream should be recovered. This will be called
whenever an error is encountered on the stream.
+ should_terminate (Callable[[Exception], bool]): A function that returns
+ True if the stream should be terminated. This will be called
+ whenever an error is encountered on the stream.
metadata Sequence[Tuple(str, str)]: RPC metadata to include in
the request.
throttle_reopen (bool): If ``True``, throttling will be applied to
@@ -401,12 +409,14 @@
self,
start_rpc,
should_recover,
+ should_terminate=_never_terminate,
initial_request=None,
metadata=None,
throttle_reopen=False,
):
super(ResumableBidiRpc, self).__init__(start_rpc, initial_request, metadata)
self._should_recover = should_recover
+ self._should_terminate = should_terminate
self._operational_lock = threading.RLock()
self._finalized = False
self._finalize_lock = threading.Lock()
@@ -433,7 +443,9 @@
# error, not for errors that we can recover from. Note that grpc's
# "future" here is also a grpc.RpcError.
with self._operational_lock:
- if not self._should_recover(future):
+ if self._should_terminate(future):
+ self._finalize(future)
+ elif not self._should_recover(future):
self._finalize(future)
else:
_LOGGER.debug("Re-opening stream from gRPC callback.")
@@ -496,6 +508,12 @@
with self._operational_lock:
_LOGGER.debug("Call to retryable %r caused %s.", method, exc)
+ if self._should_terminate(exc):
+ self.close()
+ _LOGGER.debug("Terminating %r due to %s.", method, exc)
+ self._finalize(exc)
+ break
+
if not self._should_recover(exc):
self.close()
_LOGGER.debug("Not retrying %r due to %s.", method, exc)
diff --git a/tests/unit/test_bidi.py b/tests/unit/test_bidi.py
index 8e9f262..4d185d3 100644
--- a/tests/unit/test_bidi.py
+++ b/tests/unit/test_bidi.py
@@ -370,16 +370,65 @@
class TestResumableBidiRpc(object):
- def test_initial_state(self):
- callback = mock.Mock()
- callback.return_value = True
- bidi_rpc = bidi.ResumableBidiRpc(None, callback)
+ def test_ctor_defaults(self):
+ start_rpc = mock.Mock()
+ should_recover = mock.Mock()
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
assert bidi_rpc.is_active is False
+ assert bidi_rpc._finalized is False
+ assert bidi_rpc._start_rpc is start_rpc
+ assert bidi_rpc._should_recover is should_recover
+ assert bidi_rpc._should_terminate is bidi._never_terminate
+ assert bidi_rpc._initial_request is None
+ assert bidi_rpc._rpc_metadata is None
+ assert bidi_rpc._reopen_throttle is None
+
+ def test_ctor_explicit(self):
+ start_rpc = mock.Mock()
+ should_recover = mock.Mock()
+ should_terminate = mock.Mock()
+ initial_request = mock.Mock()
+ metadata = {"x-foo": "bar"}
+ bidi_rpc = bidi.ResumableBidiRpc(
+ start_rpc,
+ should_recover,
+ should_terminate=should_terminate,
+ initial_request=initial_request,
+ metadata=metadata,
+ throttle_reopen=True,
+ )
+
+ assert bidi_rpc.is_active is False
+ assert bidi_rpc._finalized is False
+ assert bidi_rpc._should_recover is should_recover
+ assert bidi_rpc._should_terminate is should_terminate
+ assert bidi_rpc._initial_request is initial_request
+ assert bidi_rpc._rpc_metadata == metadata
+ assert isinstance(bidi_rpc._reopen_throttle, bidi._Throttle)
+
+ def test_done_callbacks_terminate(self):
+ cancellation = mock.Mock()
+ start_rpc = mock.Mock()
+ should_recover = mock.Mock(spec=["__call__"], return_value=True)
+ should_terminate = mock.Mock(spec=["__call__"], return_value=True)
+ bidi_rpc = bidi.ResumableBidiRpc(
+ start_rpc, should_recover, should_terminate=should_terminate
+ )
+ callback = mock.Mock(spec=["__call__"])
+
+ bidi_rpc.add_done_callback(callback)
+ bidi_rpc._on_call_done(cancellation)
+
+ should_terminate.assert_called_once_with(cancellation)
+ should_recover.assert_not_called()
+ callback.assert_called_once_with(cancellation)
+ assert not bidi_rpc.is_active
def test_done_callbacks_recoverable(self):
start_rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True)
- bidi_rpc = bidi.ResumableBidiRpc(start_rpc, lambda _: True)
+ should_recover = mock.Mock(spec=["__call__"], return_value=True)
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
callback = mock.Mock(spec=["__call__"])
bidi_rpc.add_done_callback(callback)
@@ -387,16 +436,45 @@
callback.assert_not_called()
start_rpc.assert_called_once()
+ should_recover.assert_called_once_with(mock.sentinel.future)
assert bidi_rpc.is_active
def test_done_callbacks_non_recoverable(self):
- bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: False)
+ start_rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True)
+ should_recover = mock.Mock(spec=["__call__"], return_value=False)
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
callback = mock.Mock(spec=["__call__"])
bidi_rpc.add_done_callback(callback)
bidi_rpc._on_call_done(mock.sentinel.future)
callback.assert_called_once_with(mock.sentinel.future)
+ should_recover.assert_called_once_with(mock.sentinel.future)
+ assert not bidi_rpc.is_active
+
+ def test_send_terminate(self):
+ cancellation = ValueError()
+ call_1 = CallStub([cancellation], active=False)
+ call_2 = CallStub([])
+ start_rpc = mock.create_autospec(
+ grpc.StreamStreamMultiCallable, instance=True, side_effect=[call_1, call_2]
+ )
+ should_recover = mock.Mock(spec=["__call__"], return_value=False)
+ should_terminate = mock.Mock(spec=["__call__"], return_value=True)
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover, should_terminate=should_terminate)
+
+ bidi_rpc.open()
+
+ bidi_rpc.send(mock.sentinel.request)
+
+ assert bidi_rpc.pending_requests == 1
+ assert bidi_rpc._request_queue.get() is None
+
+ should_recover.assert_not_called()
+ should_terminate.assert_called_once_with(cancellation)
+ assert bidi_rpc.call == call_1
+ assert bidi_rpc.is_active is False
+ assert call_1.cancelled is True
def test_send_recover(self):
error = ValueError()
@@ -441,6 +519,26 @@
assert bidi_rpc.pending_requests == 1
assert bidi_rpc._request_queue.get() is None
+ def test_recv_terminate(self):
+ cancellation = ValueError()
+ call = CallStub([cancellation])
+ start_rpc = mock.create_autospec(
+ grpc.StreamStreamMultiCallable, instance=True, return_value=call
+ )
+ should_recover = mock.Mock(spec=["__call__"], return_value=False)
+ should_terminate = mock.Mock(spec=["__call__"], return_value=True)
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover, should_terminate=should_terminate)
+
+ bidi_rpc.open()
+
+ bidi_rpc.recv()
+
+ should_recover.assert_not_called()
+ should_terminate.assert_called_once_with(cancellation)
+ assert bidi_rpc.call == call
+ assert bidi_rpc.is_active is False
+ assert call.cancelled is True
+
def test_recv_recover(self):
error = ValueError()
call_1 = CallStub([1, error])