pw_rpc: Update Python client; support streaming

- Update the Python client to reflect the pw_rpc protocol updates.
- Add support for server streaming RPCs to the Python client.
- Update and simplify the ClientImpl interface. The new structure keeps
  most logic in the generic client code while giving the ClientImpl full
  control over method invocation semantics.
- Revamp the "SimpleClient" ClientImpl implementation and move it to the
  callback_client.py module.

Change-Id: Ie610a45bf50e550e941a2713108be329573a4d24
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/14120
Commit-Queue: Wyatt Hepler <hepler@google.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
diff --git a/pw_rpc/py/callback_client_test.py b/pw_rpc/py/callback_client_test.py
new file mode 100755
index 0000000..bb46031
--- /dev/null
+++ b/pw_rpc/py/callback_client_test.py
@@ -0,0 +1,304 @@
+#!/usr/bin/env python3
+# Copyright 2020 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""Tests using the callback client for pw_rpc."""
+
+import unittest
+from unittest import mock
+from typing import List, Tuple
+
+from pw_protobuf_compiler import python_protos
+from pw_rpc import callback_client, client, packets
+from pw_status import Status
+
+TEST_PROTO_1 = """\
+syntax = "proto3";
+
+package pw.call.test1;
+
+message SomeMessage {
+  uint32 magic_number = 1;
+}
+
+message AnotherMessage {
+  enum Result {
+    FAILED = 0;
+    FAILED_MISERABLY = 1;
+    I_DONT_WANT_TO_TALK_ABOUT_IT = 2;
+  }
+
+  Result result = 1;
+  string payload = 2;
+}
+
+service PublicService {
+  rpc SomeUnary(SomeMessage) returns (AnotherMessage) {}
+  rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {}
+  rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {}
+  rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {}
+}
+"""
+
+
+class CallbackClientImplTest(unittest.TestCase):
+    """Tests the callback_client as used within a pw_rpc Client."""
+    def setUp(self):
+        self._protos = python_protos.Library.from_strings(TEST_PROTO_1)
+
+        self._client = client.Client.from_modules(
+            callback_client.Impl(), [client.Channel(1, self._handle_request)],
+            self._protos.modules())
+
+        self._last_request: packets.RpcPacket = None
+        self._next_packets: List[Tuple[bytes, bool]] = []
+
+    def _enqueue_response(self,
+                          channel_id: int,
+                          method=None,
+                          status: Status = Status.OK,
+                          response=b'',
+                          *,
+                          ids: Tuple[int, int] = None,
+                          valid=True):
+        if method:
+            assert ids is None
+            service_id, method_id = method.service.id, method.id
+        else:
+            assert ids is not None and method is None
+            service_id, method_id = ids
+
+        if isinstance(response, bytes):
+            payload = response
+        else:
+            payload = response.SerializeToString()
+
+        self._next_packets.append(
+            (packets.RpcPacket(channel_id=channel_id,
+                               service_id=service_id,
+                               method_id=method_id,
+                               status=status.value,
+                               payload=payload).SerializeToString(), valid))
+
+    def _enqueue_stream_end(self,
+                            channel_id: int,
+                            method,
+                            status: Status = Status.OK,
+                            valid=True):
+        self._next_packets.append(
+            (packets.RpcPacket(type=packets.PacketType.STREAM_END,
+                               channel_id=channel_id,
+                               service_id=method.service.id,
+                               method_id=method.id,
+                               status=status.value).SerializeToString(),
+             valid))
+
+    def _handle_request(self, data: bytes):
+        self._last_request = packets.decode(data)
+
+        for packet, valid in self._next_packets:
+            self.assertEqual(valid, self._client.process_packet(packet))
+
+        self._next_packets.clear()
+
+    def _sent_payload(self, message_type):
+        self.assertIsNotNone(self._last_request)
+        message = message_type()
+        message.ParseFromString(self._last_request.payload)
+        return message
+
+    def test_invoke_unary_rpc(self):
+        stub = self._client.channel(1).call.PublicService
+        method = stub.SomeUnary.method
+
+        for _ in range(3):
+            self._enqueue_response(1, method, Status.ABORTED,
+                                   method.response_type(payload='0_o'))
+
+            status, response = stub.SomeUnary(magic_number=6)
+
+            self.assertEqual(
+                6,
+                self._sent_payload(method.request_type).magic_number)
+
+            self.assertIs(Status.ABORTED, status)
+            self.assertEqual('0_o', response.payload)
+
+    def test_invoke_unary_rpc_with_callback(self):
+        stub = self._client.channel(1).call.PublicService
+        method = stub.SomeUnary.method
+
+        for _ in range(3):
+            self._enqueue_response(1, method, Status.ABORTED,
+                                   method.response_type(payload='0_o'))
+
+            callback = mock.Mock()
+            stub.SomeUnary.with_callback(callback, magic_number=5)
+
+            callback.assert_called_once_with(
+                Status.ABORTED, method.response_type(payload='0_o'))
+
+            self.assertEqual(
+                5,
+                self._sent_payload(method.request_type).magic_number)
+
+    def test_invoke_unary_rpc_callback_errors_suppressed(self):
+        stub = self._client.channel(1).call.PublicService.SomeUnary
+
+        self._enqueue_response(1, stub.method)
+        exception_msg = 'YOU BROKE IT O-]-<'
+
+        with self.assertLogs(callback_client.__name__, 'ERROR') as logs:
+            stub.with_callback(mock.Mock(side_effect=Exception(exception_msg)))
+
+        self.assertIn(exception_msg, ''.join(logs.output))
+
+        # Make sure we can still invoke the RPC.
+        self._enqueue_response(1, stub.method, Status.UNKNOWN)
+        status, _ = stub()
+        self.assertIs(status, Status.UNKNOWN)
+
+    def test_invoke_unary_rpc_with_callback_cancel(self):
+        stub = self._client.channel(1).call.PublicService
+        callback = mock.Mock()
+
+        for _ in range(3):
+            call = stub.SomeUnary.with_callback(callback, magic_number=55)
+
+            self.assertIsNotNone(self._last_request)
+            self._last_request = None
+
+            # Try to call the RPC again before cancelling, which is an error.
+            with self.assertRaises(client.Error):
+                stub.SomeUnary.with_callback(callback, magic_number=56)
+
+            self.assertTrue(call.cancel())
+            self.assertFalse(call.cancel())  # Already cancelled, returns False
+
+            # Unary RPCs do not send a cancel request to the server.
+            self.assertIsNone(self._last_request)
+
+        callback.assert_not_called()
+
+    def test_invoke_server_streaming(self):
+        stub = self._client.channel(1).call.PublicService
+        method = stub.SomeServerStreaming.method
+
+        rep1 = method.response_type(payload='!!!')
+        rep2 = method.response_type(payload='?')
+
+        for _ in range(3):
+            self._enqueue_response(1, method, response=rep1)
+            self._enqueue_response(1, method, response=rep2)
+            self._enqueue_stream_end(1, method, Status.ABORTED)
+
+            self.assertEqual([rep1, rep2],
+                             list(stub.SomeServerStreaming(magic_number=4)))
+
+            self.assertEqual(
+                4,
+                self._sent_payload(method.request_type).magic_number)
+
+    def test_invoke_server_streaming_with_callback(self):
+        stub = self._client.channel(1).call.PublicService
+        method = stub.SomeServerStreaming.method
+
+        rep1 = method.response_type(payload='!!!')
+        rep2 = method.response_type(payload='?')
+
+        for _ in range(3):
+            self._enqueue_response(1, method, response=rep1)
+            self._enqueue_response(1, method, response=rep2)
+            self._enqueue_stream_end(1, method, Status.ABORTED)
+
+            callback = mock.Mock()
+            stub.SomeServerStreaming.with_callback(callback, magic_number=3)
+
+            callback.assert_has_calls([
+                mock.call(None, method.response_type(payload='!!!')),
+                mock.call(None, method.response_type(payload='?')),
+                mock.call(Status.ABORTED, None),
+            ])
+
+            self.assertEqual(
+                3,
+                self._sent_payload(method.request_type).magic_number)
+
+    def test_invoke_server_streaming_with_callback_cancel(self):
+        stub = self._client.channel(1).call.PublicService.SomeServerStreaming
+
+        resp = stub.method.response_type(payload='!!!')
+        self._enqueue_response(1, stub.method, response=resp)
+
+        callback = mock.Mock()
+        call = stub.with_callback(callback, magic_number=3)
+        callback.assert_called_once_with(
+            None, stub.method.response_type(payload='!!!'))
+
+        callback.reset_mock()
+
+        call.cancel()
+
+        self.assertIs(self._last_request.type, packets.PacketType.CANCEL)
+
+        # Ensure the RPC can be called after being cancelled.
+        self._enqueue_response(1, stub.method, response=resp)
+        self._enqueue_stream_end(1, stub.method, Status.OK)
+
+        call = stub.with_callback(callback, magic_number=3)
+
+        callback.assert_has_calls([
+            mock.call(None, stub.method.response_type(payload='!!!')),
+            mock.call(Status.OK, None),
+        ])
+
+    def test_ignore_bad_packets_with_pending_rpc(self):
+        rpcs = self._client.channel(1).call
+        method = rpcs.PublicService.SomeUnary.method
+        service_id = method.service.id
+
+        # Unknown channel
+        self._enqueue_response(999, method, valid=False)
+        # Bad service
+        self._enqueue_response(1, ids=(999, method.id), valid=False)
+        # Bad method
+        self._enqueue_response(1, ids=(service_id, 999), valid=False)
+        # For RPC not pending (valid=True because the packet is processed)
+        self._enqueue_response(
+            1,
+            ids=(service_id, rpcs.PublicService.SomeBidiStreaming.method.id),
+            valid=True)
+
+        self._enqueue_response(1, method, valid=True)
+
+        status, response = rpcs.PublicService.SomeUnary(magic_number=6)
+        self.assertIs(Status.OK, status)
+        self.assertEqual('', response.payload)
+
+    def test_pass_none_if_payload_fails_to_decode(self):
+        rpcs = self._client.channel(1).call
+        method = rpcs.PublicService.SomeUnary.method
+
+        self._enqueue_response(1,
+                               method,
+                               Status.OK,
+                               b'INVALID DATA!!!',
+                               valid=True)
+
+        status, response = rpcs.PublicService.SomeUnary(magic_number=6)
+        self.assertIs(status, Status.OK)
+        self.assertIsNone(response)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/pw_rpc/py/client_test.py b/pw_rpc/py/client_test.py
index 7569cf0..a2486b6 100755
--- a/pw_rpc/py/client_test.py
+++ b/pw_rpc/py/client_test.py
@@ -15,15 +15,12 @@
 """Tests creating pw_rpc client."""
 
 import unittest
