Firestore: Add 'should_terminate' predicate for clean BiDi shutdown. (#8650)

Closes #7826.
diff --git a/google/api_core/bidi.py b/google/api_core/bidi.py
index 3b69e91..7d3716d 100644
--- a/google/api_core/bidi.py
+++ b/google/api_core/bidi.py
@@ -349,6 +349,11 @@
         return self._request_queue.qsize()
 
 
+def _never_terminate(future_or_error):
+    """By default, no errors cause BiDi termination."""
+    return False
+
+
 class ResumableBidiRpc(BidiRpc):
     """A :class:`BidiRpc` that can automatically resume the stream on errors.
 
@@ -391,6 +396,9 @@
         should_recover (Callable[[Exception], bool]): A function that returns
             True if the stream should be recovered. This will be called
             whenever an error is encountered on the stream.
+        should_terminate (Callable[[Exception], bool]): A function that returns
+            True if the stream should be terminated. This will be called
+            whenever an error is encountered on the stream.
         metadata Sequence[Tuple(str, str)]: RPC metadata to include in
             the request.
         throttle_reopen (bool): If ``True``, throttling will be applied to
@@ -401,12 +409,14 @@
         self,
         start_rpc,
         should_recover,
+        should_terminate=_never_terminate,
         initial_request=None,
         metadata=None,
         throttle_reopen=False,
     ):
         super(ResumableBidiRpc, self).__init__(start_rpc, initial_request, metadata)
         self._should_recover = should_recover
+        self._should_terminate = should_terminate
         self._operational_lock = threading.RLock()
         self._finalized = False
         self._finalize_lock = threading.Lock()
@@ -433,7 +443,9 @@
         # error, not for errors that we can recover from. Note that grpc's
         # "future" here is also a grpc.RpcError.
         with self._operational_lock:
-            if not self._should_recover(future):
+            if self._should_terminate(future):
+                self._finalize(future)
+            elif not self._should_recover(future):
                 self._finalize(future)
             else:
                 _LOGGER.debug("Re-opening stream from gRPC callback.")
@@ -496,6 +508,12 @@
                 with self._operational_lock:
                     _LOGGER.debug("Call to retryable %r caused %s.", method, exc)
 
+                    if self._should_terminate(exc):
+                        self.close()
+                        _LOGGER.debug("Terminating %r due to %s.", method, exc)
+                        self._finalize(exc)
+                        break
+
                     if not self._should_recover(exc):
                         self.close()
                         _LOGGER.debug("Not retrying %r due to %s.", method, exc)
diff --git a/tests/unit/test_bidi.py b/tests/unit/test_bidi.py
index 8e9f262..4d185d3 100644
--- a/tests/unit/test_bidi.py
+++ b/tests/unit/test_bidi.py
@@ -370,16 +370,65 @@
 
 
 class TestResumableBidiRpc(object):
-    def test_initial_state(self):
-        callback = mock.Mock()
-        callback.return_value = True
-        bidi_rpc = bidi.ResumableBidiRpc(None, callback)
+    def test_ctor_defaults(self):
+        start_rpc = mock.Mock()
+        should_recover = mock.Mock()
+        bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
 
         assert bidi_rpc.is_active is False
+        assert bidi_rpc._finalized is False
+        assert bidi_rpc._start_rpc is start_rpc
+        assert bidi_rpc._should_recover is should_recover
+        assert bidi_rpc._should_terminate is bidi._never_terminate
+        assert bidi_rpc._initial_request is None
+        assert bidi_rpc._rpc_metadata is None
+        assert bidi_rpc._reopen_throttle is None
+
+    def test_ctor_explicit(self):
+        start_rpc = mock.Mock()
+        should_recover = mock.Mock()
+        should_terminate = mock.Mock()
+        initial_request = mock.Mock()
+        metadata = {"x-foo": "bar"}
+        bidi_rpc = bidi.ResumableBidiRpc(
+            start_rpc,
+            should_recover,
+            should_terminate=should_terminate,
+            initial_request=initial_request,
+            metadata=metadata,
+            throttle_reopen=True,
+        )
+
+        assert bidi_rpc.is_active is False
+        assert bidi_rpc._finalized is False
+        assert bidi_rpc._should_recover is should_recover
+        assert bidi_rpc._should_terminate is should_terminate
+        assert bidi_rpc._initial_request is initial_request
+        assert bidi_rpc._rpc_metadata == metadata
+        assert isinstance(bidi_rpc._reopen_throttle, bidi._Throttle)
+
+    def test_done_callbacks_terminate(self):
+        cancellation = mock.Mock()
+        start_rpc = mock.Mock()
+        should_recover = mock.Mock(spec=["__call__"], return_value=True)
+        should_terminate = mock.Mock(spec=["__call__"], return_value=True)
+        bidi_rpc = bidi.ResumableBidiRpc(
+            start_rpc, should_recover, should_terminate=should_terminate
+        )
+        callback = mock.Mock(spec=["__call__"])
+
+        bidi_rpc.add_done_callback(callback)
+        bidi_rpc._on_call_done(cancellation)
+
+        should_terminate.assert_called_once_with(cancellation)
+        should_recover.assert_not_called()
+        callback.assert_called_once_with(cancellation)
+        assert not bidi_rpc.is_active
 
     def test_done_callbacks_recoverable(self):
         start_rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True)
-        bidi_rpc = bidi.ResumableBidiRpc(start_rpc, lambda _: True)
+        should_recover = mock.Mock(spec=["__call__"], return_value=True)
+        bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
         callback = mock.Mock(spec=["__call__"])
 
         bidi_rpc.add_done_callback(callback)
@@ -387,16 +436,45 @@
 
         callback.assert_not_called()
         start_rpc.assert_called_once()
+        should_recover.assert_called_once_with(mock.sentinel.future)
         assert bidi_rpc.is_active
 
     def test_done_callbacks_non_recoverable(self):
-        bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: False)
+        start_rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True)
+        should_recover = mock.Mock(spec=["__call__"], return_value=False)
+        bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
         callback = mock.Mock(spec=["__call__"])
 
         bidi_rpc.add_done_callback(callback)
         bidi_rpc._on_call_done(mock.sentinel.future)
 
         callback.assert_called_once_with(mock.sentinel.future)
+        should_recover.assert_called_once_with(mock.sentinel.future)
+        assert not bidi_rpc.is_active
+
+    def test_send_terminate(self):
+        cancellation = ValueError()
+        call_1 = CallStub([cancellation], active=False)
+        call_2 = CallStub([])
+        start_rpc = mock.create_autospec(
+            grpc.StreamStreamMultiCallable, instance=True, side_effect=[call_1, call_2]
+        )
+        should_recover = mock.Mock(spec=["__call__"], return_value=False)
+        should_terminate = mock.Mock(spec=["__call__"], return_value=True)
+        bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover, should_terminate=should_terminate)
+
+        bidi_rpc.open()
+
+        bidi_rpc.send(mock.sentinel.request)
+
+        assert bidi_rpc.pending_requests == 1
+        assert bidi_rpc._request_queue.get() is None
+
+        should_recover.assert_not_called()
+        should_terminate.assert_called_once_with(cancellation)
+        assert bidi_rpc.call == call_1
+        assert bidi_rpc.is_active is False
+        assert call_1.cancelled is True
 
     def test_send_recover(self):
         error = ValueError()
@@ -441,6 +519,26 @@
         assert bidi_rpc.pending_requests == 1
         assert bidi_rpc._request_queue.get() is None
 
+    def test_recv_terminate(self):
+        cancellation = ValueError()
+        call = CallStub([cancellation])
+        start_rpc = mock.create_autospec(
+            grpc.StreamStreamMultiCallable, instance=True, return_value=call
+        )
+        should_recover = mock.Mock(spec=["__call__"], return_value=False)
+        should_terminate = mock.Mock(spec=["__call__"], return_value=True)
+        bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover, should_terminate=should_terminate)
+
+        bidi_rpc.open()
+
+        bidi_rpc.recv()
+
+        should_recover.assert_not_called()
+        should_terminate.assert_called_once_with(cancellation)
+        assert bidi_rpc.call == call
+        assert bidi_rpc.is_active is False
+        assert call.cancelled is True
+
     def test_recv_recover(self):
         error = ValueError()
         call_1 = CallStub([1, error])