Move bidi to api-core (#6211)
diff --git a/google/api_core/bidi.py b/google/api_core/bidi.py
new file mode 100644
index 0000000..7c995c5
--- /dev/null
+++ b/google/api_core/bidi.py
@@ -0,0 +1,597 @@
+# Copyright 2017, Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Bi-directional streaming RPC helpers."""
+
+import logging
+import threading
+
+from six.moves import queue
+
+from google.api_core import exceptions
+
+_LOGGER = logging.getLogger(__name__)
+_BIDIRECTIONAL_CONSUMER_NAME = 'Thread-ConsumeBidirectionalStream'
+
+
+class _RequestQueueGenerator(object):
+ """A helper for sending requests to a gRPC stream from a Queue.
+
+ This generator takes requests off a given queue and yields them to gRPC.
+
+ This helper is useful when you have an indeterminate, indefinite, or
+ otherwise open-ended set of requests to send through a request-streaming
+ (or bidirectional) RPC.
+
+ The reason this is necessary is because gRPC takes an iterator as the
+ request for request-streaming RPCs. gRPC consumes this iterator in another
+ thread to allow it to block while generating requests for the stream.
+ However, if the generator blocks indefinitely gRPC will not be able to
+ clean up the thread as it'll be blocked on `next(iterator)` and not be able
+ to check the channel status to stop iterating. This helper mitigates that
+ by waiting on the queue with a timeout and checking the RPC state before
+ yielding.
+
+ Finally, it allows for retrying without swapping queues because if it does
+ pull an item off the queue when the RPC is inactive, it'll immediately put
+ it back and then exit. This is necessary because yielding the item in this
+ case will cause gRPC to discard it. In practice, this means that the order
+ of messages is not guaranteed. If such a thing is necessary it would be
+ easy to use a priority queue.
+
+ Example::
+
+ requests = request_queue_generator(q)
+ call = stub.StreamingRequest(iter(requests))
+ requests.call = call
+
+ for response in call:
+ print(response)
+ q.put(...)
+
+ Note that it is possible to accomplish this behavior without "spinning"
+ (using a queue timeout). One possible way would be to use more threads to
+ multiplex the grpc end event with the queue, another possible way is to
+ use selectors and a custom event/queue object. Both of these approaches
+ are significant from an engineering perspective for small benefit - the
+ CPU consumed by spinning is pretty minuscule.
+
+ Args:
+ queue (queue.Queue): The request queue.
+ period (float): The number of seconds to wait for items from the queue
+ before checking if the RPC is cancelled. In practice, this
+ determines the maximum amount of time the request consumption
+ thread will live after the RPC is cancelled.
+ initial_request (Union[protobuf.Message,
+ Callable[None, protobuf.Message]]): The initial request to
+ yield. This is done independently of the request queue to allow fo
+ easily restarting streams that require some initial configuration
+ request.
+ """
+ def __init__(self, queue, period=1, initial_request=None):
+ self._queue = queue
+ self._period = period
+ self._initial_request = initial_request
+ self.call = None
+
+ def _is_active(self):
+ # Note: there is a possibility that this starts *before* the call
+ # property is set. So we have to check if self.call is set before
+ # seeing if it's active.
+ if self.call is not None and not self.call.is_active():
+ return False
+ else:
+ return True
+
+ def __iter__(self):
+ if self._initial_request is not None:
+ if callable(self._initial_request):
+ yield self._initial_request()
+ else:
+ yield self._initial_request
+
+ while True:
+ try:
+ item = self._queue.get(timeout=self._period)
+ except queue.Empty:
+ if not self._is_active():
+ _LOGGER.debug(
+ 'Empty queue and inactive call, exiting request '
+ 'generator.')
+ return
+ else:
+ # call is still active, keep waiting for queue items.
+ continue
+
+ # The consumer explicitly sent "None", indicating that the request
+ # should end.
+ if item is None:
+ _LOGGER.debug('Cleanly exiting request generator.')
+ return
+
+ if not self._is_active():
+ # We have an item, but the call is closed. We should put the
+ # item back on the queue so that the next call can consume it.
+ self._queue.put(item)
+ _LOGGER.debug(
+ 'Inactive call, replacing item on queue and exiting '
+ 'request generator.')
+ return
+
+ yield item
+
+
+class BidiRpc(object):
+ """A helper for consuming a bi-directional streaming RPC.
+
+ This maps gRPC's built-in interface which uses a request iterator and a
+ response iterator into a socket-like :func:`send` and :func:`recv`. This
+ is a more useful pattern for long-running or asymmetric streams (streams
+ where there is not a direct correlation between the requests and
+ responses).
+
+ Example::
+
+ initial_request = example_pb2.StreamingRpcRequest(
+ setting='example')
+ rpc = BidiRpc(stub.StreamingRpc, initial_request=initial_request)
+
+ rpc.open()
+
+ while rpc.is_active():
+ print(rpc.recv())
+ rpc.send(example_pb2.StreamingRpcRequest(
+ data='example'))
+
+ This does *not* retry the stream on errors. See :class:`ResumableBidiRpc`.
+
+ Args:
+ start_rpc (grpc.StreamStreamMultiCallable): The gRPC method used to
+ start the RPC.
+ initial_request (Union[protobuf.Message,
+ Callable[None, protobuf.Message]]): The initial request to
+ yield. This is useful if an initial request is needed to start the
+ stream.
+ """
+ def __init__(self, start_rpc, initial_request=None):
+ self._start_rpc = start_rpc
+ self._initial_request = initial_request
+ self._request_queue = queue.Queue()
+ self._request_generator = None
+ self._is_active = False
+ self._callbacks = []
+ self.call = None
+
+ def add_done_callback(self, callback):
+ """Adds a callback that will be called when the RPC terminates.
+
+ This occurs when the RPC errors or is successfully terminated.
+
+ Args:
+ callback (Callable[[grpc.Future], None]): The callback to execute.
+ It will be provided with the same gRPC future as the underlying
+ stream which will also be a :class:`grpc.Call`.
+ """
+ self._callbacks.append(callback)
+
+ def _on_call_done(self, future):
+ for callback in self._callbacks:
+ callback(future)
+
+ def open(self):
+ """Opens the stream."""
+ if self.is_active:
+ raise ValueError('Can not open an already open stream.')
+
+ request_generator = _RequestQueueGenerator(
+ self._request_queue, initial_request=self._initial_request)
+ call = self._start_rpc(iter(request_generator))
+
+ request_generator.call = call
+
+ # TODO: api_core should expose the future interface for wrapped
+ # callables as well.
+ if hasattr(call, '_wrapped'): # pragma: NO COVER
+ call._wrapped.add_done_callback(self._on_call_done)
+ else:
+ call.add_done_callback(self._on_call_done)
+
+ self._request_generator = request_generator
+ self.call = call
+
+ def close(self):
+ """Closes the stream."""
+ if self.call is None:
+ return
+
+ self._request_queue.put(None)
+ self.call.cancel()
+ self._request_generator = None
+ # Don't set self.call to None. Keep it around so that send/recv can
+ # raise the error.
+
+ def send(self, request):
+ """Queue a message to be sent on the stream.
+
+ Send is non-blocking.
+
+ If the underlying RPC has been closed, this will raise.
+
+ Args:
+ request (protobuf.Message): The request to send.
+ """
+ if self.call is None:
+ raise ValueError(
+ 'Can not send() on an RPC that has never been open()ed.')
+
+ # Don't use self.is_active(), as ResumableBidiRpc will overload it
+ # to mean something semantically different.
+ if self.call.is_active():
+ self._request_queue.put(request)
+ else:
+ # calling next should cause the call to raise.
+ next(self.call)
+
+ def recv(self):
+ """Wait for a message to be returned from the stream.
+
+ Recv is blocking.
+
+ If the underlying RPC has been closed, this will raise.
+
+ Returns:
+ protobuf.Message: The received message.
+ """
+ if self.call is None:
+ raise ValueError(
+ 'Can not recv() on an RPC that has never been open()ed.')
+
+ return next(self.call)
+
+ @property
+ def is_active(self):
+ """bool: True if this stream is currently open and active."""
+ return self.call is not None and self.call.is_active()
+
+ @property
+ def pending_requests(self):
+ """int: Returns an estimate of the number of queued requests."""
+ return self._request_queue.qsize()
+
+
+class ResumableBidiRpc(BidiRpc):
+ """A :class:`BidiRpc` that can automatically resume the stream on errors.
+
+ It uses the ``should_recover`` arg to determine if it should re-establish
+ the stream on error.
+
+ Example::
+
+ def should_recover(exc):
+ return (
+ isinstance(exc, grpc.RpcError) and
+ exc.code() == grpc.StatusCode.UNVAILABLE)
+
+ initial_request = example_pb2.StreamingRpcRequest(
+ setting='example')
+
+ rpc = ResumeableBidiRpc(
+ stub.StreamingRpc,
+ initial_request=initial_request,
+ should_recover=should_recover)
+
+ rpc.open()
+
+ while rpc.is_active():
+ print(rpc.recv())
+ rpc.send(example_pb2.StreamingRpcRequest(
+ data='example'))
+
+ Args:
+ start_rpc (grpc.StreamStreamMultiCallable): The gRPC method used to
+ start the RPC.
+ initial_request (Union[protobuf.Message,
+ Callable[None, protobuf.Message]]): The initial request to
+ yield. This is useful if an initial request is needed to start the
+ stream.
+ should_recover (Callable[[Exception], bool]): A function that returns
+ True if the stream should be recovered. This will be called
+ whenever an error is encountered on the stream.
+ """
+ def __init__(self, start_rpc, should_recover, initial_request=None):
+ super(ResumableBidiRpc, self).__init__(start_rpc, initial_request)
+ self._should_recover = should_recover
+ self._operational_lock = threading.RLock()
+ self._finalized = False
+ self._finalize_lock = threading.Lock()
+
+ def _finalize(self, result):
+ with self._finalize_lock:
+ if self._finalized:
+ return
+
+ for callback in self._callbacks:
+ callback(result)
+
+ self._finalized = True
+
+ def _on_call_done(self, future):
+ # Unlike the base class, we only execute the callbacks on a terminal
+ # error, not for errors that we can recover from. Note that grpc's
+ # "future" here is also a grpc.RpcError.
+ with self._operational_lock:
+ if not self._should_recover(future):
+ self._finalize(future)
+ else:
+ _LOGGER.debug('Re-opening stream from gRPC callback.')
+ self._reopen()
+
+ def _reopen(self):
+ with self._operational_lock:
+ # Another thread already managed to re-open this stream.
+ if self.call is not None and self.call.is_active():
+ _LOGGER.debug('Stream was already re-established.')
+ return
+
+ self.call = None
+ # Request generator should exit cleanly since the RPC its bound to
+ # has exited.
+ self.request_generator = None
+
+ # Note: we do not currently do any sort of backoff here. The
+ # assumption is that re-establishing the stream under normal
+ # circumstances will happen in intervals greater than 60s.
+ # However, it is possible in a degenerative case that the server
+ # closes the stream rapidly which would lead to thrashing here,
+ # but hopefully in those cases the server would return a non-
+ # retryable error.
+
+ try:
+ self.open()
+ # If re-opening or re-calling the method fails for any reason,
+ # consider it a terminal error and finalize the stream.
+ except Exception as exc:
+ _LOGGER.debug('Failed to re-open stream due to %s', exc)
+ self._finalize(exc)
+ raise
+
+ _LOGGER.info('Re-established stream')
+
+ def _recoverable(self, method, *args, **kwargs):
+ """Wraps a method to recover the stream and retry on error.
+
+ If a retryable error occurs while making the call, then the stream will
+ be re-opened and the method will be retried. This happens indefinitely
+ so long as the error is a retryable one. If an error occurs while
+ re-opening the stream, then this method will raise immediately and
+ trigger finalization of this object.
+
+ Args:
+ method (Callable[..., Any]): The method to call.
+ args: The args to pass to the method.
+ kwargs: The kwargs to pass to the method.
+ """
+ while True:
+ try:
+ return method(*args, **kwargs)
+
+ except Exception as exc:
+ with self._operational_lock:
+ _LOGGER.debug(
+ 'Call to retryable %r caused %s.', method, exc)
+
+ if not self._should_recover(exc):
+ self.close()
+ _LOGGER.debug(
+ 'Not retrying %r due to %s.', method, exc)
+ self._finalize(exc)
+ raise exc
+
+ _LOGGER.debug(
+ 'Re-opening stream from retryable %r.', method)
+ self._reopen()
+
+ def _send(self, request):
+ # Grab a reference to the RPC call. Because another thread (notably
+ # the gRPC error thread) can modify self.call (by invoking reopen),
+ # we should ensure our reference can not change underneath us.
+ # If self.call is modified (such as replaced with a new RPC call) then
+ # this will use the "old" RPC, which should result in the same
+ # exception passed into gRPC's error handler being raised here, which
+ # will be handled by the usual error handling in retryable.
+ with self._operational_lock:
+ call = self.call
+
+ if call is None:
+ raise ValueError(
+ 'Can not send() on an RPC that has never been open()ed.')
+
+ # Don't use self.is_active(), as ResumableBidiRpc will overload it
+ # to mean something semantically different.
+ if call.is_active():
+ self._request_queue.put(request)
+ pass
+ else:
+ # calling next should cause the call to raise.
+ next(call)
+
+ def send(self, request):
+ return self._recoverable(self._send, request)
+
+ def _recv(self):
+ with self._operational_lock:
+ call = self.call
+
+ if call is None:
+ raise ValueError(
+ 'Can not recv() on an RPC that has never been open()ed.')
+
+ return next(call)
+
+ def recv(self):
+ return self._recoverable(self._recv)
+
+ @property
+ def is_active(self):
+ """bool: True if this stream is currently open and active."""
+ # Use the operational lock. It's entirely possible for something
+ # to check the active state *while* the RPC is being retried.
+ # Also, use finalized to track the actual terminal state here.
+ # This is because if the stream is re-established by the gRPC thread
+ # it's technically possible to check this between when gRPC marks the
+ # RPC as inactive and when gRPC executes our callback that re-opens
+ # the stream.
+ with self._operational_lock:
+ return self.call is not None and not self._finalized
+
+
+class BackgroundConsumer(object):
+ """A bi-directional stream consumer that runs in a separate thread.
+
+ This maps the consumption of a stream into a callback-based model. It also
+ provides :func:`pause` and :func:`resume` to allow for flow-control.
+
+ Example::
+
+ def should_recover(exc):
+ return (
+ isinstance(exc, grpc.RpcError) and
+ exc.code() == grpc.StatusCode.UNVAILABLE)
+
+ initial_request = example_pb2.StreamingRpcRequest(
+ setting='example')
+
+ rpc = ResumeableBidiRpc(
+ stub.StreamingRpc,
+ initial_request=initial_request,
+ should_recover=should_recover)
+
+ def on_response(response):
+ print(response)
+
+ consumer = BackgroundConsumer(rpc, on_response)
+ consume.start()
+
+ Note that error handling *must* be done by using the provided
+ ``bidi_rpc``'s ``add_done_callback``. This helper will automatically exit
+ whenever the RPC itself exits and will not provide any error details.
+
+ Args:
+ bidi_rpc (BidiRpc): The RPC to consume. Should not have been
+ ``open()``ed yet.
+ on_response (Callable[[protobuf.Message], None]): The callback to
+ be called for every response on the stream.
+ """
+ def __init__(self, bidi_rpc, on_response):
+ self._bidi_rpc = bidi_rpc
+ self._on_response = on_response
+ self._paused = False
+ self._wake = threading.Condition()
+ self._thread = None
+ self._operational_lock = threading.Lock()
+
+ def _on_call_done(self, future):
+ # Resume the thread if it's paused, this prevents blocking forever
+ # when the RPC has terminated.
+ self.resume()
+
+ def _thread_main(self):
+ try:
+ self._bidi_rpc.add_done_callback(self._on_call_done)
+ self._bidi_rpc.open()
+
+ while self._bidi_rpc.is_active:
+ # Do not allow the paused status to change at all during this
+ # section. There is a condition where we could be resumed
+ # between checking if we are paused and calling wake.wait(),
+ # which means that we will miss the notification to wake up
+ # (oops!) and wait for a notification that will never come.
+ # Keeping the lock throughout avoids that.
+ # In the future, we could use `Condition.wait_for` if we drop
+ # Python 2.7.
+ with self._wake:
+ if self._paused:
+ _LOGGER.debug('paused, waiting for waking.')
+ self._wake.wait()
+ _LOGGER.debug('woken.')
+
+ _LOGGER.debug('waiting for recv.')
+ response = self._bidi_rpc.recv()
+ _LOGGER.debug('recved response.')
+ self._on_response(response)
+
+ except exceptions.GoogleAPICallError as exc:
+ _LOGGER.debug(
+ '%s caught error %s and will exit. Generally this is due to '
+ 'the RPC itself being cancelled and the error will be '
+ 'surfaced to the calling code.',
+ _BIDIRECTIONAL_CONSUMER_NAME, exc, exc_info=True)
+
+ except Exception as exc:
+ _LOGGER.exception(
+ '%s caught unexpected exception %s and will exit.',
+ _BIDIRECTIONAL_CONSUMER_NAME, exc)
+
+ else:
+ _LOGGER.error(
+ 'The bidirectional RPC exited.')
+
+ _LOGGER.info('%s exiting', _BIDIRECTIONAL_CONSUMER_NAME)
+
+ def start(self):
+ """Start the background thread and begin consuming the thread."""
+ with self._operational_lock:
+ thread = threading.Thread(
+ name=_BIDIRECTIONAL_CONSUMER_NAME,
+ target=self._thread_main)
+ thread.daemon = True
+ thread.start()
+ self._thread = thread
+ _LOGGER.debug('Started helper thread %s', thread.name)
+
+ def stop(self):
+ """Stop consuming the stream and shutdown the background thread."""
+ with self._operational_lock:
+ self._bidi_rpc.close()
+
+ if self._thread is not None:
+ # Resume the thread to wake it up in case it is sleeping.
+ self.resume()
+ self._thread.join()
+
+ self._thread = None
+
+ @property
+ def is_active(self):
+ """bool: True if the background thread is active."""
+ return self._thread is not None and self._thread.is_alive()
+
+ def pause(self):
+ """Pauses the response stream.
+
+ This does *not* pause the request stream.
+ """
+ with self._wake:
+ self._paused = True
+
+ def resume(self):
+ """Resumes the response stream."""
+ with self._wake:
+ self._paused = False
+ self._wake.notifyAll()
+
+ @property
+ def is_paused(self):
+ """bool: True if the response stream is paused."""
+ return self._paused
diff --git a/tests/unit/test_bidi.py b/tests/unit/test_bidi.py
new file mode 100644
index 0000000..a377706
--- /dev/null
+++ b/tests/unit/test_bidi.py
@@ -0,0 +1,650 @@
+# Copyright 2018, Google LLC All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import threading
+
+import grpc
+import mock
+import pytest
+from six.moves import queue
+
+from google.api_core import exceptions
+from google.api_core import bidi
+
+
+class Test_RequestQueueGenerator(object):
+
+ def test_bounded_consume(self):
+ call = mock.create_autospec(grpc.Call, instance=True)
+ call.is_active.return_value = True
+
+ def queue_generator(rpc):
+ yield mock.sentinel.A
+ yield queue.Empty()
+ yield mock.sentinel.B
+ rpc.is_active.return_value = False
+ yield mock.sentinel.C
+
+ q = mock.create_autospec(queue.Queue, instance=True)
+ q.get.side_effect = queue_generator(call)
+
+ generator = bidi._RequestQueueGenerator(q)
+ generator.call = call
+
+ items = list(generator)
+
+ assert items == [mock.sentinel.A, mock.sentinel.B]
+
+ def test_yield_initial_and_exit(self):
+ q = mock.create_autospec(queue.Queue, instance=True)
+ q.get.side_effect = queue.Empty()
+ call = mock.create_autospec(grpc.Call, instance=True)
+ call.is_active.return_value = False
+
+ generator = bidi._RequestQueueGenerator(
+ q, initial_request=mock.sentinel.A)
+ generator.call = call
+
+ items = list(generator)
+
+ assert items == [mock.sentinel.A]
+
+ def test_yield_initial_callable_and_exit(self):
+ q = mock.create_autospec(queue.Queue, instance=True)
+ q.get.side_effect = queue.Empty()
+ call = mock.create_autospec(grpc.Call, instance=True)
+ call.is_active.return_value = False
+
+ generator = bidi._RequestQueueGenerator(
+ q, initial_request=lambda: mock.sentinel.A)
+ generator.call = call
+
+ items = list(generator)
+
+ assert items == [mock.sentinel.A]
+
+ def test_exit_when_inactive_with_item(self):
+ q = mock.create_autospec(queue.Queue, instance=True)
+ q.get.side_effect = [mock.sentinel.A, queue.Empty()]
+ call = mock.create_autospec(grpc.Call, instance=True)
+ call.is_active.return_value = False
+
+ generator = bidi._RequestQueueGenerator(q)
+ generator.call = call
+
+ items = list(generator)
+
+ assert items == []
+ # Make sure it put the item back.
+ q.put.assert_called_once_with(mock.sentinel.A)
+
+ def test_exit_when_inactive_empty(self):
+ q = mock.create_autospec(queue.Queue, instance=True)
+ q.get.side_effect = queue.Empty()
+ call = mock.create_autospec(grpc.Call, instance=True)
+ call.is_active.return_value = False
+
+ generator = bidi._RequestQueueGenerator(q)
+ generator.call = call
+
+ items = list(generator)
+
+ assert items == []
+
+ def test_exit_with_stop(self):
+ q = mock.create_autospec(queue.Queue, instance=True)
+ q.get.side_effect = [None, queue.Empty()]
+ call = mock.create_autospec(grpc.Call, instance=True)
+ call.is_active.return_value = True
+
+ generator = bidi._RequestQueueGenerator(q)
+ generator.call = call
+
+ items = list(generator)
+
+ assert items == []
+
+
+class _CallAndFuture(grpc.Call, grpc.Future):
+ pass
+
+
+def make_rpc():
+ """Makes a mock RPC used to test Bidi classes."""
+ call = mock.create_autospec(_CallAndFuture, instance=True)
+ rpc = mock.create_autospec(grpc.StreamStreamMultiCallable, instance=True)
+
+ def rpc_side_effect(request):
+ call.is_active.return_value = True
+ call.request = request
+ return call
+
+ rpc.side_effect = rpc_side_effect
+
+ def cancel_side_effect():
+ call.is_active.return_value = False
+
+ call.cancel.side_effect = cancel_side_effect
+
+ return rpc, call
+
+
+class ClosedCall(object):
+ # NOTE: This is needed because defining `.next` on an **instance**
+ # rather than the **class** will not be iterable in Python 2.
+ # This is problematic since a `Mock` just sets members.
+
+ def __init__(self, exception):
+ self.exception = exception
+
+ def __next__(self):
+ raise self.exception
+
+ next = __next__ # Python 2
+
+ def is_active(self):
+ return False
+
+
+class TestBidiRpc(object):
+ def test_initial_state(self):
+ bidi_rpc = bidi.BidiRpc(None)
+
+ assert bidi_rpc.is_active is False
+
+ def test_done_callbacks(self):
+ bidi_rpc = bidi.BidiRpc(None)
+ callback = mock.Mock(spec=['__call__'])
+
+ bidi_rpc.add_done_callback(callback)
+ bidi_rpc._on_call_done(mock.sentinel.future)
+
+ callback.assert_called_once_with(mock.sentinel.future)
+
+ def test_open(self):
+ rpc, call = make_rpc()
+ bidi_rpc = bidi.BidiRpc(rpc)
+
+ bidi_rpc.open()
+
+ assert bidi_rpc.call == call
+ assert bidi_rpc.is_active
+ call.add_done_callback.assert_called_once_with(bidi_rpc._on_call_done)
+
+ def test_open_error_already_open(self):
+ rpc, _ = make_rpc()
+ bidi_rpc = bidi.BidiRpc(rpc)
+
+ bidi_rpc.open()
+
+ with pytest.raises(ValueError):
+ bidi_rpc.open()
+
+ def test_close(self):
+ rpc, call = make_rpc()
+ bidi_rpc = bidi.BidiRpc(rpc)
+ bidi_rpc.open()
+
+ bidi_rpc.close()
+
+ call.cancel.assert_called_once()
+ assert bidi_rpc.call == call
+ assert bidi_rpc.is_active is False
+ # ensure the request queue was signaled to stop.
+ assert bidi_rpc.pending_requests == 1
+ assert bidi_rpc._request_queue.get() is None
+
+ def test_close_no_rpc(self):
+ bidi_rpc = bidi.BidiRpc(None)
+ bidi_rpc.close()
+
+ def test_send(self):
+ rpc, call = make_rpc()
+ bidi_rpc = bidi.BidiRpc(rpc)
+ bidi_rpc.open()
+
+ bidi_rpc.send(mock.sentinel.request)
+
+ assert bidi_rpc.pending_requests == 1
+ assert bidi_rpc._request_queue.get() is mock.sentinel.request
+
+ def test_send_not_open(self):
+ rpc, call = make_rpc()
+ bidi_rpc = bidi.BidiRpc(rpc)
+
+ with pytest.raises(ValueError):
+ bidi_rpc.send(mock.sentinel.request)
+
+ def test_send_dead_rpc(self):
+ error = ValueError()
+ bidi_rpc = bidi.BidiRpc(None)
+ bidi_rpc.call = ClosedCall(error)
+
+ with pytest.raises(ValueError) as exc_info:
+ bidi_rpc.send(mock.sentinel.request)
+
+ assert exc_info.value == error
+
+ def test_recv(self):
+ bidi_rpc = bidi.BidiRpc(None)
+ bidi_rpc.call = iter([mock.sentinel.response])
+
+ response = bidi_rpc.recv()
+
+ assert response == mock.sentinel.response
+
+ def test_recv_not_open(self):
+ rpc, call = make_rpc()
+ bidi_rpc = bidi.BidiRpc(rpc)
+
+ with pytest.raises(ValueError):
+ bidi_rpc.recv()
+
+
+class CallStub(object):
+ def __init__(self, values, active=True):
+ self.values = iter(values)
+ self._is_active = active
+ self.cancelled = False
+
+ def __next__(self):
+ item = next(self.values)
+ if isinstance(item, Exception):
+ self._is_active = False
+ raise item
+ return item
+
+ next = __next__ # Python 2
+
+ def is_active(self):
+ return self._is_active
+
+ def add_done_callback(self, callback):
+ pass
+
+ def cancel(self):
+ self.cancelled = True
+
+
+class TestResumableBidiRpc(object):
+ def test_initial_state(self):
+ bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: True)
+
+ assert bidi_rpc.is_active is False
+
+ def test_done_callbacks_recoverable(self):
+ start_rpc = mock.create_autospec(
+ grpc.StreamStreamMultiCallable, instance=True)
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, lambda _: True)
+ callback = mock.Mock(spec=['__call__'])
+
+ bidi_rpc.add_done_callback(callback)
+ bidi_rpc._on_call_done(mock.sentinel.future)
+
+ callback.assert_not_called()
+ start_rpc.assert_called_once()
+ assert bidi_rpc.is_active
+
+ def test_done_callbacks_non_recoverable(self):
+ bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: False)
+ callback = mock.Mock(spec=['__call__'])
+
+ bidi_rpc.add_done_callback(callback)
+ bidi_rpc._on_call_done(mock.sentinel.future)
+
+ callback.assert_called_once_with(mock.sentinel.future)
+
+ def test_send_recover(self):
+ error = ValueError()
+ call_1 = CallStub([error], active=False)
+ call_2 = CallStub([])
+ start_rpc = mock.create_autospec(
+ grpc.StreamStreamMultiCallable,
+ instance=True,
+ side_effect=[call_1, call_2])
+ should_recover = mock.Mock(spec=['__call__'], return_value=True)
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
+
+ bidi_rpc.open()
+
+ bidi_rpc.send(mock.sentinel.request)
+
+ assert bidi_rpc.pending_requests == 1
+ assert bidi_rpc._request_queue.get() is mock.sentinel.request
+
+ should_recover.assert_called_once_with(error)
+ assert bidi_rpc.call == call_2
+ assert bidi_rpc.is_active is True
+
+ def test_send_failure(self):
+ error = ValueError()
+ call = CallStub([error], active=False)
+ start_rpc = mock.create_autospec(
+ grpc.StreamStreamMultiCallable,
+ instance=True,
+ return_value=call)
+ should_recover = mock.Mock(spec=['__call__'], return_value=False)
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
+
+ bidi_rpc.open()
+
+ with pytest.raises(ValueError) as exc_info:
+ bidi_rpc.send(mock.sentinel.request)
+
+ assert exc_info.value == error
+ should_recover.assert_called_once_with(error)
+ assert bidi_rpc.call == call
+ assert bidi_rpc.is_active is False
+ assert call.cancelled is True
+ assert bidi_rpc.pending_requests == 1
+ assert bidi_rpc._request_queue.get() is None
+
+ def test_recv_recover(self):
+ error = ValueError()
+ call_1 = CallStub([1, error])
+ call_2 = CallStub([2, 3])
+ start_rpc = mock.create_autospec(
+ grpc.StreamStreamMultiCallable,
+ instance=True,
+ side_effect=[call_1, call_2])
+ should_recover = mock.Mock(spec=['__call__'], return_value=True)
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
+
+ bidi_rpc.open()
+
+ values = []
+ for n in range(3):
+ values.append(bidi_rpc.recv())
+
+ assert values == [1, 2, 3]
+ should_recover.assert_called_once_with(error)
+ assert bidi_rpc.call == call_2
+ assert bidi_rpc.is_active is True
+
+ def test_recv_recover_already_recovered(self):
+ call_1 = CallStub([])
+ call_2 = CallStub([])
+ start_rpc = mock.create_autospec(
+ grpc.StreamStreamMultiCallable,
+ instance=True,
+ side_effect=[call_1, call_2])
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, lambda _: True)
+
+ bidi_rpc.open()
+
+ bidi_rpc._reopen()
+
+ assert bidi_rpc.call is call_1
+ assert bidi_rpc.is_active is True
+
+ def test_recv_failure(self):
+ error = ValueError()
+ call = CallStub([error])
+ start_rpc = mock.create_autospec(
+ grpc.StreamStreamMultiCallable,
+ instance=True,
+ return_value=call)
+ should_recover = mock.Mock(spec=['__call__'], return_value=False)
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
+
+ bidi_rpc.open()
+
+ with pytest.raises(ValueError) as exc_info:
+ bidi_rpc.recv()
+
+ assert exc_info.value == error
+ should_recover.assert_called_once_with(error)
+ assert bidi_rpc.call == call
+ assert bidi_rpc.is_active is False
+ assert call.cancelled is True
+
+ def test_reopen_failure_on_rpc_restart(self):
+ error1 = ValueError('1')
+ error2 = ValueError('2')
+ call = CallStub([error1])
+ # Invoking start RPC a second time will trigger an error.
+ start_rpc = mock.create_autospec(
+ grpc.StreamStreamMultiCallable,
+ instance=True,
+ side_effect=[call, error2])
+ should_recover = mock.Mock(spec=['__call__'], return_value=True)
+ callback = mock.Mock(spec=['__call__'])
+
+ bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
+ bidi_rpc.add_done_callback(callback)
+
+ bidi_rpc.open()
+
+ with pytest.raises(ValueError) as exc_info:
+ bidi_rpc.recv()
+
+ assert exc_info.value == error2
+ should_recover.assert_called_once_with(error1)
+ assert bidi_rpc.call is None
+ assert bidi_rpc.is_active is False
+ callback.assert_called_once_with(error2)
+
+ def test_send_not_open(self):
+ bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: False)
+
+ with pytest.raises(ValueError):
+ bidi_rpc.send(mock.sentinel.request)
+
+ def test_recv_not_open(self):
+ bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: False)
+
+ with pytest.raises(ValueError):
+ bidi_rpc.recv()
+
+ def test_finalize_idempotent(self):
+ error1 = ValueError('1')
+ error2 = ValueError('2')
+ callback = mock.Mock(spec=['__call__'])
+ should_recover = mock.Mock(spec=['__call__'], return_value=False)
+
+ bidi_rpc = bidi.ResumableBidiRpc(
+ mock.sentinel.start_rpc, should_recover)
+
+ bidi_rpc.add_done_callback(callback)
+
+ bidi_rpc._on_call_done(error1)
+ bidi_rpc._on_call_done(error2)
+
+ callback.assert_called_once_with(error1)
+
+
+class TestBackgroundConsumer(object):
+ def test_consume_once_then_exit(self):
+ bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True)
+ bidi_rpc.is_active = True
+ bidi_rpc.recv.side_effect = [mock.sentinel.response_1]
+ recved = threading.Event()
+
+ def on_response(response):
+ assert response == mock.sentinel.response_1
+ bidi_rpc.is_active = False
+ recved.set()
+
+ consumer = bidi.BackgroundConsumer(bidi_rpc, on_response)
+
+ consumer.start()
+
+ recved.wait()
+
+ bidi_rpc.recv.assert_called_once()
+ assert bidi_rpc.is_active is False
+
+ consumer.stop()
+
+ bidi_rpc.close.assert_called_once()
+ assert consumer.is_active is False
+
+ def test_pause_resume_and_close(self):
+ # This test is relatively complex. It attempts to start the consumer,
+ # consume one item, pause the consumer, check the state of the world,
+ # then resume the consumer. Doing this in a deterministic fashion
+ # requires a bit more mocking and patching than usual.
+
+ bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True)
+ bidi_rpc.is_active = True
+
+ def close_side_effect():
+ bidi_rpc.is_active = False
+
+ bidi_rpc.close.side_effect = close_side_effect
+
+ # These are used to coordinate the two threads to ensure deterministic
+ # execution.
+ should_continue = threading.Event()
+ responses_and_events = {
+ mock.sentinel.response_1: threading.Event(),
+ mock.sentinel.response_2: threading.Event()
+ }
+ bidi_rpc.recv.side_effect = [
+ mock.sentinel.response_1, mock.sentinel.response_2]
+
+ recved_responses = []
+ consumer = None
+
+ def on_response(response):
+ if response == mock.sentinel.response_1:
+ consumer.pause()
+
+ recved_responses.append(response)
+ responses_and_events[response].set()
+ should_continue.wait()
+
+ consumer = bidi.BackgroundConsumer(bidi_rpc, on_response)
+
+ consumer.start()
+
+ # Wait for the first response to be recved.
+ responses_and_events[mock.sentinel.response_1].wait()
+
+ # Ensure only one item has been recved and that the consumer is paused.
+ assert recved_responses == [mock.sentinel.response_1]
+ assert consumer.is_paused is True
+ assert consumer.is_active is True
+
+ # Unpause the consumer, wait for the second item, then close the
+ # consumer.
+ should_continue.set()
+ consumer.resume()
+
+ responses_and_events[mock.sentinel.response_2].wait()
+
+ assert recved_responses == [
+ mock.sentinel.response_1, mock.sentinel.response_2]
+
+ consumer.stop()
+
+ assert consumer.is_active is False
+
+ def test_wake_on_error(self):
+ should_continue = threading.Event()
+
+ bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True)
+ bidi_rpc.is_active = True
+ bidi_rpc.add_done_callback.side_effect = (
+ lambda _: should_continue.set())
+
+ consumer = bidi.BackgroundConsumer(bidi_rpc, mock.sentinel.on_response)
+
+ # Start the consumer paused, which should immediately put it into wait
+ # state.
+ consumer.pause()
+ consumer.start()
+
+ # Wait for add_done_callback to be called
+ should_continue.wait()
+ bidi_rpc.add_done_callback.assert_called_once_with(
+ consumer._on_call_done)
+
+ # The consumer should now be blocked on waiting to be unpaused.
+ assert consumer.is_active
+ assert consumer.is_paused
+
+ # Trigger the done callback, it should unpause the consumer and cause
+ # it to exit.
+ bidi_rpc.is_active = False
+ consumer._on_call_done(bidi_rpc)
+
+ # It may take a few cycles for the thread to exit.
+ while consumer.is_active:
+ pass
+
+ def test_consumer_expected_error(self, caplog):
+ caplog.set_level(logging.DEBUG)
+
+ bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True)
+ bidi_rpc.is_active = True
+ bidi_rpc.recv.side_effect = exceptions.ServiceUnavailable('Gone away')
+
+ on_response = mock.Mock(spec=['__call__'])
+
+ consumer = bidi.BackgroundConsumer(bidi_rpc, on_response)
+
+ consumer.start()
+
+ # Wait for the consumer's thread to exit.
+ while consumer.is_active:
+ pass
+
+ on_response.assert_not_called()
+ bidi_rpc.recv.assert_called_once()
+ assert 'caught error' in caplog.text
+
+ def test_consumer_unexpected_error(self, caplog):
+ caplog.set_level(logging.DEBUG)
+
+ bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True)
+ bidi_rpc.is_active = True
+ bidi_rpc.recv.side_effect = ValueError()
+
+ on_response = mock.Mock(spec=['__call__'])
+
+ consumer = bidi.BackgroundConsumer(bidi_rpc, on_response)
+
+ consumer.start()
+
+ # Wait for the consumer's thread to exit.
+ while consumer.is_active:
+ pass
+
+ on_response.assert_not_called()
+ bidi_rpc.recv.assert_called_once()
+ assert 'caught unexpected exception' in caplog.text
+
+ def test_double_stop(self, caplog):
+ caplog.set_level(logging.DEBUG)
+ bidi_rpc = mock.create_autospec(bidi.BidiRpc, instance=True)
+ bidi_rpc.is_active = True
+ on_response = mock.Mock(spec=['__call__'])
+
+ def close_side_effect():
+ bidi_rpc.is_active = False
+
+ bidi_rpc.close.side_effect = close_side_effect
+
+ consumer = bidi.BackgroundConsumer(bidi_rpc, on_response)
+
+ consumer.start()
+ assert consumer.is_active is True
+
+ consumer.stop()
+ assert consumer.is_active is False
+
+ # calling stop twice should not result in an error.
+ consumer.stop()