-import tempfile
-from pathlib import Path
-from typing import List, Tuple
 
 from pw_protobuf_compiler import python_protos
-from pw_rpc import client, ids, packets
-from pw_status import Status
+from pw_rpc import callback_client, client, packets
+import pw_rpc.ids
 
-TEST_PROTO_1 = b"""\
+TEST_PROTO_1 = """\
 syntax = "proto3";
 
 package pw.call.test1;
@@ -51,7 +48,7 @@
 }
 """
 
-TEST_PROTO_2 = b"""\
+TEST_PROTO_2 = """\
 syntax = "proto2";
 
 package pw.call.test2;
@@ -74,60 +71,14 @@
 
 
 class ClientTest(unittest.TestCase):
-    """Tests the pw_rpc Python client."""
+    """Tests the pw_rpc Client independently of the ClientImpl."""
     def setUp(self):
-        self._proto_dir = tempfile.TemporaryDirectory(prefix='proto_test')
-        protos = []
+        self._protos = python_protos.Library.from_strings(
+            [TEST_PROTO_1, TEST_PROTO_2])
 
-        for i, contents in enumerate([TEST_PROTO_1, TEST_PROTO_2]):
-            protos.append(Path(self._proto_dir.name, f'test_{i}.proto'))
-            protos[-1].write_bytes(contents)
-
-        self._protos = python_protos.Library(
-            python_protos.compile_and_import(protos))
-
-        self._impl = client.SimpleSynchronousClient()
-        self._client = client.Client.from_modules(
-            self._impl, [client.Channel(1, self._handle_request)],
-            self._protos.modules())
-
-        self._last_request: packets.RpcPacket = None
-        self._next_packets: List[Tuple[bytes, bool]] = []
-
-    def tearDown(self):
-        self._proto_dir.cleanup()
-
-    def _enqueue_response(self,
-                          channel_id: int,
-                          service_id: int,
-                          method_id: int,
-                          status: Status = Status.OK,
-                          response=b'',
-                          valid=True):
-        if isinstance(response, bytes):
-            payload = response
-        else:
-            payload = response.SerializeToString()
-
-        self._next_packets.append(
-            (packets.RpcPacket(channel_id=channel_id,
-                               service_id=service_id,
-                               method_id=method_id,
-                               status=status.value,
-                               payload=payload).SerializeToString(), valid))
-
-    def _handle_request(self, data: bytes):
-        self._last_request = packets.decode(data)
-
-        self.assertTrue(self._next_packets)
-        for packet, valid in self._next_packets:
-            self.assertEqual(valid, self._client.process_packet(packet))
-
-    def _last_payload(self, message_type):
-        self.assertIsNotNone(self._last_request)
-        message = message_type()
-        message.ParseFromString(self._last_request.payload)
-        return message
+        self._client = client.Client.from_modules(callback_client.Impl(),
+                                                  [client.Channel(1, None)],
+                                                  self._protos.modules())
 
     def test_access_service_client_as_attribute_or_index(self):
         self.assertIs(
@@ -135,7 +86,8 @@
             self._client.channel(1).call['PublicService'])
         self.assertIs(
             self._client.channel(1).call.PublicService,
-            self._client.channel(1).call[ids.calculate('PublicService')])
+            self._client.channel(1).call[pw_rpc.ids.calculate(
+                'PublicService')])
 
     def test_access_method_client_as_attribute_or_index(self):
         self.assertIs(
@@ -143,94 +95,42 @@
             self._client.channel(1).call['Alpha']['Unary'])
         self.assertIs(
             self._client.channel(1).call.Alpha.Unary,
-            self._client.channel(1).call['Alpha'][ids.calculate('Unary')])
+            self._client.channel(1).call['Alpha'][pw_rpc.ids.calculate(
+                'Unary')])
 
     def test_check_for_presence_of_services(self):
         self.assertIn('PublicService', self._client.channel(1).call)
-        self.assertIn(ids.calculate('PublicService'),
+        self.assertIn(pw_rpc.ids.calculate('PublicService'),
                       self._client.channel(1).call)
         self.assertNotIn('NotAService', self._client.channel(1).call)
         self.assertNotIn(-1213, self._client.channel(1).call)
 
     def test_check_for_presence_of_methods(self):
         self.assertIn('SomeUnary', self._client.channel(1).call.PublicService)
-        self.assertIn(ids.calculate('SomeUnary'),
+        self.assertIn(pw_rpc.ids.calculate('SomeUnary'),
                       self._client.channel(1).call.PublicService)
 
         self.assertNotIn('Unary', self._client.channel(1).call.PublicService)
         self.assertNotIn(12345, self._client.channel(1).call.PublicService)
 
-    def test_invoke_unary_rpc(self):
-        rpcs = self._client.channel(1).call
-        method = rpcs.PublicService.SomeUnary.method
-
-        for _ in range(3):
-            self._enqueue_response(1, method.service.id, method.id,
-                                   Status.ABORTED,
-                                   method.response_type(payload='0_o'))
-
-            status, response = rpcs.PublicService.SomeUnary(magic_number=6)
-
-            self.assertEqual(
-                6,
-                self._last_payload(method.request_type).magic_number)
-
-            self.assertIs(Status.ABORTED, status)
-            self.assertEqual('0_o', response.payload)
-
-    def test_ignore_bad_packets_with_pending_rpc(self):
-        rpcs = self._client.channel(1).call
-        method = rpcs.PublicService.SomeUnary.method
-        service_id = method.service.id
-
-        # Unknown channel
-        self._enqueue_response(999, service_id, method.id, valid=False)
-        # Bad service
-        self._enqueue_response(1, 999, method.id, valid=False)
-        # Bad method
-        self._enqueue_response(1, service_id, 999, valid=False)
-        # For RPC not pending (valid=True because the packet is processed)
-        self._enqueue_response(1,
-                               service_id,
-                               rpcs.PublicService.SomeBidiStreaming.method.id,
-                               valid=True)
-
-        self._enqueue_response(1, service_id, method.id, valid=True)
-
-        status, response = rpcs.PublicService.SomeUnary(magic_number=6)
-        self.assertIs(Status.OK, status)
-        self.assertEqual('', response.payload)
-
-    def test_pass_none_if_payload_fails_to_decode(self):
-        rpcs = self._client.channel(1).call
-        method = rpcs.PublicService.SomeUnary.method
-
-        self._enqueue_response(1,
-                               method.service.id,
-                               method.id,
-                               Status.OK,
-                               b'INVALID DATA!!!',
-                               valid=True)
-
-        status, response = rpcs.PublicService.SomeUnary(magic_number=6)
-        self.assertIs(status, Status.OK)
-        self.assertIsNone(response)
-
-    def test_call_method_with_both_message_and_kwargs(self):
+    def test_method_get_request_with_both_message_and_kwargs(self):
         req = self._client.services['Alpha'].methods['Unary'].request_type()
 
         with self.assertRaisesRegex(TypeError, r'either'):
-            self._client.channel(1).call.Alpha.Unary(req, magic_number=1.0)
+            self._client.services['Alpha'].methods['Unary'].get_request(
+                req, {'magic_number': 1.0})
 
-    def test_call_method_with_wrong_type(self):
+    def test_method_get_request_with_wrong_type(self):
         with self.assertRaisesRegex(TypeError, r'pw\.call\.test2\.Request'):
-            self._client.channel(1).call.Alpha.Unary('This is str!')
+            self._client.services['Alpha'].methods['Unary'].get_request(
+                'str!', {})
 
-    def test_call_method_with_incorrect_message_type(self):
+    def test_method_get_with_incorrect_message_type(self):
         msg = self._protos.packages.pw.call.test1.AnotherMessage()
         with self.assertRaisesRegex(TypeError,
                                     r'pw\.call\.test1\.SomeMessage'):
-            self._client.channel(1).call.PublicService.SomeUnary(msg)
+            self._client.services['PublicService'].methods[
+                'SomeUnary'].get_request(msg, {})
 
     def test_process_packet_invalid_proto_data(self):
         self.assertFalse(self._client.process_packet(b'NOT a packet!'))
@@ -238,20 +138,23 @@
     def test_process_packet_unrecognized_channel(self):
         self.assertFalse(
             self._client.process_packet(
-                packets.encode((123, 456, 789),
-                               self._protos.packages.pw.call.test2.Request())))
+                packets.encode_request(
+                    (123, 456, 789),
+                    self._protos.packages.pw.call.test2.Request())))
 
     def test_process_packet_unrecognized_service(self):
         self.assertFalse(
             self._client.process_packet(
-                packets.encode((1, 456, 789),
-                               self._protos.packages.pw.call.test2.Request())))
+                packets.encode_request(
+                    (1, 456, 789),
+                    self._protos.packages.pw.call.test2.Request())))
 
     def test_process_packet_unrecognized_method(self):
         self.assertFalse(
             self._client.process_packet(
-                packets.encode((1, next(iter(self._client.services)).id, 789),
-                               self._protos.packages.pw.call.test2.Request())))
+                packets.encode_request(
+                    (1, next(iter(self._client.services)).id, 789),
+                    self._protos.packages.pw.call.test2.Request())))
 
 
 if __name__ == '__main__':
diff --git a/pw_rpc/py/packets_test.py b/pw_rpc/py/packets_test.py
index 4b54caf..216af86 100755
--- a/pw_rpc/py/packets_test.py
+++ b/pw_rpc/py/packets_test.py
@@ -18,7 +18,8 @@
 
 from pw_rpc import packets
 
-_TEST_PACKET = packets.RpcPacket(
+_TEST_REQUEST = packets.RpcPacket(
+    type=packets.PacketType.RPC,
     channel_id=1,
     service_id=2,
     method_id=3,
@@ -26,16 +27,29 @@
 
 
 class PacketsTest(unittest.TestCase):
-    def test_encode(self):
-        data = packets.encode((1, 2, 3), packets.RpcPacket(status=321))
+    def test_encode_request(self):
+        data = packets.encode_request((1, 2, 3), packets.RpcPacket(status=321))
         packet = packets.RpcPacket()
         packet.ParseFromString(data)
 
-        self.assertEqual(_TEST_PACKET, packet)
+        self.assertEqual(_TEST_REQUEST, packet)
+
+    def test_encode_cancel(self):
+        data = packets.encode_cancel((9, 8, 7))
+
+        packet = packets.RpcPacket()
+        packet.ParseFromString(data)
+
+        self.assertEqual(
+            packet,
+            packets.RpcPacket(type=packets.PacketType.CANCEL,
+                              channel_id=9,
+                              service_id=8,
+                              method_id=7))
 
     def test_decode(self):
-        self.assertEqual(_TEST_PACKET,
-                         packets.decode(_TEST_PACKET.SerializeToString()))
+        self.assertEqual(_TEST_REQUEST,
+                         packets.decode(_TEST_REQUEST.SerializeToString()))
 
 
 if __name__ == '__main__':
diff --git a/pw_rpc/py/pw_rpc/callback_client.py b/pw_rpc/py/pw_rpc/callback_client.py
new file mode 100644
index 0000000..a2b2d89
--- /dev/null
+++ b/pw_rpc/py/pw_rpc/callback_client.py
@@ -0,0 +1,191 @@
+# Copyright 2020 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""Defines a callback-based RPC ClientImpl to use with pw_rpc.client.Client.
+
+callback_client.Impl supports invoking RPCs synchronously or asynchronously.
+Asynchronous invocations use a callback.
+
+Synchronous invocations look like a function call:
+
+  status, response = client.channel(1).call.MyServer.MyUnary(some_field=123)
+
+  # Streaming calls return an iterable of responses
+  for reply in client.channel(1).call.MyService.MyServerStreaming(request):
+      pass
+
+Asynchronous invocations pass a callback in addition to the request. The
+callback must be a callable that accepts a status and a payload, either of
+which may be None. The Status is only set when the RPC is completed.
+
+  callback = lambda status, payload: print('Response:', status, payload)
+
+  call = client.channel(1).call.MyServer.MyUnary.with_callback(
+      callback, some_field=123)
+
+  call = client.channel(1).call.MyService.MyServerStreaming.with_callback(
+      callback, request):
+
+When invoking a method, requests may be provided as a message object or as
+kwargs for the message fields (but not both).
+"""
+
+import logging
+import queue
+from typing import Any, Callable, Optional, Tuple
+
+from pw_rpc import client
+from pw_rpc.descriptors import Channel, Method, Service
+from pw_status import Status
+
+_LOG = logging.getLogger(__name__)
+
+UnaryCallback = Callable[[Status, Any], Any]
+Callback = Callable[[Optional[Status], Any], Any]
+
+
+class _MethodClient:
+    """A method that can be invoked for a particular channel."""
+    def __init__(self, client_impl: 'Impl', rpcs: client.PendingRpcs,
+                 channel: Channel, method: Method):
+        self._impl = client_impl
+        self._rpcs = rpcs
+        self._rpc = client.PendingRpc(channel, method.service, method)
+
+    @property
+    def channel(self) -> Channel:
+        return self._rpc.channel
+
+    @property
+    def method(self) -> Method:
+        return self._rpc.method
+
+    @property
+    def service(self) -> Service:
+        return self._rpc.service
+
+
+class _AsyncCall:
+    """Represents an ongoing callback-based call."""
+
+    # TODO(hepler): Consider alternatives (futures) and/or expand functionality.
+
+    def __init__(self, rpcs: client.PendingRpcs, rpc: client.PendingRpc):
+        self.rpc = rpc
+        self._rpcs = rpcs
+
+    def cancel(self) -> bool:
+        return self._rpcs.cancel(self.rpc)
+
+    def __enter__(self) -> '_AsyncCall':
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback) -> None:
+        self.cancel()
+
+
+class _StreamingResponses:
+    """Used to iterate over a queue.SimpleQueue."""
+    def __init__(self, responses: queue.SimpleQueue):
+        self._queue = responses
+        self.status: Optional[Status] = None
+
+    def get(self, block: bool = True, timeout_s: float = None):
+        while True:
+            self.status, response = self._queue.get(block, timeout_s)
+            if self.status is not None:
+                return
+
+            yield response
+
+    def __iter__(self):
+        return self.get()
+
+
+class UnaryMethodClient(_MethodClient):
+    def __call__(self, _request=None, **request_fields) -> Tuple[Status, Any]:
+        responses: queue.SimpleQueue = queue.SimpleQueue()
+        self.with_callback(
+            lambda status, payload: responses.put((status, payload)), _request,
+            **request_fields)
+        return responses.get()
+
+    def with_callback(self,
+                      callback: UnaryCallback,
+                      _request=None,
+                      **request_fields):
+        self._rpcs.invoke(self._rpc, callback, _request, **request_fields)
+        return _AsyncCall(self._rpcs, self._rpc)
+
+
+class ServerStreamingMethodClient(_MethodClient):
+    def __call__(self, _request=None, **request_fields) -> _StreamingResponses:
+        responses: queue.SimpleQueue = queue.SimpleQueue()
+        self.with_callback(
+            lambda status, payload: responses.put((status, payload)), _request,
+            **request_fields)
+        return _StreamingResponses(responses)
+
+    def with_callback(self,
+                      callback: Callback,
+                      _request=None,
+                      **request_fields):
+        self._rpcs.invoke(self._rpc, callback, _request, **request_fields)
+        return _AsyncCall(self._rpcs, self._rpc)
+
+
+class ClientStreamingMethodClient(_MethodClient):
+    def __call__(self):
+        raise NotImplementedError
+
+    def with_callback(self, callback: Callback):
+        raise NotImplementedError
+
+
+class BidirectionalStreamingMethodClient(_MethodClient):
+    def __call__(self):
+        raise NotImplementedError
+
+    def with_callback(self, callback: Callback):
+        raise NotImplementedError
+
+
+class Impl(client.ClientImpl):
+    """Callback-based client.ClientImpl."""
+    def method_client(self, rpcs: client.PendingRpcs, channel: Channel,
+                      method: Method) -> _MethodClient:
+        """Returns an object that invokes a method using the given chanel."""
+
+        if method.type is Method.Type.UNARY:
+            return UnaryMethodClient(self, rpcs, channel, method)
+
+        if method.type is Method.Type.SERVER_STREAMING:
+            return ServerStreamingMethodClient(self, rpcs, channel, method)
+
+        if method.type is Method.Type.CLIENT_STREAMING:
+            return ClientStreamingMethodClient(self, rpcs, channel, method)
+
+        if method.type is Method.Type.BIDI_STREAMING:
+            return BidirectionalStreamingMethodClient(self, rpcs, channel,
+                                                      method)
+
+        raise AssertionError(f'Unknown method type {method.type}')
+
+    def process_response(self, rpcs: client.PendingRpcs,
+                         rpc: client.PendingRpc, context,
+                         status: Optional[Status], payload) -> None:
+        try:
+            context(status, payload)
+        except:  # pylint: disable=bare-except
+            rpcs.cancel(rpc)
+            _LOG.exception('Callback %s for %s raised exception', context, rpc)
diff --git a/pw_rpc/py/pw_rpc/client.py b/pw_rpc/py/pw_rpc/client.py
index d69a57f..cb4b360 100644
--- a/pw_rpc/py/pw_rpc/client.py
+++ b/pw_rpc/py/pw_rpc/client.py
@@ -14,149 +14,105 @@
 """Creates an RPC client."""
 
 import abc
