api_core: Add ChannelStub to grpc_helpers (#4705)
diff --git a/google/api_core/grpc_helpers.py b/google/api_core/grpc_helpers.py
index 784acf6..7d81c75 100644
--- a/google/api_core/grpc_helpers.py
+++ b/google/api_core/grpc_helpers.py
@@ -14,6 +14,8 @@
"""Helpers for :mod:`grpc`."""
+import collections
+
import grpc
import six
@@ -136,3 +138,185 @@
return google.auth.transport.grpc.secure_authorized_channel(
credentials, request, target, **kwargs)
+
+
+_MethodCall = collections.namedtuple(
+ '_MethodCall', ('request', 'timeout', 'metadata', 'credentials'))
+
+_ChannelRequest = collections.namedtuple(
+ '_ChannelRequest', ('method', 'request'))
+
+
+class _CallableStub(object):
+ """Stub for the grpc.*MultiCallable interfaces."""
+
+ def __init__(self, method, channel):
+ self._method = method
+ self._channel = channel
+ self.response = None
+ """Union[protobuf.Message, Callable[protobuf.Message], exception]:
+ The response to give when invoking this callable. If this is a
+ callable, it will be invoked with the request protobuf. If it's an
+ exception, the exception will be raised when this is invoked.
+ """
+ self.responses = None
+ """Iterator[
+ Union[protobuf.Message, Callable[protobuf.Message], exception]]:
+ An iterator of responses. If specified, self.response will be populated
+ on each invocation by calling ``next(self.responses)``."""
+ self.requests = []
+ """List[protobuf.Message]: All requests sent to this callable."""
+ self.calls = []
+ """List[Tuple]: All invocations of this callable. Each tuple is the
+ request, timeout, metadata, and credentials."""
+
+ def __call__(self, request, timeout=None, metadata=None, credentials=None):
+ self._channel.requests.append(
+ _ChannelRequest(self._method, request))
+ self.calls.append(
+ _MethodCall(request, timeout, metadata, credentials))
+ self.requests.append(request)
+
+ response = self.response
+ if self.responses is not None:
+ if response is None:
+ response = next(self.responses)
+ else:
+ raise ValueError(
+ '{method}.response and {method}.responses are mutually '
+ 'exclusive.'.format(method=self._method))
+
+ if callable(response):
+ return response(request)
+
+ if isinstance(response, Exception):
+ raise response
+
+ if response is not None:
+ return response
+
+ raise ValueError(
+ 'Method stub for "{}" has no response.'.format(self._method))
+
+
+def _simplify_method_name(method):
+ """Simplifies a gRPC method name.
+
+ When gRPC invokes the channel to create a callable, it gives a full
+ method name like "/google.pubsub.v1.Publisher/CreateTopic". This
+ returns just the name of the method, in this case "CreateTopic".
+
+ Args:
+ method (str): The name of the method.
+
+ Returns:
+ str: The simplified name of the method.
+ """
+ return method.rsplit('/', 1).pop()
+
+
+class ChannelStub(grpc.Channel):
+ """A testing stub for the grpc.Channel interface.
+
+ This can be used to test any client that eventually uses a gRPC channel
+ to communicate. By passing in a channel stub, you can configure which
+ responses are returned and track which requests are made.
+
+ For example:
+
+ .. code-block:: python
+
+ channel_stub = grpc_helpers.ChannelStub()
+ client = FooClient(channel=channel_stub)
+
+ channel_stub.GetFoo.response = foo_pb2.Foo(name='bar')
+
+ foo = client.get_foo(labels=['baz'])
+
+ assert foo.name == 'bar'
+ assert channel_stub.GetFoo.requests[0].labels = ['baz']
+
+ Each method on the stub can be accessed and configured on the channel.
+ Here's some examples of various configurations:
+
+ .. code-block:: python
+
+ # Return a basic response:
+
+ channel_stub.GetFoo.response = foo_pb2.Foo(name='bar')
+ assert client.get_foo().name == 'bar'
+
+ # Raise an exception:
+ channel_stub.GetFoo.response = NotFound('...')
+
+ with pytest.raises(NotFound):
+ client.get_foo()
+
+ # Use a sequence of responses:
+ channel_stub.GetFoo.responses = iter([
+ foo_pb2.Foo(name='bar'),
+ foo_pb2.Foo(name='baz'),
+ ])
+
+ assert client.get_foo().name == 'bar'
+ assert client.get_foo().name == 'baz'
+
+ # Use a callable
+
+ def on_get_foo(request):
+ return foo_pb2.Foo(name='bar' + request.id)
+
+ channel_stub.GetFoo.response = on_get_foo
+
+ assert client.get_foo(id='123').name == 'bar123'
+ """
+
+ def __init__(self, responses=[]):
+ self.requests = []
+ """Sequence[Tuple[str, protobuf.Message]]: A list of all requests made
+ on this channel in order. The tuple is of method name, request
+ message."""
+ self._method_stubs = {}
+
+ def _stub_for_method(self, method):
+ method = _simplify_method_name(method)
+ self._method_stubs[method] = _CallableStub(method, self)
+ return self._method_stubs[method]
+
+ def __getattr__(self, key):
+ try:
+ return self._method_stubs[key]
+ except KeyError:
+ raise AttributeError
+
+ def unary_unary(
+ self, method,
+ request_serializer=None, response_deserializer=None):
+ """grpc.Channel.unary_unary implementation."""
+ return self._stub_for_method(method)
+
+ def unary_stream(
+ self, method,
+ request_serializer=None, response_deserializer=None):
+ """grpc.Channel.unary_stream implementation."""
+ return self._stub_for_method(method)
+
+ def stream_unary(
+ self, method,
+ request_serializer=None, response_deserializer=None):
+ """grpc.Channel.stream_unary implementation."""
+ return self._stub_for_method(method)
+
+ def stream_stream(
+ self, method,
+ request_serializer=None, response_deserializer=None):
+ """grpc.Channel.stream_stream implementation."""
+ return self._stub_for_method(method)
+
+ def subscribe(self, callback, try_to_connect=False):
+ """grpc.Channel.subscribe implementation."""
+ pass
+
+ def unsubscribe(self, callback):
+ """grpc.Channel.unsubscribe implementation."""
+ pass
diff --git a/tests/unit/operations_v1/test_operations_client.py b/tests/unit/operations_v1/test_operations_client.py
index 1b6e6d9..69d4dfc 100644
--- a/tests/unit/operations_v1/test_operations_client.py
+++ b/tests/unit/operations_v1/test_operations_client.py
@@ -12,90 +12,64 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import mock
-
+from google.api_core import grpc_helpers
from google.api_core import operations_v1
from google.api_core import page_iterator
from google.longrunning import operations_pb2
+from google.protobuf import empty_pb2
-def make_operations_stub(channel):
- return mock.Mock(
- spec=[
- 'GetOperation', 'DeleteOperation', 'ListOperations',
- 'CancelOperation'])
-
-
-operations_stub_patch = mock.patch(
- 'google.longrunning.operations_pb2.OperationsStub',
- autospec=True,
- side_effect=make_operations_stub)
-
-
-@operations_stub_patch
-def test_constructor(operations_stub):
- stub = make_operations_stub(None)
- operations_stub.side_effect = None
- operations_stub.return_value = stub
-
- client = operations_v1.OperationsClient(mock.sentinel.channel)
-
- assert client.operations_stub == stub
- operations_stub.assert_called_once_with(mock.sentinel.channel)
-
-
-@operations_stub_patch
-def test_get_operation(operations_stub):
- client = operations_v1.OperationsClient(mock.sentinel.channel)
- client.operations_stub.GetOperation.return_value = mock.sentinel.operation
+def test_get_operation():
+ channel = grpc_helpers.ChannelStub()
+ client = operations_v1.OperationsClient(channel)
+ channel.GetOperation.response = operations_pb2.Operation(name='meep')
response = client.get_operation('name')
- request = client.operations_stub.GetOperation.call_args[0][0]
- assert isinstance(request, operations_pb2.GetOperationRequest)
- assert request.name == 'name'
-
- assert response == mock.sentinel.operation
+ assert len(channel.GetOperation.requests) == 1
+ assert channel.GetOperation.requests[0].name == 'name'
+ assert response == channel.GetOperation.response
-@operations_stub_patch
-def test_list_operations(operations_stub):
- client = operations_v1.OperationsClient(mock.sentinel.channel)
+def test_list_operations():
+ channel = grpc_helpers.ChannelStub()
+ client = operations_v1.OperationsClient(channel)
operations = [
operations_pb2.Operation(name='1'),
operations_pb2.Operation(name='2')]
list_response = operations_pb2.ListOperationsResponse(
operations=operations)
- client.operations_stub.ListOperations.return_value = list_response
+ channel.ListOperations.response = list_response
response = client.list_operations('name', 'filter')
assert isinstance(response, page_iterator.Iterator)
assert list(response) == operations
- request = client.operations_stub.ListOperations.call_args[0][0]
+ assert len(channel.ListOperations.requests) == 1
+ request = channel.ListOperations.requests[0]
assert isinstance(request, operations_pb2.ListOperationsRequest)
assert request.name == 'name'
assert request.filter == 'filter'
-@operations_stub_patch
-def test_delete_operation(operations_stub):
- client = operations_v1.OperationsClient(mock.sentinel.channel)
+def test_delete_operation():
+ channel = grpc_helpers.ChannelStub()
+ client = operations_v1.OperationsClient(channel)
+ channel.DeleteOperation.response = empty_pb2.Empty()
client.delete_operation('name')
- request = client.operations_stub.DeleteOperation.call_args[0][0]
- assert isinstance(request, operations_pb2.DeleteOperationRequest)
- assert request.name == 'name'
+ assert len(channel.DeleteOperation.requests) == 1
+ assert channel.DeleteOperation.requests[0].name == 'name'
-@operations_stub_patch
-def test_cancel_operation(operations_stub):
- client = operations_v1.OperationsClient(mock.sentinel.channel)
+def test_cancel_operation():
+ channel = grpc_helpers.ChannelStub()
+ client = operations_v1.OperationsClient(channel)
+ channel.CancelOperation.response = empty_pb2.Empty()
client.cancel_operation('name')
- request = client.operations_stub.CancelOperation.call_args[0][0]
- assert isinstance(request, operations_pb2.CancelOperationRequest)
- assert request.name == 'name'
+ assert len(channel.CancelOperation.requests) == 1
+ assert channel.CancelOperation.requests[0].name == 'name'
diff --git a/tests/unit/test_grpc_helpers.py b/tests/unit/test_grpc_helpers.py
index 6ee4062..de093e5 100644
--- a/tests/unit/test_grpc_helpers.py
+++ b/tests/unit/test_grpc_helpers.py
@@ -19,6 +19,7 @@
from google.api_core import exceptions
from google.api_core import grpc_helpers
import google.auth.credentials
+from google.longrunning import operations_pb2
def test__patch_callable_name():
@@ -186,3 +187,147 @@
scopes=scopes)
credentials.with_scopes.assert_called_once_with(scopes)
+
+
+class TestChannelStub(object):
+
+ def test_single_response(self):
+ channel = grpc_helpers.ChannelStub()
+ stub = operations_pb2.OperationsStub(channel)
+ expected_request = operations_pb2.GetOperationRequest(name='meep')
+ expected_response = operations_pb2.Operation(name='moop')
+
+ channel.GetOperation.response = expected_response
+
+ response = stub.GetOperation(expected_request)
+
+ assert response == expected_response
+ assert channel.requests == [('GetOperation', expected_request)]
+ assert channel.GetOperation.requests == [expected_request]
+
+ def test_no_response(self):
+ channel = grpc_helpers.ChannelStub()
+ stub = operations_pb2.OperationsStub(channel)
+ expected_request = operations_pb2.GetOperationRequest(name='meep')
+
+ with pytest.raises(ValueError) as exc_info:
+ stub.GetOperation(expected_request)
+
+ assert exc_info.match('GetOperation')
+
+ def test_missing_method(self):
+ channel = grpc_helpers.ChannelStub()
+
+ with pytest.raises(AttributeError):
+ channel.DoesNotExist.response
+
+ def test_exception_response(self):
+ channel = grpc_helpers.ChannelStub()
+ stub = operations_pb2.OperationsStub(channel)
+ expected_request = operations_pb2.GetOperationRequest(name='meep')
+
+ channel.GetOperation.response = RuntimeError()
+
+ with pytest.raises(RuntimeError):
+ stub.GetOperation(expected_request)
+
+ def test_callable_response(self):
+ channel = grpc_helpers.ChannelStub()
+ stub = operations_pb2.OperationsStub(channel)
+ expected_request = operations_pb2.GetOperationRequest(name='meep')
+ expected_response = operations_pb2.Operation(name='moop')
+
+ on_get_operation = mock.Mock(
+ spec=('__call__',), return_value=expected_response)
+
+ channel.GetOperation.response = on_get_operation
+
+ response = stub.GetOperation(expected_request)
+
+ assert response == expected_response
+ on_get_operation.assert_called_once_with(expected_request)
+
+ def test_multiple_responses(self):
+ channel = grpc_helpers.ChannelStub()
+ stub = operations_pb2.OperationsStub(channel)
+ expected_request = operations_pb2.GetOperationRequest(name='meep')
+ expected_responses = [
+ operations_pb2.Operation(name='foo'),
+ operations_pb2.Operation(name='bar'),
+ operations_pb2.Operation(name='baz'),
+ ]
+
+ channel.GetOperation.responses = iter(expected_responses)
+
+ response1 = stub.GetOperation(expected_request)
+ response2 = stub.GetOperation(expected_request)
+ response3 = stub.GetOperation(expected_request)
+
+ assert response1 == expected_responses[0]
+ assert response2 == expected_responses[1]
+ assert response3 == expected_responses[2]
+ assert channel.requests == [('GetOperation', expected_request)] * 3
+ assert channel.GetOperation.requests == [expected_request] * 3
+
+ with pytest.raises(StopIteration):
+ stub.GetOperation(expected_request)
+
+ def test_multiple_responses_and_single_response_error(self):
+ channel = grpc_helpers.ChannelStub()
+ stub = operations_pb2.OperationsStub(channel)
+ channel.GetOperation.responses = []
+ channel.GetOperation.response = mock.sentinel.response
+
+ with pytest.raises(ValueError):
+ stub.GetOperation(operations_pb2.GetOperationRequest())
+
+ def test_call_info(self):
+ channel = grpc_helpers.ChannelStub()
+ stub = operations_pb2.OperationsStub(channel)
+ expected_request = operations_pb2.GetOperationRequest(name='meep')
+ expected_response = operations_pb2.Operation(name='moop')
+ expected_metadata = [('red', 'blue'), ('two', 'shoe')]
+ expected_credentials = mock.sentinel.credentials
+ channel.GetOperation.response = expected_response
+
+ response = stub.GetOperation(
+ expected_request, timeout=42, metadata=expected_metadata,
+ credentials=expected_credentials)
+
+ assert response == expected_response
+ assert channel.requests == [('GetOperation', expected_request)]
+ assert channel.GetOperation.calls == [
+ (expected_request, 42, expected_metadata, expected_credentials)]
+
+ def test_unary_unary(self):
+ channel = grpc_helpers.ChannelStub()
+ method_name = 'GetOperation'
+ callable_stub = channel.unary_unary(method_name)
+ assert callable_stub._method == method_name
+ assert callable_stub._channel == channel
+
+ def test_unary_stream(self):
+ channel = grpc_helpers.ChannelStub()
+ method_name = 'GetOperation'
+ callable_stub = channel.unary_stream(method_name)
+ assert callable_stub._method == method_name
+ assert callable_stub._channel == channel
+
+ def test_stream_unary(self):
+ channel = grpc_helpers.ChannelStub()
+ method_name = 'GetOperation'
+ callable_stub = channel.stream_unary(method_name)
+ assert callable_stub._method == method_name
+ assert callable_stub._channel == channel
+
+ def test_stream_stream(self):
+ channel = grpc_helpers.ChannelStub()
+ method_name = 'GetOperation'
+ callable_stub = channel.stream_stream(method_name)
+ assert callable_stub._method == method_name
+ assert callable_stub._channel == channel
+
+ def test_subscribe_unsubscribe(self):
+ channel = grpc_helpers.ChannelStub()
+ assert channel.subscribe(None) is None
+ assert channel.unsubscribe(None) is None