api_core: Add ChannelStub to grpc_helpers (#4705)

diff --git a/google/api_core/grpc_helpers.py b/google/api_core/grpc_helpers.py
index 784acf6..7d81c75 100644
--- a/google/api_core/grpc_helpers.py
+++ b/google/api_core/grpc_helpers.py
@@ -14,6 +14,8 @@
 
 """Helpers for :mod:`grpc`."""
 
+import collections
+
 import grpc
 import six
 
@@ -136,3 +138,185 @@
 
     return google.auth.transport.grpc.secure_authorized_channel(
         credentials, request, target, **kwargs)
+
+
+_MethodCall = collections.namedtuple(
+    '_MethodCall', ('request', 'timeout', 'metadata', 'credentials'))
+
+_ChannelRequest = collections.namedtuple(
+    '_ChannelRequest', ('method', 'request'))
+
+
+class _CallableStub(object):
+    """Stub for the grpc.*MultiCallable interfaces."""
+
+    def __init__(self, method, channel):
+        self._method = method
+        self._channel = channel
+        self.response = None
+        """Union[protobuf.Message, Callable[protobuf.Message], exception]:
+        The response to give when invoking this callable. If this is a
+        callable, it will be invoked with the request protobuf. If it's an
+        exception, the exception will be raised when this is invoked.
+        """
+        self.responses = None
+        """Iterator[
+            Union[protobuf.Message, Callable[protobuf.Message], exception]]:
+        An iterator of responses. If specified, self.response will be populated
+        on each invocation by calling ``next(self.responses)``."""
+        self.requests = []
+        """List[protobuf.Message]: All requests sent to this callable."""
+        self.calls = []
+        """List[Tuple]: All invocations of this callable. Each tuple is the
+        request, timeout, metadata, and credentials."""
+
+    def __call__(self, request, timeout=None, metadata=None, credentials=None):
+        self._channel.requests.append(
+            _ChannelRequest(self._method, request))
+        self.calls.append(
+            _MethodCall(request, timeout, metadata, credentials))
+        self.requests.append(request)
+
+        response = self.response
+        if self.responses is not None:
+            if response is None:
+                response = next(self.responses)
+            else:
+                raise ValueError(
+                    '{method}.response and {method}.responses are mutually '
+                    'exclusive.'.format(method=self._method))
+
+        if callable(response):
+            return response(request)
+
+        if isinstance(response, Exception):
+            raise response
+
+        if response is not None:
+            return response
+
+        raise ValueError(
+            'Method stub for "{}" has no response.'.format(self._method))
+
+
+def _simplify_method_name(method):
+    """Simplifies a gRPC method name.
+
+    When gRPC invokes the channel to create a callable, it gives a full
+    method name like "/google.pubsub.v1.Publisher/CreateTopic". This
+    returns just the name of the method, in this case "CreateTopic".
+
+    Args:
+        method (str): The name of the method.
+
+    Returns:
+        str: The simplified name of the method.
+    """
+    return method.rsplit('/', 1).pop()
+
+
+class ChannelStub(grpc.Channel):
+    """A testing stub for the grpc.Channel interface.
+
+    This can be used to test any client that eventually uses a gRPC channel
+    to communicate. By passing in a channel stub, you can configure which
+    responses are returned and track which requests are made.
+
+    For example:
+
+    .. code-block:: python
+
+        channel_stub = grpc_helpers.ChannelStub()
+        client = FooClient(channel=channel_stub)
+
+        channel_stub.GetFoo.response = foo_pb2.Foo(name='bar')
+
+        foo = client.get_foo(labels=['baz'])
+
+        assert foo.name == 'bar'
+        assert channel_stub.GetFoo.requests[0].labels = ['baz']
+
+    Each method on the stub can be accessed and configured on the channel.
+    Here's some examples of various configurations:
+
+    .. code-block:: python
+
+        # Return a basic response:
+
+        channel_stub.GetFoo.response = foo_pb2.Foo(name='bar')
+        assert client.get_foo().name == 'bar'
+
+        # Raise an exception:
+        channel_stub.GetFoo.response = NotFound('...')
+
+        with pytest.raises(NotFound):
+            client.get_foo()
+
+        # Use a sequence of responses:
+        channel_stub.GetFoo.responses = iter([
+            foo_pb2.Foo(name='bar'),
+            foo_pb2.Foo(name='baz'),
+        ])
+
+        assert client.get_foo().name == 'bar'
+        assert client.get_foo().name == 'baz'
+
+        # Use a callable
+
+        def on_get_foo(request):
+            return foo_pb2.Foo(name='bar' + request.id)
+
+        channel_stub.GetFoo.response = on_get_foo
+
+        assert client.get_foo(id='123').name == 'bar123'
+    """
+
+    def __init__(self, responses=[]):
+        self.requests = []
+        """Sequence[Tuple[str, protobuf.Message]]: A list of all requests made
+        on this channel in order. The tuple is of method name, request
+        message."""
+        self._method_stubs = {}
+
+    def _stub_for_method(self, method):
+        method = _simplify_method_name(method)
+        self._method_stubs[method] = _CallableStub(method, self)
+        return self._method_stubs[method]
+
+    def __getattr__(self, key):
+        try:
+            return self._method_stubs[key]
+        except KeyError:
+            raise AttributeError
+
+    def unary_unary(
+            self, method,
+            request_serializer=None, response_deserializer=None):
+        """grpc.Channel.unary_unary implementation."""
+        return self._stub_for_method(method)
+
+    def unary_stream(
+            self, method,
+            request_serializer=None, response_deserializer=None):
+        """grpc.Channel.unary_stream implementation."""
+        return self._stub_for_method(method)
+
+    def stream_unary(
+            self, method,
+            request_serializer=None, response_deserializer=None):
+        """grpc.Channel.stream_unary implementation."""
+        return self._stub_for_method(method)
+
+    def stream_stream(
+            self, method,
+            request_serializer=None, response_deserializer=None):
+        """grpc.Channel.stream_stream implementation."""
+        return self._stub_for_method(method)
+
+    def subscribe(self, callback, try_to_connect=False):
+        """grpc.Channel.subscribe implementation."""
+        pass
+
+    def unsubscribe(self, callback):
+        """grpc.Channel.unsubscribe implementation."""
+        pass
diff --git a/tests/unit/operations_v1/test_operations_client.py b/tests/unit/operations_v1/test_operations_client.py
index 1b6e6d9..69d4dfc 100644
--- a/tests/unit/operations_v1/test_operations_client.py
+++ b/tests/unit/operations_v1/test_operations_client.py
@@ -12,90 +12,64 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import mock
-
+from google.api_core import grpc_helpers
 from google.api_core import operations_v1
 from google.api_core import page_iterator
 from google.longrunning import operations_pb2
+from google.protobuf import empty_pb2
 
 
-def make_operations_stub(channel):
-    return mock.Mock(
-        spec=[
-            'GetOperation', 'DeleteOperation', 'ListOperations',
-            'CancelOperation'])
-
-
-operations_stub_patch = mock.patch(
-    'google.longrunning.operations_pb2.OperationsStub',
-    autospec=True,
-    side_effect=make_operations_stub)
-
-
-@operations_stub_patch
-def test_constructor(operations_stub):
-    stub = make_operations_stub(None)
-    operations_stub.side_effect = None
-    operations_stub.return_value = stub
-
-    client = operations_v1.OperationsClient(mock.sentinel.channel)
-
-    assert client.operations_stub == stub
-    operations_stub.assert_called_once_with(mock.sentinel.channel)
-
-
-@operations_stub_patch
-def test_get_operation(operations_stub):
-    client = operations_v1.OperationsClient(mock.sentinel.channel)
-    client.operations_stub.GetOperation.return_value = mock.sentinel.operation
+def test_get_operation():
+    channel = grpc_helpers.ChannelStub()
+    client = operations_v1.OperationsClient(channel)
+    channel.GetOperation.response = operations_pb2.Operation(name='meep')
 
     response = client.get_operation('name')
 
-    request = client.operations_stub.GetOperation.call_args[0][0]
-    assert isinstance(request, operations_pb2.GetOperationRequest)
-    assert request.name == 'name'
-
-    assert response == mock.sentinel.operation
+    assert len(channel.GetOperation.requests) == 1
+    assert channel.GetOperation.requests[0].name == 'name'
+    assert response == channel.GetOperation.response
 
 
-@operations_stub_patch
-def test_list_operations(operations_stub):
-    client = operations_v1.OperationsClient(mock.sentinel.channel)
+def test_list_operations():
+    channel = grpc_helpers.ChannelStub()
+    client = operations_v1.OperationsClient(channel)
     operations = [
         operations_pb2.Operation(name='1'),
         operations_pb2.Operation(name='2')]
     list_response = operations_pb2.ListOperationsResponse(
         operations=operations)
-    client.operations_stub.ListOperations.return_value = list_response
+    channel.ListOperations.response = list_response
 
     response = client.list_operations('name', 'filter')
 
     assert isinstance(response, page_iterator.Iterator)
     assert list(response) == operations
 
-    request = client.operations_stub.ListOperations.call_args[0][0]
+    assert len(channel.ListOperations.requests) == 1
+    request = channel.ListOperations.requests[0]
     assert isinstance(request, operations_pb2.ListOperationsRequest)
     assert request.name == 'name'
     assert request.filter == 'filter'
 
 
-@operations_stub_patch
-def test_delete_operation(operations_stub):
-    client = operations_v1.OperationsClient(mock.sentinel.channel)
+def test_delete_operation():
+    channel = grpc_helpers.ChannelStub()
+    client = operations_v1.OperationsClient(channel)
+    channel.DeleteOperation.response = empty_pb2.Empty()
 
     client.delete_operation('name')
 
-    request = client.operations_stub.DeleteOperation.call_args[0][0]
-    assert isinstance(request, operations_pb2.DeleteOperationRequest)
-    assert request.name == 'name'
+    assert len(channel.DeleteOperation.requests) == 1
+    assert channel.DeleteOperation.requests[0].name == 'name'
 
 
-@operations_stub_patch
-def test_cancel_operation(operations_stub):
-    client = operations_v1.OperationsClient(mock.sentinel.channel)
+def test_cancel_operation():
+    channel = grpc_helpers.ChannelStub()
+    client = operations_v1.OperationsClient(channel)
+    channel.CancelOperation.response = empty_pb2.Empty()
 
     client.cancel_operation('name')
 
-    request = client.operations_stub.CancelOperation.call_args[0][0]
-    assert isinstance(request, operations_pb2.CancelOperationRequest)
-    assert request.name == 'name'
+    assert len(channel.CancelOperation.requests) == 1
+    assert channel.CancelOperation.requests[0].name == 'name'
diff --git a/tests/unit/test_grpc_helpers.py b/tests/unit/test_grpc_helpers.py
index 6ee4062..de093e5 100644
--- a/tests/unit/test_grpc_helpers.py
+++ b/tests/unit/test_grpc_helpers.py
@@ -19,6 +19,7 @@
 from google.api_core import exceptions
 from google.api_core import grpc_helpers
 import google.auth.credentials
+from google.longrunning import operations_pb2
 
 
 def test__patch_callable_name():
@@ -186,3 +187,147 @@
         scopes=scopes)
 
     credentials.with_scopes.assert_called_once_with(scopes)