-from collections import defaultdict
 from dataclasses import dataclass
 import logging
-from queue import SimpleQueue
-from typing import Any, Collection, Dict, Iterable
+from typing import Collection, Dict, Iterable, List, NamedTuple, Optional
 
 from pw_rpc import descriptors, packets
-from pw_rpc.descriptors import Channel, Service, Method, PendingRpc
+from pw_rpc.descriptors import Channel, Service, Method
 from pw_status import Status
 
 _LOG = logging.getLogger(__name__)
 
 
+class Error(Exception):
+    """Error from incorrectly using the RPC client classes."""
+
+
+class PendingRpc(NamedTuple):
+    """Uniquely identifies an RPC call."""
+    channel: Channel
+    service: Service
+    method: Method
+
+
+class PendingRpcs:
+    """Internal object for tracking whether an RPC is pending."""
+    def __init__(self):
+        self._pending: Dict[PendingRpc, List] = {}
+
+    # Use underscores to prevent potential conflicts with request field names.
+    def invoke(self,
+               _rpc: PendingRpc,
+               _context,
+               _request=None,
+               **request_fields):
+        # Ensure that every context is a unique object by wrapping it in a list.
+        context = [_context]
+
+        # Check that the context was added; if not, the RPC was already pending.
+        if self._pending.setdefault(_rpc, context) is not context:
+            raise Error(f'Sent request for {_rpc}, but it is already pending! '
+                        'Cancel the RPC before invoking it again')
+
+        _LOG.debug('Starting %s', _rpc)
+        request = _rpc.method.get_request(_request, request_fields)
+        _rpc.channel.output(packets.encode_request(_rpc, request))
+
+    def cancel(self, rpc: PendingRpc) -> bool:
+        """Cancels the RPC, including sending a CANCEL packet.
+
+        Returns:
+          True if the RPC was cancelled; False if it was not pending
+        """
+        try:
+            _LOG.debug('Cancelling %s', rpc)
+            del self._pending[rpc]
+        except KeyError:
+            return False
+
+        if rpc.method.type is not Method.Type.UNARY:
+            rpc.channel.output(packets.encode_cancel(rpc))
+
+        return True
+
+    def get_pending(self, rpc: PendingRpc, status: Optional[Status]):
+        if status is None:
+            return self._pending[rpc][0]  # Unwrap the context from the list
+
+        _LOG.debug('Finishing %s with status %s', rpc, status)
+        return self._pending.pop(rpc)[0]
+
+
 class ClientImpl(abc.ABC):
     """The internal interface of the RPC client.
 
     This interface defines the semantics for invoking an RPC on a particular
-    client. The return values can objects that provide for synchronous or
-    asynchronous behavior.
+    client.
     """
     @abc.abstractmethod
