blob: 0373a071b796b028b229e558d4fb30a54ce28c47 [file] [log] [blame]
Wyatt Heplerdd4cce92020-07-17 12:24:23 -07001#!/usr/bin/env python3
Wyatt Heplercc9d2e92021-02-01 09:08:55 -08002# Copyright 2021 The Pigweed Authors
Wyatt Heplerdd4cce92020-07-17 12:24:23 -07003#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8# https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests using the callback client for pw_rpc."""
16
17import unittest
18from unittest import mock
19from typing import List, Tuple
20
21from pw_protobuf_compiler import python_protos
Wyatt Heplerdd4cce92020-07-17 12:24:23 -070022from pw_status import Status
23
Alexei Frolov10945aa2020-12-03 12:39:38 -080024from pw_rpc import callback_client, client, packets
Wyatt Heplerba325e42021-03-08 14:23:34 -080025from pw_rpc.internal import packet_pb2
Wyatt Hepler8ce90132020-12-03 10:57:20 -080026
Wyatt Heplerdd4cce92020-07-17 12:24:23 -070027TEST_PROTO_1 = """\
28syntax = "proto3";
29
Wyatt Hepler7d19c362020-07-20 08:12:15 -070030package pw.test1;
Wyatt Heplerdd4cce92020-07-17 12:24:23 -070031
32message SomeMessage {
33 uint32 magic_number = 1;
34}
35
36message AnotherMessage {
37 enum Result {
38 FAILED = 0;
39 FAILED_MISERABLY = 1;
40 I_DONT_WANT_TO_TALK_ABOUT_IT = 2;
41 }
42
43 Result result = 1;
44 string payload = 2;
45}
46
47service PublicService {
48 rpc SomeUnary(SomeMessage) returns (AnotherMessage) {}
49 rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {}
50 rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {}
51 rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {}
52}
53"""
54
55
Wyatt Hepler60161aa2020-07-21 23:29:47 -070056def _rpc(method_stub):
57 return client.PendingRpc(method_stub.channel, method_stub.method.service,
58 method_stub.method)
59
60
Wyatt Heplerdd4cce92020-07-17 12:24:23 -070061class CallbackClientImplTest(unittest.TestCase):
62 """Tests the callback_client as used within a pw_rpc Client."""
63 def setUp(self):
64 self._protos = python_protos.Library.from_strings(TEST_PROTO_1)
Wyatt Heplercc9d2e92021-02-01 09:08:55 -080065 self._request = self._protos.packages.pw.test1.SomeMessage
Wyatt Heplerdd4cce92020-07-17 12:24:23 -070066
67 self._client = client.Client.from_modules(
68 callback_client.Impl(), [client.Channel(1, self._handle_request)],
69 self._protos.modules())
Wyatt Heplercddc5cd2021-01-27 14:52:58 -080070 self._service = self._client.channel(1).rpcs.pw.test1.PublicService
Wyatt Heplerdd4cce92020-07-17 12:24:23 -070071
Wyatt Heplera56ab152020-08-12 17:06:31 -070072 self._last_request: packet_pb2.RpcPacket = None
Wyatt Hepler0f262352020-07-29 09:51:27 -070073 self._next_packets: List[Tuple[bytes, Status]] = []
74 self._send_responses_on_request = True
Wyatt Heplerdd4cce92020-07-17 12:24:23 -070075
76 def _enqueue_response(self,
77 channel_id: int,
78 method=None,
79 status: Status = Status.OK,
80 response=b'',
81 *,
82 ids: Tuple[int, int] = None,
Wyatt Hepler0f262352020-07-29 09:51:27 -070083 process_status=Status.OK):
Wyatt Heplerdd4cce92020-07-17 12:24:23 -070084 if method:
85 assert ids is None
86 service_id, method_id = method.service.id, method.id
87 else:
88 assert ids is not None and method is None
89 service_id, method_id = ids
90
91 if isinstance(response, bytes):
92 payload = response
93 else:
94 payload = response.SerializeToString()
95
96 self._next_packets.append(
Wyatt Heplera2970c52021-02-02 14:47:22 -080097 (packet_pb2.RpcPacket(type=packet_pb2.PacketType.RESPONSE,
Wyatt Heplera56ab152020-08-12 17:06:31 -070098 channel_id=channel_id,
99 service_id=service_id,
100 method_id=method_id,
101 status=status.value,
102 payload=payload).SerializeToString(),
Wyatt Hepler0f262352020-07-29 09:51:27 -0700103 process_status))
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700104
105 def _enqueue_stream_end(self,
106 channel_id: int,
107 method,
108 status: Status = Status.OK,
Wyatt Hepler0f262352020-07-29 09:51:27 -0700109 process_status=Status.OK):
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700110 self._next_packets.append(
Wyatt Heplera2970c52021-02-02 14:47:22 -0800111 (packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_STREAM_END,
Wyatt Heplera56ab152020-08-12 17:06:31 -0700112 channel_id=channel_id,
113 service_id=method.service.id,
114 method_id=method.id,
115 status=status.value).SerializeToString(),
Wyatt Hepler0f262352020-07-29 09:51:27 -0700116 process_status))
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700117
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800118 def _enqueue_error(self,
119 channel_id: int,
120 method,
121 status: Status,
122 process_status=Status.OK):
123 self._next_packets.append(
124 (packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_ERROR,
125 channel_id=channel_id,
126 service_id=method.service.id,
127 method_id=method.id,
128 status=status.value).SerializeToString(),
129 process_status))
130
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700131 def _handle_request(self, data: bytes):
Wyatt Hepler0f262352020-07-29 09:51:27 -0700132 # Disable this method to prevent infinite recursion if processing the
133 # packet happens to send another packet.
134 if not self._send_responses_on_request:
135 return
136
137 self._send_responses_on_request = False
138
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700139 self._last_request = packets.decode(data)
140
Wyatt Hepler0f262352020-07-29 09:51:27 -0700141 for packet, status in self._next_packets:
142 self.assertIs(status, self._client.process_packet(packet))
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700143
144 self._next_packets.clear()
Wyatt Hepler0f262352020-07-29 09:51:27 -0700145 self._send_responses_on_request = True
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700146
147 def _sent_payload(self, message_type):
148 self.assertIsNotNone(self._last_request)
149 message = message_type()
150 message.ParseFromString(self._last_request.payload)
151 return message
152
153 def test_invoke_unary_rpc(self):
Wyatt Heplercddc5cd2021-01-27 14:52:58 -0800154 method = self._service.SomeUnary.method
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700155
156 for _ in range(3):
157 self._enqueue_response(1, method, Status.ABORTED,
158 method.response_type(payload='0_o'))
159
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800160 status, response = self._service.SomeUnary(
161 method.request_type(magic_number=6))
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700162
163 self.assertEqual(
164 6,
165 self._sent_payload(method.request_type).magic_number)
166
167 self.assertIs(Status.ABORTED, status)
168 self.assertEqual('0_o', response.payload)
169
Wyatt Heplere0ae9b12021-04-01 17:56:04 -0700170 def test_invoke_unary_rpc_keep_open(self) -> None:
171 method = self._service.SomeUnary.method
172
173 payload_1 = method.response_type(payload='-_-')
174 payload_2 = method.response_type(payload='0_o')
175
176 self._enqueue_response(1, method, Status.ABORTED, payload_1)
177
178 replies: list = []
179 enqueue_replies = lambda _, reply: replies.append(reply)
180
181 self._service.SomeUnary.invoke(method.request_type(magic_number=6),
182 enqueue_replies,
183 enqueue_replies,
184 keep_open=True)
185
186 self.assertEqual([payload_1, Status.ABORTED], replies)
187
188 # Send another packet and make sure it is processed even though the RPC
189 # terminated.
190 self._client.process_packet(
191 packet_pb2.RpcPacket(
192 type=packet_pb2.PacketType.RESPONSE,
193 channel_id=1,
194 service_id=method.service.id,
195 method_id=method.id,
196 status=Status.OK.value,
197 payload=payload_2.SerializeToString()).SerializeToString())
198
199 self.assertEqual([payload_1, Status.ABORTED, payload_2, Status.OK],
200 replies)
201
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700202 def test_invoke_unary_rpc_with_callback(self):
Wyatt Heplercddc5cd2021-01-27 14:52:58 -0800203 method = self._service.SomeUnary.method
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700204
205 for _ in range(3):
206 self._enqueue_response(1, method, Status.ABORTED,
207 method.response_type(payload='0_o'))
208
209 callback = mock.Mock()
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800210 self._service.SomeUnary.invoke(self._request(magic_number=5),
211 callback, callback)
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700212
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800213 callback.assert_has_calls([
214 mock.call(_rpc(self._service.SomeUnary),
215 method.response_type(payload='0_o')),
216 mock.call(_rpc(self._service.SomeUnary), Status.ABORTED)
217 ])
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700218
219 self.assertEqual(
220 5,
221 self._sent_payload(method.request_type).magic_number)
222
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800223 def test_unary_rpc_server_error(self):
224 method = self._service.SomeUnary.method
225
226 for _ in range(3):
227 self._enqueue_error(1, method, Status.NOT_FOUND)
228
229 with self.assertRaises(callback_client.RpcError) as context:
230 self._service.SomeUnary(method.request_type(magic_number=6))
231
232 self.assertIs(context.exception.status, Status.NOT_FOUND)
233
234 def test_invoke_unary_rpc_callback_exceptions_suppressed(self):
Wyatt Heplercddc5cd2021-01-27 14:52:58 -0800235 stub = self._service.SomeUnary
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700236
237 self._enqueue_response(1, stub.method)
238 exception_msg = 'YOU BROKE IT O-]-<'
239
240 with self.assertLogs(callback_client.__name__, 'ERROR') as logs:
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800241 stub.invoke(self._request(),
242 mock.Mock(side_effect=Exception(exception_msg)))
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700243
244 self.assertIn(exception_msg, ''.join(logs.output))
245
246 # Make sure we can still invoke the RPC.
247 self._enqueue_response(1, stub.method, Status.UNKNOWN)
248 status, _ = stub()
249 self.assertIs(status, Status.UNKNOWN)
250
251 def test_invoke_unary_rpc_with_callback_cancel(self):
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700252 callback = mock.Mock()
253
254 for _ in range(3):
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800255 call = self._service.SomeUnary.invoke(
256 self._request(magic_number=55), callback)
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700257
258 self.assertIsNotNone(self._last_request)
259 self._last_request = None
260
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800261 # Try to invoke the RPC again before cancelling, without overriding
262 # pending RPCs.
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700263 with self.assertRaises(client.Error):
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800264 self._service.SomeUnary.invoke(self._request(magic_number=56),
265 callback,
266 override_pending=False)
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700267
268 self.assertTrue(call.cancel())
269 self.assertFalse(call.cancel()) # Already cancelled, returns False
270
271 # Unary RPCs do not send a cancel request to the server.
272 self.assertIsNone(self._last_request)
273
274 callback.assert_not_called()
275
Wyatt Hepler60161aa2020-07-21 23:29:47 -0700276 def test_reinvoke_unary_rpc(self):
Wyatt Hepler60161aa2020-07-21 23:29:47 -0700277 for _ in range(3):
278 self._last_request = None
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800279 self._service.SomeUnary.invoke(self._request(magic_number=55),
280 override_pending=True)
Wyatt Hepler0f262352020-07-29 09:51:27 -0700281 self.assertEqual(self._last_request.type,
Wyatt Heplera2970c52021-02-02 14:47:22 -0800282 packet_pb2.PacketType.REQUEST)
Wyatt Hepler60161aa2020-07-21 23:29:47 -0700283
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700284 def test_invoke_server_streaming(self):
Wyatt Heplercddc5cd2021-01-27 14:52:58 -0800285 method = self._service.SomeServerStreaming.method
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700286
287 rep1 = method.response_type(payload='!!!')
288 rep2 = method.response_type(payload='?')
289
290 for _ in range(3):
291 self._enqueue_response(1, method, response=rep1)
292 self._enqueue_response(1, method, response=rep2)
293 self._enqueue_stream_end(1, method, Status.ABORTED)
294
Wyatt Heplercddc5cd2021-01-27 14:52:58 -0800295 self.assertEqual(
296 [rep1, rep2],
297 list(self._service.SomeServerStreaming(magic_number=4)))
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700298
299 self.assertEqual(
300 4,
301 self._sent_payload(method.request_type).magic_number)
302
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800303 def test_invoke_server_streaming_with_callbacks(self):
Wyatt Heplercddc5cd2021-01-27 14:52:58 -0800304 method = self._service.SomeServerStreaming.method
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700305
306 rep1 = method.response_type(payload='!!!')
307 rep2 = method.response_type(payload='?')
308
309 for _ in range(3):
310 self._enqueue_response(1, method, response=rep1)
311 self._enqueue_response(1, method, response=rep2)
312 self._enqueue_stream_end(1, method, Status.ABORTED)
313
314 callback = mock.Mock()
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800315 self._service.SomeServerStreaming.invoke(
316 self._request(magic_number=3), callback, callback)
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700317
Wyatt Heplercddc5cd2021-01-27 14:52:58 -0800318 rpc = _rpc(self._service.SomeServerStreaming)
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700319 callback.assert_has_calls([
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800320 mock.call(rpc, method.response_type(payload='!!!')),
321 mock.call(rpc, method.response_type(payload='?')),
322 mock.call(rpc, Status.ABORTED),
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700323 ])
324
325 self.assertEqual(
326 3,
327 self._sent_payload(method.request_type).magic_number)
328
329 def test_invoke_server_streaming_with_callback_cancel(self):
Wyatt Heplercddc5cd2021-01-27 14:52:58 -0800330 stub = self._service.SomeServerStreaming
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700331
332 resp = stub.method.response_type(payload='!!!')
333 self._enqueue_response(1, stub.method, response=resp)
334
335 callback = mock.Mock()
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800336 call = stub.invoke(self._request(magic_number=3), callback)
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700337 callback.assert_called_once_with(
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800338 _rpc(stub), stub.method.response_type(payload='!!!'))
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700339
340 callback.reset_mock()
341
342 call.cancel()
343
Wyatt Heplera9211162021-06-12 15:40:11 -0700344 self.assertEqual(self._last_request.type, packet_pb2.PacketType.CANCEL)
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700345
346 # Ensure the RPC can be called after being cancelled.
347 self._enqueue_response(1, stub.method, response=resp)
348 self._enqueue_stream_end(1, stub.method, Status.OK)
349
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800350 call = stub.invoke(self._request(magic_number=3), callback, callback)
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700351
352 callback.assert_has_calls([
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800353 mock.call(_rpc(stub), stub.method.response_type(payload='!!!')),
354 mock.call(_rpc(stub), Status.OK),
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700355 ])
356
357 def test_ignore_bad_packets_with_pending_rpc(self):
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800358 method = self._service.SomeUnary.method
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700359 service_id = method.service.id
360
361 # Unknown channel
Wyatt Hepler0f262352020-07-29 09:51:27 -0700362 self._enqueue_response(999, method, process_status=Status.NOT_FOUND)
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700363 # Bad service
Wyatt Hepler0f262352020-07-29 09:51:27 -0700364 self._enqueue_response(1,
365 ids=(999, method.id),
366 process_status=Status.OK)
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700367 # Bad method
Wyatt Hepler0f262352020-07-29 09:51:27 -0700368 self._enqueue_response(1,
369 ids=(service_id, 999),
370 process_status=Status.OK)
371 # For RPC not pending (is Status.OK because the packet is processed)
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800372 self._enqueue_response(1,
373 ids=(service_id,
374 self._service.SomeBidiStreaming.method.id),
375 process_status=Status.OK)
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700376
Wyatt Hepler0f262352020-07-29 09:51:27 -0700377 self._enqueue_response(1, method, process_status=Status.OK)
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700378
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800379 status, response = self._service.SomeUnary(magic_number=6)
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700380 self.assertIs(Status.OK, status)
381 self.assertEqual('', response.payload)
382
383 def test_pass_none_if_payload_fails_to_decode(self):
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800384 method = self._service.SomeUnary.method
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700385
386 self._enqueue_response(1,
387 method,
388 Status.OK,
389 b'INVALID DATA!!!',
Wyatt Hepler0f262352020-07-29 09:51:27 -0700390 process_status=Status.OK)
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700391
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800392 status, response = self._service.SomeUnary(magic_number=6)
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700393 self.assertIs(status, Status.OK)
394 self.assertIsNone(response)
395
Wyatt Hepler26113872020-12-08 16:16:22 -0800396 def test_rpc_help_contains_method_name(self):
Wyatt Heplercc9d2e92021-02-01 09:08:55 -0800397 rpc = self._service.SomeUnary
Wyatt Hepler26113872020-12-08 16:16:22 -0800398 self.assertIn(rpc.method.full_name, rpc.help())
399
Wyatt Heplercddc5cd2021-01-27 14:52:58 -0800400 def test_default_timeouts_set_on_impl(self):
401 impl = callback_client.Impl(None, 1.5)
402
403 self.assertEqual(impl.default_unary_timeout_s, None)
404 self.assertEqual(impl.default_stream_timeout_s, 1.5)
405
406 def test_default_timeouts_set_for_all_rpcs(self):
407 rpc_client = client.Client.from_modules(callback_client.Impl(
408 99, 100), [client.Channel(1, lambda *a, **b: None)],
409 self._protos.modules())
410 rpcs = rpc_client.channel(1).rpcs
411
412 self.assertEqual(
413 rpcs.pw.test1.PublicService.SomeUnary.default_timeout_s, 99)
414 self.assertEqual(
415 rpcs.pw.test1.PublicService.SomeServerStreaming.default_timeout_s,
416 100)
417
418 def test_timeout_unary(self):
419 with self.assertRaises(callback_client.RpcTimeout):
420 self._service.SomeUnary(pw_rpc_timeout_s=0.0001)
421
422 def test_timeout_unary_set_default(self):
423 self._service.SomeUnary.default_timeout_s = 0.0001
424
425 with self.assertRaises(callback_client.RpcTimeout):
426 self._service.SomeUnary()
427
428 def test_timeout_server_streaming_iteration(self):
429 responses = self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001)
430 with self.assertRaises(callback_client.RpcTimeout):
431 for _ in responses:
432 pass
433
434 def test_timeout_server_streaming_responses(self):
435 responses = self._service.SomeServerStreaming()
436 with self.assertRaises(callback_client.RpcTimeout):
437 for _ in responses.responses(timeout_s=0.0001):
438 pass
439
Wyatt Heplerdd4cce92020-07-17 12:24:23 -0700440
441if __name__ == '__main__':
442 unittest.main()