Move bidi to api-core (#6211)


diff --git a/google/api_core/bidi.py b/google/api_core/bidi.py
new file mode 100644
index 0000000..7c995c5
--- /dev/null
+++ b/google/api_core/bidi.py
@@ -0,0 +1,597 @@
+# Copyright 2017, Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Bi-directional streaming RPC helpers."""
+
+import logging
+import threading
+
+from six.moves import queue
+
+from google.api_core import exceptions
+
+_LOGGER = logging.getLogger(__name__)
+_BIDIRECTIONAL_CONSUMER_NAME = 'Thread-ConsumeBidirectionalStream'
+
+
+class _RequestQueueGenerator(object):
+    """A helper for sending requests to a gRPC stream from a Queue.
+
+    This generator takes requests off a given queue and yields them to gRPC.
+
+    This helper is useful when you have an indeterminate, indefinite, or
+    otherwise open-ended set of requests to send through a request-streaming
+    (or bidirectional) RPC.
+
+    The reason this is necessary is because gRPC takes an iterator as the
+    request for request-streaming RPCs. gRPC consumes this iterator in another
+    thread to allow it to block while generating requests for the stream.
+    However, if the generator blocks indefinitely gRPC will not be able to
+    clean up the thread as it'll be blocked on `next(iterator)` and not be able
+    to check the channel status to stop iterating. This helper mitigates that
+    by waiting on the queue with a timeout and checking the RPC state before
+    yielding.
+
+    Finally, it allows for retrying without swapping queues because if it does
+    pull an item off the queue when the RPC is inactive, it'll immediately put
+    it back and then exit. This is necessary because yielding the item in this
+    case will cause gRPC to discard it. In practice, this means that the order
+    of messages is not guaranteed. If such a thing is necessary it would be
+    easy to use a priority queue.
+
+    Example::
+
+        requests = request_queue_generator(q)
+        call = stub.StreamingRequest(iter(requests))
+        requests.call = call
+
+        for response in call:
+            print(response)
+            q.put(...)
+
+    Note that it is possible to accomplish this behavior without "spinning"
+    (using a queue timeout). One possible way would be to use more threads to
+    multiplex the grpc end event with the queue, another possible way is to
+    use selectors and a custom event/queue object. Both of these approaches
+    are significant from an engineering perspective for small benefit - the
+    CPU consumed by spinning is pretty minuscule.
+
+    Args:
+        queue (queue.Queue): The request queue.
+        period (float): The number of seconds to wait for items from the queue
+            before checking if the RPC is cancelled. In practice, this
+            determines the maximum amount of time the request consumption
+            thread will live after the RPC is cancelled.
+        initial_request (Union[protobuf.Message,
+                Callable[None, protobuf.Message]]): The initial request to
+            yield. This is done independently of the request queue to allow fo
+            easily restarting streams that require some initial configuration
+            request.
+    """
+    def __init__(self, queue, period=1, initial_request=None):
+        self._queue = queue
+        self._period = period
+        self._initial_request = initial_request
+        self.call = None
+
+    def _is_active(self):
+        # Note: there is a possibility that this starts *before* the call
+        # property is set. So we have to check if self.call is set before
+        # seeing if it's active.
+        if self.call is not None and not self.call.is_active():
+            return False
+        else:
+            return True
+
+    def __iter__(self):
+        if self._initial_request is not None:
+            if callable(self._initial_request):
+                yield self._initial_request()
+            else:
+                yield self._initial_request
+
+        while True:
+            try:
+                item = self._queue.get(timeout=self._period)
+            except queue.Empty:
+                if not self._is_active():
+                    _LOGGER.debug(
+                        'Empty queue and inactive call, exiting request '
+                        'generator.')
+                    return
+                else:
+                    # call is still active, keep waiting for queue items.
+                    continue
+
+            # The consumer explicitly sent "None", indicating that the request
+            # should end.
+            if item is None:
+                _LOGGER.debug('Cleanly exiting request generator.')
+                return
+
+            if not self._is_active():
+                # We have an item, but the call is closed. We should put the
+                # item back on the queue so that the next call can consume it.
+                self._queue.put(item)
+                _LOGGER.debug(
+                    'Inactive call, replacing item on queue and exiting '
+                    'request generator.')
+                return
+
+            yield item
+
+
+class BidiRpc(object):
+    """A helper for consuming a bi-directional streaming RPC.
+
+    This maps gRPC's built-in interface which uses a request iterator and a
+    response iterator into a socket-like :func:`send` and :func:`recv`. This
+    is a more useful pattern for long-running or asymmetric streams (streams
+    where there is not a direct correlation between the requests and
+    responses).
+
+    Example::
+
+        initial_request = example_pb2.StreamingRpcRequest(
+            setting='example')
+        rpc = BidiRpc(stub.StreamingRpc, initial_request=initial_request)
+
+        rpc.open()
+
+        while rpc.is_active():
+            print(rpc.recv())
+            rpc.send(example_pb2.StreamingRpcRequest(
+                data='example'))
+
+    This does *not* retry the stream on errors. See :class:`ResumableBidiRpc`.
+
+    Args:
+        start_rpc (grpc.StreamStreamMultiCallable): The gRPC method used to
+            start the RPC.
+        initial_request (Union[protobuf.Message,
+                Callable[None, protobuf.Message]]): The initial request to
+            yield. This is useful if an initial request is needed to start the
+            stream.
+    """
+    def __init__(self, start_rpc, initial_request=None):
+        self._start_rpc = start_rpc
+        self._initial_request = initial_request
+        self._request_queue = queue.Queue()
+        self._request_generator = None
+        self._is_active = False
+        self._callbacks = []
+        self.call = None
+
+    def add_done_callback(self, callback):
+        """Adds a callback that will be called when the RPC terminates.
+
+        This occurs when the RPC errors or is successfully terminated.
+
+        Args:
+            callback (Callable[[grpc.Future], None]): The callback to execute.
+                It will be provided with the same gRPC future as the underlying
+                stream which will also be a :class:`grpc.Call`.
+        """
+        self._callbacks.append(callback)
+
+    def _on_call_done(self, future):
+        for callback in self._callbacks:
+            callback(future)
+
+    def open(self):
+        """Opens the stream."""
+        if self.is_active:
+            raise ValueError('Can not open an already open stream.')
+
+        request_generator = _RequestQueueGenerator(
+            self._request_queue, initial_request=self._initial_request)
+        call = self._start_rpc(iter(request_generator))
+
+        request_generator.call = call
+
+        # TODO: api_core should expose the future interface for wrapped
+        # callables as well.
+        if hasattr(call, '_wrapped'):  # pragma: NO COVER
+            call._wrapped.add_done_callback(self._on_call_done)
+        else:
+            call.add_done_callback(self._on_call_done)
+
+        self._request_generator = request_generator
+        self.call = call
+
+    def close(self):
+        """Closes the stream."""
+        if self.call is None:
+            return
+
+        self._request_queue.put(None)
+        self.call.cancel()
+        self._request_generator = None
+        # Don't set self.call to None. Keep it around so that send/recv can
+        # raise the error.
+
+    def send(self, request):
+        """Queue a message to be sent on the stream.
+
+        Send is non-blocking.
+
+        If the underlying RPC has been closed, this will raise.
+
+        Args:
+            request (protobuf.Message): The request to send.
+        """
+        if self.call is None:
+            raise ValueError(
+                'Can not send() on an RPC that has never been open()ed.')
+
+        # Don't use self.is_active(), as ResumableBidiRpc will overload it
+        # to mean something semantically different.
+        if self.call.is_active():
+            self._request_queue.put(request)
+        else:
+            # calling next should cause the call to raise.
+            next(self.call)
+
+    def recv(self):
+        """Wait for a message to be returned from the stream.
+
+        Recv is blocking.
+
+        If the underlying RPC has been closed, this will raise.
+
+        Returns:
+            protobuf.Message: The received message.
+        """
+        if self.call is None:
+            raise ValueError(
+                'Can not recv() on an RPC that has never been open()ed.')
+
+        return next(self.call)
+
+    @property
+    def is_active(self):
+        """bool: True if this stream is currently open and active."""
+        return self.call is not None and self.call.is_active()
+
+    @property
+    def pending_requests(self):
+        """int: Returns an estimate of the number of queued requests."""
+        return self._request_queue.qsize()
+
+
+class ResumableBidiRpc(BidiRpc):
+    """A :class:`BidiRpc` that can automatically resume the stream on errors.
+
+    It uses the ``should_recover`` arg to determine if it should re-establish
+    the stream on error.
+
+    Example::
+
+        def should_recover(exc):
+            return (
+                isinstance(exc, grpc.RpcError) and
+                exc.code() == grpc.StatusCode.UNVAILABLE)
+
+        initial_request = example_pb2.StreamingRpcRequest(
+            setting='example')
+
+        rpc = ResumeableBidiRpc(
+            stub.StreamingRpc,
+            initial_request=initial_request,
+            should_recover=should_recover)
+
+        rpc.open()
+
+        while rpc.is_active():
+            print(rpc.recv())
+            rpc.send(example_pb2.StreamingRpcRequest(
+                data='example'))
+
+    Args:
+        start_rpc (grpc.StreamStreamMultiCallable): The gRPC method used to
+            start the RPC.
+        initial_request (Union[protobuf.Message,
+                Callable[None, protobuf.Message]]): The initial request to
+            yield. This is useful if an initial request is needed to start the
+            stream.
+        should_recover (Callable[[Exception], bool]): A function that returns
+            True if the stream should be recovered. This will be called
+            whenever an error is encountered on the stream.
+    """
+    def __init__(self, start_rpc, should_recover, initial_request=None):
+        super(ResumableBidiRpc, self).__init__(start_rpc, initial_request)
+        self._should_recover = should_recover
+        self._operational_lock = threading.RLock()
+        self._finalized = False
+        self._finalize_lock = threading.Lock()
+
+    def _finalize(self, result):
+        with self._finalize_lock:
+            if self._finalized:
+                return
+
+            for callback in self._callbacks:
+                callback(result)
+
+            self._finalized = True
+
+    def _on_call_done(self, future):
+        # Unlike the base class, we only execute the callbacks on a terminal
+        # error, not for errors that we can recover from. Note that grpc's
+        # "future" here is also a grpc.RpcError.
+        with self._operational_lock:
+            if not self._should_recover(future):
+                self._finalize(future)
+            else:
+                _LOGGER.debug('Re-opening stream from gRPC callback.')
+                self._reopen()
+
+    def _reopen(self):
+        with self._operational_lock:
+            # Another thread already managed to re-open this stream.
+            if self.call is not None and self.call.is_active():
+                _LOGGER.debug('Stream was already re-established.')
+                return
+
+            self.call = None
+            # Request generator should exit cleanly since the RPC its bound to
+            # has exited.
+            self.request_generator = None
+
+            # Note: we do not currently do any sort of backoff here. The
+            # assumption is that re-establishing the stream under normal
+            # circumstances will happen in intervals greater than 60s.
+            # However, it is possible in a degenerative case that the server
+            # closes the stream rapidly which would lead to thrashing here,
+            # but hopefully in those cases the server would return a non-
+            # retryable error.
+
+            try:
+                self.open()
+            # If re-opening or re-calling the method fails for any reason,
+            # consider it a terminal error and finalize the stream.
+            except Exception as exc:
+                _LOGGER.debug('Failed to re-open stream due to %s', exc)
+                self._finalize(exc)
+                raise
+
+            _LOGGER.info('Re-established stream')
+
+    def _recoverable(self, method, *args, **kwargs):
+        """Wraps a method to recover the stream and retry on error.
+
+        If a retryable error occurs while making the call, then the stream will
+        be re-opened and the method will be retried. This happens indefinitely
+        so long as the error is a retryable one. If an error occurs while
+        re-opening the stream, then this method will raise immediately and
+        trigger finalization of this object.
+
+        Args:
+            method (Callable[..., Any]): The method to call.
+            args: The args to pass to the method.
+            kwargs: The kwargs to pass to the method.
+        """
+        while True:
+            try:
+                return method(*args, **kwargs)
+
+            except Exception as exc:
+                with self._operational_lock:
+                    _LOGGER.debug(
+                        'Call to retryable %r caused %s.', method, exc)
+
+                    if not self._should_recover(exc):
+                        self.close()
+                        _LOGGER.debug(
+                            'Not retrying %r due to %s.', method, exc)
+                        self._finalize(exc)
+                        raise exc
+
+                    _LOGGER.debug(
+                        'Re-opening stream from retryable %r.', method)
+                    self._reopen()
+
+    def _send(self, request):
+        # Grab a reference to the RPC call. Because another thread (notably
+        # the gRPC error thread) can modify self.call (by invoking reopen),
+        # we should ensure our reference can not change underneath us.
+        # If self.call is modified (such as replaced with a new RPC call) then
+        # this will use the "old" RPC, which should result in the same
+        # exception passed into gRPC's error handler being raised here, which
+        # will be handled by the usual error handling in retryable.
+        with self._operational_lock:
+            call = self.call
+
+        if call is None:
+            raise ValueError(
+                'Can not send() on an RPC that has never been open()ed.')
+
+        # Don't use self.is_active(), as ResumableBidiRpc will overload it
+        # to mean something semantically different.
+        if call.is_active():
+            self._request_queue.put(request)
+            pass
+        else:
+            # calling next should cause the call to raise.
+            next(call)
+
+    def send(self, request):
+        return self._recoverable(self._send, request)
+
+    def _recv(self):
+        with self._operational_lock:
+            call = self.call
+
+        if call is None:
+            raise ValueError(
+                'Can not recv() on an RPC that has never been open()ed.')
+
+        return next(call)
+
+    def recv(self):
+        return self._recoverable(self._recv)
+
+    @property
+    def is_active(self):
+        """bool: True if this stream is currently open and active."""
+        # Use the operational lock. It's entirely possible for something
+        # to check the active state *while* the RPC is being retried.
+        # Also, use finalized to track the actual terminal state here.
+        # This is because if the stream is re-established by the gRPC thread
+        # it's technically possible to check this between when gRPC marks the
+        # RPC as inactive and when gRPC executes our callback that re-opens
+        # the stream.
+        with self._operational_lock:
+            return self.call is not None and not self._finalized
+
+
+class BackgroundConsumer(object):
+    """A bi-directional stream consumer that runs in a separate thread.
+
+    This maps the consumption of a stream into a callback-based model. It also
+    provides :func:`pause` and :func:`resume` to allow for flow-control.
+
+    Example::
+
+        def should_recover(exc):
+            return (
+                isinstance(exc, grpc.RpcError) and
+                exc.code() == grpc.StatusCode.UNVAILABLE)
+
+        initial_request = example_pb2.StreamingRpcRequest(
+            setting='example')
+
+        rpc = ResumeableBidiRpc(
+            stub.StreamingRpc,
+            initial_request=initial_request,
+            should_recover=should_recover)
+
+        def on_response(response):
+            print(response)
+
+        consumer = BackgroundConsumer(rpc, on_response)
+        consume.start()
+
+    Note that error handling *must* be done by using the provided
+    ``bidi_rpc``'s ``add_done_callback``. This helper will automatically exit
+    whenever the RPC itself exits and will not provide any error details.
+
+    Args:
+        bidi_rpc (BidiRpc): The RPC to consume. Should not have been
+            ``open()``ed yet.
+        on_response (Callable[[protobuf.Message], None]): The callback to
+            be called for every response on the stream.
+    """
+    def __init__(self, bidi_rpc, on_response):
+        self._bidi_rpc = bidi_rpc
+        self._on_response = on_response
+        self._paused = False
+        self._wake = threading.Condition()
+        self._thread = None
+        self._operational_lock = threading.Lock()
+
+    def _on_call_done(self, future):
+        # Resume the thread if it's paused, this prevents blocking forever
+        # when the RPC has terminated.
+        self.resume()
+
+    def _thread_main(self):
+        try:
+            self._bidi_rpc.add_done_callback(self._on_call_done)
+            self._bidi_rpc.open()
+
+            while self._bidi_rpc.is_active:
+                # Do not allow the paused status to change at all during this
+                # section. There is a condition where we could be resumed
+                # between checking if we are paused and calling wake.wait(),
+                # which means that we will miss the notification to wake up
+                # (oops!) and wait for a notification that will never come.
+                # Keeping the lock throughout avoids that.
+                # In the future, we could use `Condition.wait_for` if we drop
+                # Python 2.7.
+                with self._wake:
+                    if self._paused:
+                        _LOGGER.debug('paused, waiting for waking.')
+                        self._wake.wait()
+                        _LOGGER.debug('woken.')
+
+                _LOGGER.debug('waiting for recv.')
+                response = self._bidi_rpc.recv()
+                _LOGGER.debug('recved response.')
+                self._on_response(response)
+
+        except exceptions.GoogleAPICallError as exc:
+            _LOGGER.debug(
+                '%s caught error %s and will exit. Generally this is due to '
+                'the RPC itself being cancelled and the error will be '
+                'surfaced to the calling code.',
+                _BIDIRECTIONAL_CONSUMER_NAME, exc, exc_info=True)
+
+        except Exception as exc:
+            _LOGGER.exception(
+                '%s caught unexpected exception %s and will exit.',
+                _BIDIRECTIONAL_CONSUMER_NAME, exc)
+
+        else:
+            _LOGGER.error(
+                'The bidirectional RPC exited.')
+
+        _LOGGER.info('%s exiting', _BIDIRECTIONAL_CONSUMER_NAME)
+
+    def start(self):
+        """Start the background thread and begin consuming the thread."""
+        with self._operational_lock:
+            thread = threading.Thread(
+                name=_BIDIRECTIONAL_CONSUMER_NAME,
+                target=self._thread_main)
+            thread.daemon = True
+            thread.start()
+            self._thread = thread
+            _LOGGER.debug('Started helper thread %s', thread.name)
+
+    def stop(self):
+        """Stop consuming the stream and shutdown the background thread."""
+        with self._operational_lock:
+            self._bidi_rpc.close()
+
+            if self._thread is not None:
+                # Resume the thread to wake it up in case it is sleeping.
+                self.resume()
+                self._thread.join()
+
+            self._thread = None
+
+    @property
+    def is_active(self):
+        """bool: True if the background thread is active."""
+        return self._thread is not None and self._thread.is_alive()
+
+    def pause(self):
+        """Pauses the response stream.
+
+        This does *not* pause the request stream.
+        """
+        with self._wake:
+            self._paused = True
+
+    def resume(self):
+        """Resumes the response stream."""
+        with self._wake:
+            self._paused = False
+            self._wake.notifyAll()
+
+    @property
+    def is_paused(self):
+        """bool: True if the response stream is paused."""
+        return self._paused
diff --git a/tests/unit/test_bidi.py b/tests/unit/test_bidi.py
new file mode 100644
index 0000000..a377706
--- /dev/null
+++ b/tests/unit/test_bidi.py
@@ -0,0 +1,650 @@
+# Copyright 2018, Google LLC All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import threading
+
+import grpc
+import mock
+import pytest
+from six.moves import queue
+
+from google.api_core import exceptions
+from google.api_core import bidi
+
+
+class Test_RequestQueueGenerator(object):
+
+    def test_bounded_consume(self):
+        call = mock.create_autospec(grpc.Call, instance=True)
+        call.is_active.return_value = True
+
+        def queue_generator(rpc):
+            yield mock.sentinel.A
+            yield queue.Empty()
+            yield mock.sentinel.B
+            rpc.is_active.return_value = False
+            yield mock.sentinel.C
+
+        q = mock.create_autospec(queue.Queue, instance=True)
+        q.get.side_effect = queue_generator(call)
+
+        generator = bidi._RequestQueueGenerator(q)
+        generator.call = call
+
+        items = list(generator)
+
+        assert items == [mock.sentinel.A, mock.sentinel.B]
+
+    def test_yield_initial_and_exit(self):
+        q = mock.create_autospec(queue.Queue, instance=True)
+        q.get.side_effect = queue.Empty()
+        call = mock.create_autospec(grpc.Call, instance=True)
+        call.is_active.return_value = False
+
+        generator = bidi._RequestQueueGenerator(
+            q, initial_request=mock.sentinel.A)
+        generator.call = call
+
+        items = list(generator)
+
+        assert items == [mock.sentinel.A]
+
+    def test_yield_initial_callable_and_exit(self):
+        q = mock.create_autospec(queue.Queue, instance=True)
+        q.get.side_effect = queue.Empty()
+        call = mock.create_autospec(grpc.Call, instance=True)
+        call.is_active.return_value = False
+
+        generator = bidi._RequestQueueGenerator(
+            q, initial_request=lambda: mock.sentinel.A)
+        generator.call = call
+
+        items = list(generator)
+
+        assert items == [mock.sentinel.A]
+
+    def test_exit_when_inactive_with_item(self):
+        q = mock.create_autospec(queue.Queue, instance=True)
+        q.get.side_effect = [mock.sentinel.A, queue.Empty()]
+        call = mock.create_autospec(grpc.Call, instance=True)
+        call.is_active.return_value = False
+
+        generator = bidi._RequestQueueGenerator(q)
+        generator.call = call
+
+        items = list(generator)
+
+        assert items == []
+        # Make sure it put the item back.
+        q.put.assert_called_once_with(mock.sentinel.A)
+
+    def test_exit_when_inactive_empty(self):
+        q = mock.create_autospec(queue.Queue, instance=True)
+        q.get.side_effect = queue.Empty()
+        call = mock.create_autospec(grpc.Call, instance=True)
+        call.is_active.return_value = False
+
+        generator = bidi._RequestQueueGenerator(q)
+        generator.call = call
+
+        items = list(generator)
+
+        assert items == []
+
+    def test_exit_with_stop(self):
+        q = mock.create_autospec(queue.Queue, instance=True)
+        q.get.side_effect = [None, queue.Empty()]
+        call = mock.create_autospec(grpc.Call, instance=True)
+        call.is_active.return_value = True
+
+        generator = bidi._RequestQueueGenerator(q)
+        generator.call = call
+
+        items = list(generator)
+
+        assert items == []
+
+
+class _CallAndFuture(grpc.Call, grpc.Future):
+    pass
+
+
+def make_rpc():
+    """Makes a mock RPC used to test Bidi classes."""
+    call = mock.create_autospec(_CallAndFuture, instance=True)
+    rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True)
+
+    def rpc_side_effect(request):
+        call.is_active.return_value = True
+        call.request = request
+        return call
+
+    rpc.side_effect = rpc_side_effect
+
+    def cancel_side_effect():
+        call.is_active.return_value = False
+
+    call.cancel.side_effect = cancel_side_effect
+
+    return rpc, call
+
+
+class ClosedCall(object):
+    # NOTE: This is needed because defining `.next` on an **instance**
+    #       rather than the **class** will not be iterable in Python 2.
+    #       This is problematic since a `Mock` just sets members.
+
+    def __init__(self, exception):
+        self.exception = exception
+
+    def __next__(self):
+        raise self.exception
+
+    next = __next__  # Python 2
+
+    def is_active(self):
+        return False
+
+
+class TestBidiRpc(object):
+    def test_initial_state(self):
+        bidi_rpc = bidi.BidiRpc(None)
+
+        assert bidi_rpc.is_active is False
+
+    def test_done_callbacks(self):
+        bidi_rpc = bidi.BidiRpc(None)
+        callback = mock.Mock(spec=['__call__'])
+
+        bidi_rpc.add_done_callback(callback)
+        bidi_rpc._on_call_done(mock.sentinel.future)
+
+        callback.assert_called_once_with(mock.sentinel.future)
+
+    def test_open(self):
+        rpc, call = make_rpc()
+        bidi_rpc = bidi.BidiRpc(rpc)
+
+        bidi_rpc.open()
+
+        assert bidi_rpc.call == call
+        assert bidi_rpc.is_active
+        call.add_done_callback.assert_called_once_with(bidi_rpc._on_call_done)
+
+    def test_open_error_already_open(self):
+        rpc, _ = make_rpc()
+        bidi_rpc = bidi.BidiRpc(rpc)
+
+        bidi_rpc.open()
+
+        with pytest.raises(ValueError):
+            bidi_rpc.open()
+
+    def test_close(self):
+        rpc, call = make_rpc()
+        bidi_rpc = bidi.BidiRpc(rpc)
+        bidi_rpc.open()
+
+        bidi_rpc.close()
+
+        call.cancel.assert_called_once()
+        assert bidi_rpc.call == call
+        assert bidi_rpc.is_active is False
+        # ensure the request queue was signaled to stop.
+        assert bidi_rpc.pending_requests == 1
+        assert bidi_rpc._request_queue.get() is None
+
+    def test_close_no_rpc(self):
+        bidi_rpc = bidi.BidiRpc(None)
+        bidi_rpc.close()
+
+    def test_send(self):
+        rpc, call = make_rpc()
+        bidi_rpc = bidi.BidiRpc(rpc)
+        bidi_rpc.open()
+
+        bidi_rpc.send(mock.sentinel.request)
+
+        assert bidi_rpc.pending_requests == 1
+        assert bidi_rpc._request_queue.get() is mock.sentinel.request
+
+    def test_send_not_open(self):
+        rpc, call = make_rpc()
+        bidi_rpc = bidi.BidiRpc(rpc)
+
+        with pytest.raises(ValueError):
+            bidi_rpc.send(mock.sentinel.request)
+
+    def test_send_dead_rpc(self):
+        error = ValueError()
+        bidi_rpc = bidi.BidiRpc(None)
+        bidi_rpc.call = ClosedCall(error)
+
+        with pytest.raises(ValueError) as exc_info:
+            bidi_rpc.send(mock.sentinel.request)
+
+        assert exc_info.value == error
+
+    def test_recv(self):
+        bidi_rpc = bidi.BidiRpc(None)
+        bidi_rpc.call = iter([mock.sentinel.response])
+
+        response = bidi_rpc.recv()
+
+        assert response == mock.sentinel.response
+
+    def test_recv_not_open(self):
+        rpc, call = make_rpc()
+        bidi_rpc = bidi.BidiRpc(rpc)
+
+        with pytest.raises(ValueError):
+            bidi_rpc.recv()
+
+
+class CallStub(object):
+    def __init__(self, values, active=True):
+        self.values = iter(values)
+        self._is_active = active
+        self.cancelled = False
+
+    def __next__(self):
+        item = next(self.values)
+        if isinstance(item, Exception):
+            self._is_active = False
+            raise item
+        return item
+
+    next = __next__  # Python 2
+
+    def is_active(self):
+        return self._is_active
+
+    def add_done_callback(self, callback):
+        pass
+
+    def cancel(self):
+        self.cancelled = True
+
+
+class TestResumableBidiRpc(object):
+    def test_initial_state(self):
+        bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: True)
+
+        assert bidi_rpc.is_active is False
+
+    def test_done_callbacks_recoverable(self):
+        start_rpc = mock.create_autospec(
+            grpc.StreamStreamMultiCallable, instance=True)
+        bidi_rpc = bidi.ResumableBidiRpc(start_rpc, lambda _: True)
+        callback = mock.Mock(spec=['__call__'])
+
+        bidi_rpc.add_done_callback(callback)
+        bidi_rpc._on_call_done(mock.sentinel.future)
+
+        callback.assert_not_called()
+        start_rpc.assert_called_once()
+        assert bidi_rpc.is_active
+
+    def test_done_callbacks_non_recoverable(self):
+        bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: False)
+        callback = mock.Mock(spec=['__call__'])
+
+        bidi_rpc.add_done_callback(callback)
+        bidi_rpc._on_call_done(mock.sentinel.future)
+
+        callback.assert_called_once_with(mock.sentinel.future)
+
+    def test_send_recover(self):
+        error = ValueError()
+        call_1 = CallStub([error], active=False)
+        call_2 = CallStub([])
+        start_rpc = mock.create_autospec(
+            grpc.StreamStreamMultiCallable,
+            instance=True,
+            side_effect=[call_1, call_2])
+        should_recover = mock.Mock(spec=['__call__'], return_value=True)
+        bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
+
+        bidi_rpc.open()
+
+        bidi_rpc.send(mock.sentinel.request)
+
+        assert bidi_rpc.pending_requests == 1
+        assert bidi_rpc._request_queue.get() is mock.sentinel.request
+
+        should_recover.assert_called_once_with(error)
+        assert bidi_rpc.call == call_2
+        assert bidi_rpc.is_active is True
+
+    def test_send_failure(self):
+        error = ValueError()
+        call = CallStub([error], active=False)
+        start_rpc = mock.create_autospec(
+            grpc.StreamStreamMultiCallable,
+            instance=True,
+            return_value=call)
+        should_recover = mock.Mock(spec=['__call__'], return_value=False)
+        bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
+
+        bidi_rpc.open()
+
+        with pytest.raises(ValueError) as exc_info:
+            bidi_rpc.send(mock.sentinel.request)
+
+        assert exc_info.value == error
+        should_recover.assert_called_once_with(error)
+        assert bidi_rpc.call == call
+        assert bidi_rpc.is_active is False
+        assert call.cancelled is True
+        assert bidi_rpc.pending_requests == 1
+        assert bidi_rpc._request_queue.get() is None
+
+    def test_recv_recover(self):
+        error = ValueError()
+        call_1 = CallStub([1, error])
+        call_2 = CallStub([2, 3])
+        start_rpc = mock.create_autospec(
+            grpc.StreamStreamMultiCallable,
+            instance=True,
+            side_effect=[call_1, call_2])
+        should_recover = mock.Mock(spec=['__call__'], return_value=True)
+        bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
+
+        bidi_rpc.open()
+
+        values = []
+        for n in range(3):
+            values.append(bidi_rpc.recv())
+
+        assert values == [1, 2, 3]
+        should_recover.assert_called_once_with(error)
+        assert bidi_rpc.call == call_2
+        assert bidi_rpc.is_active is True
+
+    def test_recv_recover_already_recovered(self):
+        call_1 = CallStub([])
+        call_2 = CallStub([])
+        start_rpc = mock.create_autospec(
+            grpc.StreamStreamMultiCallable,
+            instance=True,
+            side_effect=[call_1, call_2])
+        bidi_rpc = bidi.ResumableBidiRpc(start_rpc, lambda _: True)
+
+        bidi_rpc.open()
+
+        bidi_rpc._reopen()
+
+        assert bidi_rpc.call is call_1
+        assert bidi_rpc.is_active is True
+
+    def test_recv_failure(self):
+        error = ValueError()
+        call = CallStub([error])
+        start_rpc = mock.create_autospec(
+            grpc.StreamStreamMultiCallable,
+            instance=True,
+            return_value=call)
+        should_recover = mock.Mock(spec=['__call__'], return_value=False)
+        bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
+
+        bidi_rpc.open()
+
+        with pytest.raises(ValueError) as exc_info:
+            bidi_rpc.recv()
+
+        assert exc_info.value == error
+        should_recover.assert_called_once_with(error)
+        assert bidi_rpc.call == call
+        assert bidi_rpc.is_active is False
+        assert call.cancelled is True
+
+    def test_reopen_failure_on_rpc_restart(self):
+        error1 = ValueError('1')
+        error2 = ValueError('2')
+        call = CallStub([error1])
+        # Invoking start RPC a second time will trigger an error.
+        start_rpc = mock.create_autospec(
+            grpc.StreamStreamMultiCallable,
+            instance=True,
+            side_effect=[call, error2])
+        should_recover = mock.Mock(spec=['__call__'], return_value=True)
+        callback = mock.Mock(spec=['__call__'])
+
+        bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
+        bidi_rpc.add_done_callback(callback)
+
+        bidi_rpc.open()
+
+        with pytest.raises(ValueError) as exc_info:
+            bidi_rpc.recv()
+
+        assert exc_info.value == error2
+        should_recover.assert_called_once_with(error1)
+        assert bidi_rpc.call is None
+        assert bidi_rpc.is_active is False
+        callback.assert_called_once_with(error2)
+
+    def test_send_not_open(self):
+        bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: False)
+
+        with pytest.raises(ValueError):
+            bidi_rpc.send(mock.sentinel.request)
+
+    def test_recv_not_open(self):
+        bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: False)
+
+        with pytest.raises(ValueError):
+            bidi_rpc.recv()
+
+    def test_finalize_idempotent(self):
+        error1 = ValueError('1')
+        error2 = ValueError('2')
+        callback = mock.Mock(spec=['__call__'])
+        should_recover = mock.Mock(spec=['__call__'], return_value=False)
+
+        bidi_rpc = bidi.ResumableBidiRpc(
+            mock.sentinel.start_rpc, should_recover)
+
+        bidi_rpc.add_done_callback(callback)
+
+        bidi_rpc._on_call_done(error1)
+        bidi_rpc._on_call_done(error2)
+
+        callback.assert_called_once_with(error1)
+
+
+class TestBackgroundConsumer(object):
+    def test_consume_once_then_exit(self):
+        bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True)
+        bidi_rpc.is_active = True
+        bidi_rpc.recv.side_effect = [mock.sentinel.response_1]
+        recved = threading.Event()
+
+        def on_response(response):
+            assert response == mock.sentinel.response_1
+            bidi_rpc.is_active = False
+            recved.set()
+
+        consumer = bidi.BackgroundConsumer(bidi_rpc, on_response)
+
+        consumer.start()
+
+        recved.wait()
+
+        bidi_rpc.recv.assert_called_once()
+        assert bidi_rpc.is_active is False
+
+        consumer.stop()
+
+        bidi_rpc.close.assert_called_once()
+        assert consumer.is_active is False
+
+    def test_pause_resume_and_close(self):
+        # This test is relatively complex. It attempts to start the consumer,
+        # consume one item, pause the consumer, check the state of the world,
+        # then resume the consumer. Doing this in a deterministic fashion
+        # requires a bit more mocking and patching than usual.
+
+        bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True)
+        bidi_rpc.is_active = True
+
+        def close_side_effect():
+            bidi_rpc.is_active = False
+
+        bidi_rpc.close.side_effect = close_side_effect
+
+        # These are used to coordinate the two threads to ensure deterministic
+        # execution.
+        should_continue = threading.Event()
+        responses_and_events = {
+            mock.sentinel.response_1: threading.Event(),
+            mock.sentinel.response_2: threading.Event()
+        }
+        bidi_rpc.recv.side_effect = [
+            mock.sentinel.response_1, mock.sentinel.response_2]
+
+        recved_responses = []
+        consumer = None
+
+        def on_response(response):
+            if response == mock.sentinel.response_1:
+                consumer.pause()
+
+            recved_responses.append(response)
+            responses_and_events[response].set()
+            should_continue.wait()
+
+        consumer = bidi.BackgroundConsumer(bidi_rpc, on_response)
+
+        consumer.start()
+
+        # Wait for the first response to be recved.
+        responses_and_events[mock.sentinel.response_1].wait()
+
+        # Ensure only one item has been recved and that the consumer is paused.
+        assert recved_responses == [mock.sentinel.response_1]
+        assert consumer.is_paused is True
+        assert consumer.is_active is True
+
+        # Unpause the consumer, wait for the second item, then close the
+        # consumer.
+        should_continue.set()
+        consumer.resume()
+
+        responses_and_events[mock.sentinel.response_2].wait()
+
+        assert recved_responses == [
+            mock.sentinel.response_1, mock.sentinel.response_2]
+
+        consumer.stop()
+
+        assert consumer.is_active is False
+
+    def test_wake_on_error(self):
+        should_continue = threading.Event()
+
+        bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True)
+        bidi_rpc.is_active = True
+        bidi_rpc.add_done_callback.side_effect = (
+            lambda _: should_continue.set())
+
+        consumer = bidi.BackgroundConsumer(bidi_rpc, mock.sentinel.on_response)
+
+        # Start the consumer paused, which should immediately put it into wait
+        # state.
+        consumer.pause()
+        consumer.start()
+
+        # Wait for add_done_callback to be called
+        should_continue.wait()
+        bidi_rpc.add_done_callback.assert_called_once_with(
+            consumer._on_call_done)
+
+        # The consumer should now be blocked on waiting to be unpaused.
+        assert consumer.is_active
+        assert consumer.is_paused
+
+        # Trigger the done callback, it should unpause the consumer and cause
+        # it to exit.
+        bidi_rpc.is_active = False
+        consumer._on_call_done(bidi_rpc)
+
+        # It may take a few cycles for the thread to exit.
+        while consumer.is_active:
+            pass
+
+    def test_consumer_expected_error(self, caplog):
+        caplog.set_level(logging.DEBUG)
+
+        bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True)
+        bidi_rpc.is_active = True
+        bidi_rpc.recv.side_effect = exceptions.ServiceUnavailable('Gone away')
+
+        on_response = mock.Mock(spec=['__call__'])
+
+        consumer = bidi.BackgroundConsumer(bidi_rpc, on_response)
+
+        consumer.start()
+
+        # Wait for the consumer's thread to exit.
+        while consumer.is_active:
+            pass
+
+        on_response.assert_not_called()
+        bidi_rpc.recv.assert_called_once()
+        assert 'caught error' in caplog.text
+
+    def test_consumer_unexpected_error(self, caplog):
+        caplog.set_level(logging.DEBUG)
+
+        bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True)
+        bidi_rpc.is_active = True
+        bidi_rpc.recv.side_effect = ValueError()
+
+        on_response = mock.Mock(spec=['__call__'])
+
+        consumer = bidi.BackgroundConsumer(bidi_rpc, on_response)
+
+        consumer.start()
+
+        # Wait for the consumer's thread to exit.
+        while consumer.is_active:
+            pass
+
+        on_response.assert_not_called()
+        bidi_rpc.recv.assert_called_once()
+        assert 'caught unexpected exception' in caplog.text
+
+    def test_double_stop(self, caplog):
+        caplog.set_level(logging.DEBUG)
+        bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True)
+        bidi_rpc.is_active = True
+        on_response = mock.Mock(spec=['__call__'])
+
+        def close_side_effect():
+            bidi_rpc.is_active = False
+
+        bidi_rpc.close.side_effect = close_side_effect
+
+        consumer = bidi.BackgroundConsumer(bidi_rpc, on_response)
+
+        consumer.start()
+        assert consumer.is_active is True
+
+        consumer.stop()
+        assert consumer.is_active is False
+
+        # calling stop twice should not result in an error.
+        consumer.stop()