-    def invoke_unary(self, rpc: PendingRpc, request) -> Any:
-        """Invokes a unary RPC."""
+    def method_client(self, rpcs: PendingRpcs, channel: Channel,
+                      method: Method):
+        """Returns an object that invokes a method using the given channel."""
 
     @abc.abstractmethod
-    def invoke_server_streaming(self, rpc: PendingRpc, request) -> Any:
-        """Invokes a server streaming RPC."""
+    def process_response(self, rpcs: PendingRpcs, rpc: PendingRpc, context,
+                         status: Optional[Status], payload) -> None:
+        """Processes a response from the RPC server.
 
-    @abc.abstractmethod
-    def invoke_client_streaming(self, rpc: PendingRpc) -> Any:
-        """Invokes a client streaming streaming RPC."""
-
-    @abc.abstractmethod
-    def invoke_bidirectional_streaming(self, rpc: PendingRpc) -> Any:
-        """Invokes a bidirectional streaming streaming RPC."""
-
-    @abc.abstractmethod
-    def process_response(self, rpc: PendingRpc, payload,
-                         status: Status) -> None:
-        """Processes a response from the RPC server."""
-
-
-class SimpleSynchronousClient(ClientImpl):
-    """A client that blocks until a response is received for unary RPCs."""
-    def __init__(self):
-        self._responses: Dict[PendingRpc,
-                              SimpleQueue] = defaultdict(SimpleQueue)
-        self._pending: Dict[PendingRpc, bool] = defaultdict(bool)
-
-    def invoke_unary(self, rpc: PendingRpc, request: packets.Message):
-        queue = self._responses[rpc]
-
-        assert not self._pending[rpc], f'{rpc} is already pending!'
-        self._pending[rpc] = True
-
-        try:
-            rpc.channel.output(packets.encode(rpc, request))
-            result = queue.get()
-        finally:
-            self._pending[rpc] = False
-        return result
-
-    def invoke_server_streaming(self, rpc: PendingRpc, request):
-        raise NotImplementedError
-
-    def invoke_client_streaming(self, rpc: PendingRpc):
-        raise NotImplementedError
-
-    def invoke_bidirectional_streaming(self, rpc: PendingRpc):
-        raise NotImplementedError
-
-    def process_response(self, rpc: PendingRpc, payload,
-                         status: Status) -> None:
-        if not self._pending[rpc]:
-            _LOG.warning('Discarding packet for %s', rpc)
-            return
-
-        self._responses[rpc].put((status, payload))
-
-
-class _MethodClient:
-    """A method that can be invoked for a particular channel."""
-    @classmethod
-    def create(cls, client_impl: ClientImpl, channel: Channel, method: Method):
-        """Instantiates a _MethodClient according to the RPC type."""
-        if method.type is Method.Type.UNARY:
-            return UnaryMethodClient(client_impl, channel, method)
-
-        raise NotImplementedError('Streaming methods are not yet supported')
-
-    def __init__(self, client_impl: ClientImpl, channel: Channel,
-                 method: Method):
-        self._client_impl = client_impl
-        self.channel = channel
-        self.method = method
-
-    def _get_request(self, proto: packets.Message,
-                     kwargs: dict) -> packets.Message:
-        if proto and kwargs:
-            raise TypeError(
-                'Requests must be provided either as a message object or a '
-                'series of keyword args, but both were provided')
-
-        if proto is None:
-            return self.method.request_type(**kwargs)
-
-        if not isinstance(proto, self.method.request_type):
-            try:
-                bad_type = proto.DESCRIPTOR.full_name
-            except AttributeError:
-                bad_type = type(proto).__name__
-
-            raise TypeError(
-                f'Expected a message of type '
-                f'{self.method.request_type.DESCRIPTOR.full_name}, '
-                f'got {bad_type}')
-
-        return proto
-
-
-class UnaryMethodClient(_MethodClient):
-    # TODO(hepler): This function should make _request a positional-only
-    #     argument, to avoid confusion with keyword-specified protobuf fields.
-    #     However, yapf does not yet support Python 3.8's grammar, and
-    #     positional-only arguments crash it.
-    def __call__(self, _request=None, **request_fields):
-        """Invokes this unary method using its associated channel.
-
-        The request can be provided as either a message object or as keyword
-        arguments for the message's fields (but not both).
+        Args:
+          status: If set, this is the last packet for this RPC. None otherwise.
+          payload: A protobuf message, if present. None otherwise.
         """
-        return self._client_impl.invoke_unary(
-            PendingRpc(self.channel, self.method.service, self.method),
-            self._get_request(_request, request_fields))
 
 
-class _MethodClients(descriptors.ServiceAccessor[_MethodClient]):
+class _MethodClients(descriptors.ServiceAccessor):
     """Navigates the methods in a service provided by a ChannelClient."""
-    def __init__(self, client_impl: ClientImpl, channel: Channel,
-                 methods: Collection[Method]):
+    def __init__(self, rpcs: PendingRpcs, client_impl: ClientImpl,
+                 channel: Channel, methods: Collection[Method]):
         super().__init__(
             {
-                method.name: _MethodClient.create(client_impl, channel, method)
+                method.name: client_impl.method_client(rpcs, channel, method)
                 for method in methods
             },
             as_attrs=True)
@@ -164,16 +120,41 @@
 
 class _ServiceClients(descriptors.ServiceAccessor[_MethodClients]):
     """Navigates the services provided by a ChannelClient."""
-    def __init__(self, client_impl, channel: Channel,
+    def __init__(self, rpcs: PendingRpcs, client_impl, channel: Channel,
                  services: Collection[Service]):
         super().__init__(
             {
-                s.name: _MethodClients(client_impl, channel, s.methods)
+                s.name: _MethodClients(rpcs, client_impl, channel, s.methods)
                 for s in services
             },
             as_attrs=True)
 
 
+def _decode_status(rpc: PendingRpc, packet) -> Optional[Status]:
+    # STREAM_END and non-streaming RPC packets have a status.
+    if (packet.type is packets.PacketType.STREAM_END
+            or (packet.type is packets.PacketType.RPC
+                and not rpc.method.server_streaming)):
+        try:
+            return Status(packet.status)
+        except ValueError:
+            _LOG.warning('Illegal status code %d for %s', packet.status, rpc)
+            return None
+
+    return None
+
+
+def _decode_payload(rpc: PendingRpc, packet):
+    if packet.type is packets.PacketType.RPC:
+        try:
+            return packets.decode_payload(packet, rpc.method.response_type)
+        except packets.DecodeError as err:
+            _LOG.warning('Failed to decode %s response for %s: %s',
+                         rpc.method.response_type.DESCRIPTOR.full_name,
+                         rpc.method.full_name, err)
+    return None
+
+
 @dataclass(frozen=True, eq=False)
 class ChannelClient:
     """RPC services and methods bound to a particular channel.