+
+
+class TestChannelStub(object):
+
+    def test_single_response(self):
+        channel = grpc_helpers.ChannelStub()
+        stub = operations_pb2.OperationsStub(channel)
+        expected_request = operations_pb2.GetOperationRequest(name='meep')
+        expected_response = operations_pb2.Operation(name='moop')
+
+        channel.GetOperation.response = expected_response
+
+        response = stub.GetOperation(expected_request)
+
+        assert response == expected_response
+        assert channel.requests == [('GetOperation', expected_request)]
+        assert channel.GetOperation.requests == [expected_request]
+
+    def test_no_response(self):
+        channel = grpc_helpers.ChannelStub()
+        stub = operations_pb2.OperationsStub(channel)
+        expected_request = operations_pb2.GetOperationRequest(name='meep')
+
+        with pytest.raises(ValueError) as exc_info:
+            stub.GetOperation(expected_request)
+
+        assert exc_info.match('GetOperation')
+
+    def test_missing_method(self):
+        channel = grpc_helpers.ChannelStub()
+
+        with pytest.raises(AttributeError):
+            channel.DoesNotExist.response
+
+    def test_exception_response(self):
+        channel = grpc_helpers.ChannelStub()
+        stub = operations_pb2.OperationsStub(channel)
+        expected_request = operations_pb2.GetOperationRequest(name='meep')
+
+        channel.GetOperation.response = RuntimeError()
+
+        with pytest.raises(RuntimeError):
+            stub.GetOperation(expected_request)
+
+    def test_callable_response(self):
+        channel = grpc_helpers.ChannelStub()
+        stub = operations_pb2.OperationsStub(channel)
+        expected_request = operations_pb2.GetOperationRequest(name='meep')
+        expected_response = operations_pb2.Operation(name='moop')
+
+        on_get_operation = mock.Mock(
+            spec=('__call__',), return_value=expected_response)
+
+        channel.GetOperation.response = on_get_operation
+
+        response = stub.GetOperation(expected_request)
+
+        assert response == expected_response
+        on_get_operation.assert_called_once_with(expected_request)
+
+    def test_multiple_responses(self):
+        channel = grpc_helpers.ChannelStub()
+        stub = operations_pb2.OperationsStub(channel)
+        expected_request = operations_pb2.GetOperationRequest(name='meep')
+        expected_responses = [
+            operations_pb2.Operation(name='foo'),
+            operations_pb2.Operation(name='bar'),
+            operations_pb2.Operation(name='baz'),
+        ]
+
+        channel.GetOperation.responses = iter(expected_responses)
+
+        response1 = stub.GetOperation(expected_request)
+        response2 = stub.GetOperation(expected_request)
+        response3 = stub.GetOperation(expected_request)
+
+        assert response1 == expected_responses[0]
+        assert response2 == expected_responses[1]
+        assert response3 == expected_responses[2]
+        assert channel.requests == [('GetOperation', expected_request)] * 3
+        assert channel.GetOperation.requests == [expected_request] * 3
+
+        with pytest.raises(StopIteration):
+            stub.GetOperation(expected_request)
+
+    def test_multiple_responses_and_single_response_error(self):
+        channel = grpc_helpers.ChannelStub()
+        stub = operations_pb2.OperationsStub(channel)
+        channel.GetOperation.responses = []
+        channel.GetOperation.response = mock.sentinel.response
+
+        with pytest.raises(ValueError):
+            stub.GetOperation(operations_pb2.GetOperationRequest())
+
+    def test_call_info(self):
+        channel = grpc_helpers.ChannelStub()
+        stub = operations_pb2.OperationsStub(channel)
+        expected_request = operations_pb2.GetOperationRequest(name='meep')
+        expected_response = operations_pb2.Operation(name='moop')
+        expected_metadata = [('red', 'blue'), ('two', 'shoe')]
+        expected_credentials = mock.sentinel.credentials
+        channel.GetOperation.response = expected_response
+
+        response = stub.GetOperation(
+            expected_request, timeout=42, metadata=expected_metadata,
+            credentials=expected_credentials)
+
+        assert response == expected_response
+        assert channel.requests == [('GetOperation', expected_request)]
+        assert channel.GetOperation.calls == [
+            (expected_request, 42, expected_metadata, expected_credentials)]
+
+    def test_unary_unary(self):
+        channel = grpc_helpers.ChannelStub()
+        method_name = 'GetOperation'
+        callable_stub = channel.unary_unary(method_name)
+        assert callable_stub._method == method_name
+        assert callable_stub._channel == channel
+
+    def test_unary_stream(self):
+        channel = grpc_helpers.ChannelStub()
+        method_name = 'GetOperation'
+        callable_stub = channel.unary_stream(method_name)
+        assert callable_stub._method == method_name
+        assert callable_stub._channel == channel
+
+    def test_stream_unary(self):
+        channel = grpc_helpers.ChannelStub()
+        method_name = 'GetOperation'
+        callable_stub = channel.stream_unary(method_name)
+        assert callable_stub._method == method_name
+        assert callable_stub._channel == channel
+
+    def test_stream_stream(self):
+        channel = grpc_helpers.ChannelStub()
+        method_name = 'GetOperation'
+        callable_stub = channel.stream_stream(method_name)
+        assert callable_stub._method == method_name
+        assert callable_stub._channel == channel
+
+    def test_subscribe_unsubscribe(self):
+        channel = grpc_helpers.ChannelStub()
+        assert channel.subscribe(None) is None
+        assert channel.unsubscribe(None) is None