@@ -208,12 +189,16 @@
 
     def __init__(self, impl: ClientImpl, channels: Iterable[Channel],
                  services: Iterable[Service]):
-        self.services = descriptors.Services(services)
         self._impl = impl
+        self.services = descriptors.Services(services)
+
+        self._rpcs = PendingRpcs()
+
         self._channels_by_id = {
-            channel.id:
-            ChannelClient(channel,
-                          _ServiceClients(self._impl, channel, self.services))
+            channel.id: ChannelClient(
+                channel,
+                _ServiceClients(self._rpcs, self._impl, channel,
+                                self.services))
             for channel in channels
         }
 
@@ -238,23 +223,24 @@
             return False
 
         try:
-            rpc = self._lookup_packet(packet)
+            rpc = self._lookup_rpc(packet)
         except ValueError as err:
             _LOG.warning('Unable to process packet: %s', err)
             return False
 
-        try:
-            response = packets.decode_payload(packet, rpc.method.response_type)
-        except packets.DecodeError as err:
-            response = None
-            _LOG.warning('Failed to decode %s response for %s: %s',
-                         rpc.method.response_type.DESCRIPTOR.full_name,
-                         rpc.method.full_name, err)
+        status = _decode_status(rpc, packet)
+        payload = _decode_payload(rpc, packet)
 
-        self._impl.process_response(rpc, response, Status(packet.status))
+        try:
+            context = self._rpcs.get_pending(rpc, status)
+        except KeyError:
+            _LOG.debug('Discarding response for %s, which is not pending', rpc)
+            return True  # Handled packet, even though it was invalid
+
+        self._impl.process_response(self._rpcs, rpc, context, status, payload)
         return True
 
-    def _lookup_packet(self, packet: packets.RpcPacket) -> PendingRpc:
+    def _lookup_rpc(self, packet: packets.RpcPacket) -> PendingRpc:
         try:
             channel_client = self._channels_by_id[packet.channel_id]
         except KeyError:
diff --git a/pw_rpc/py/pw_rpc/descriptors.py b/pw_rpc/py/pw_rpc/descriptors.py
index 1299e0f..95b2d36 100644
--- a/pw_rpc/py/pw_rpc/descriptors.py
+++ b/pw_rpc/py/pw_rpc/descriptors.py
@@ -15,9 +15,10 @@
 
 from dataclasses import dataclass
 import enum
-from typing import Any, Callable, Collection, Iterable, Iterator, NamedTuple
+from typing import Any, Callable, Collection, Dict, Iterable, Iterator, Tuple
 from typing import TypeVar, Union
 
+from google.protobuf import descriptor_pb2
 from pw_rpc import ids
 
 
@@ -52,6 +53,19 @@
         return f'Service({self.name!r})'
 
 
+def _streaming_attributes(method) -> Tuple[bool, bool]:
+    # TODO(hepler): Investigate adding server_streaming and client_streaming
+    #     attributes to the generated protobuf code. As a workaround,
+    #     deserialize the FileDescriptorProto to get that information.
+    service = method.containing_service
+
+    file_pb = descriptor_pb2.FileDescriptorProto()
+    file_pb.MergeFromString(service.file.serialized_pb)
+
+    method_pb = file_pb.service[service.index].method[method.index]  # pylint: disable=no-member
+    return method_pb.server_streaming, method_pb.client_streaming
+
+
 @dataclass(frozen=True, eq=False)
 class Method:
     """Describes a method in a service."""
@@ -59,49 +73,77 @@
     service: Service
     name: str
     id: int
-    type: 'Method.Type'
+    server_streaming: bool
+    client_streaming: bool
     request_type: Any
     response_type: Any
 
-    @property
-    def full_name(self) -> str:
-        return f'{self.service.name}.{self.name}'
-
-    class Type(enum.Enum):
-        UNARY = 0
-        SERVER_STREAMING = 1
-        CLIENT_STREAMING = 2
-        BIDI_STREAMING = 3
-
-        @classmethod
-        def from_descriptor(cls, unused_descriptor) -> 'Method.Type':
-            # TODO(hepler): Add server_streaming and client_streaming to
-            #     protobuf generated code, or access these attributes by
-            #     deserializing the FileDescriptor.
-            return cls.UNARY
-
     @classmethod
     def from_descriptor(cls, module, descriptor, service: Service):
         return Method(
             service,
             descriptor.name,
             ids.calculate(descriptor.name),
-            cls.Type.from_descriptor(descriptor),
+            *_streaming_attributes(descriptor),
             getattr(module, descriptor.input_type.name),
             getattr(module, descriptor.output_type.name),
         )
 
+    class Type(enum.Enum):
+        UNARY = 0
+        SERVER_STREAMING = 1
+        CLIENT_STREAMING = 2
+        BIDI_STREAMING = 3
+
+    @property
+    def full_name(self) -> str:
+        return f'{self.service.name}.{self.name}'
+
+    @property
+    def type(self) -> 'Method.Type':
+        if self.server_streaming and self.client_streaming:
+            return self.Type.BIDI_STREAMING
+
+        if self.server_streaming:
+            return self.Type.SERVER_STREAMING
+
+        if self.client_streaming:
+            return self.Type.CLIENT_STREAMING
+
+        return self.Type.UNARY
+
+    def get_request(self, proto, proto_kwargs: Dict[str, Any]):
+        """Returns a request_type protobuf message.
+
+        The client implementation may use this to support providing a request
+        as either a message object or as keyword arguments for the message's
+        fields (but not both).
+        """
+        if proto and proto_kwargs:
+            raise TypeError(
+                'Requests must be provided either as a message object or a '
+                'series of keyword args, but both were provided '
+                f'({proto!r} and {proto_kwargs!r})')
+
+        if proto is None:
+            return self.request_type(**proto_kwargs)
+
+        if not isinstance(proto, self.request_type):
+            try:
+                bad_type = proto.DESCRIPTOR.full_name
+            except AttributeError:
+                bad_type = type(proto).__name__
+
+            raise TypeError(f'Expected a message of type '
+                            f'{self.request_type.DESCRIPTOR.full_name}, '
+                            f'got {bad_type}')
+
+        return proto
+
     def __repr__(self) -> str:
         return f'Method({self.name!r})'
 
 
-class PendingRpc(NamedTuple):
-    """Uniquely identifies an RPC call."""
-    channel: Channel
-    service: Service
-    method: Method
-
-
 T = TypeVar('T')
 
 
diff --git a/pw_rpc/py/pw_rpc/packets.py b/pw_rpc/py/pw_rpc/packets.py
index 152d837..1b1c6f5 100644
--- a/pw_rpc/py/pw_rpc/packets.py
+++ b/pw_rpc/py/pw_rpc/packets.py
@@ -21,6 +21,7 @@
 packet_pb2 = python_protos.compile_and_import_file(
     os.path.join(__file__, '..', '..', '..', 'pw_rpc_protos', 'packet.proto'))
 
+PacketType = packet_pb2.PacketType
 RpcPacket = packet_pb2.RpcPacket
 
 DecodeError = message.DecodeError
@@ -39,12 +40,23 @@
     return payload
 
 
-def encode(rpc: tuple, request: message.Message) -> bytes:
-    channel, service, method = rpc
+def _ids(rpc: tuple) -> tuple:
+    return tuple(item if isinstance(item, int) else item.id for item in rpc)
 
-    return packet_pb2.RpcPacket(
-        type=packet_pb2.PacketType.RPC,
-        channel_id=channel if isinstance(channel, int) else channel.id,
-        service_id=service if isinstance(service, int) else service.id,
-        method_id=method if isinstance(method, int) else method.id,
-        payload=request.SerializeToString()).SerializeToString()
+
+def encode_request(rpc: tuple, request: message.Message) -> bytes:
+    channel, service, method = _ids(rpc)
+
+    return RpcPacket(type=PacketType.RPC,
+                     channel_id=channel,
+                     service_id=service,
+                     method_id=method,
+                     payload=request.SerializeToString()).SerializeToString()
+
+
+def encode_cancel(rpc: tuple) -> bytes:
+    channel, service, method = _ids(rpc)
+    return RpcPacket(type=PacketType.CANCEL,
+                     channel_id=channel,
+                     service_id=service,
+                     method_id=method).SerializeToString()
diff --git a/pw_rpc/py/setup.py b/pw_rpc/py/setup.py
index 44779c9..0d792bb 100644
--- a/pw_rpc/py/setup.py
+++ b/pw_rpc/py/setup.py
@@ -26,5 +26,6 @@
     install_requires=[
         'protobuf',
         'pw_protobuf',
+        'pw_protobuf_compiler',
     ],
 )