pw_rpc: Update for protocol changes
- Update to handle STREAM_END and CANCEL packet semantics.
- Detect and handle packet decoding errors.
- Send ERROR packets when a malformed or unexpected packet is received.
- Switch from uint32 to fixed32 for service and method IDs. fixed32
stores evenly distributed uint32s (like hashes) more efficiently than
uint32.
- Use a void return type for server streaming RPCs, since the status is
sent from the ServerWriter.
- Move response sending from the Server to the Method. Previously, the
server always sent a response. However, only unary RPCs should send an
immediate response. This structure also gives the Method class more
control over the semantics for RPCs.
Change-Id: Ie88e4a24baf4e039463adf411f7dad7a0c64a710
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/13960
Commit-Queue: Wyatt Hepler <hepler@google.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
diff --git a/pw_protobuf/py/pw_protobuf/proto_tree.py b/pw_protobuf/py/pw_protobuf/proto_tree.py
index 9313689..e2a2c8a 100644
--- a/pw_protobuf/py/pw_protobuf/proto_tree.py
+++ b/pw_protobuf/py/pw_protobuf/proto_tree.py
@@ -337,6 +337,14 @@
def type(self) -> Type:
return self._type
+ def server_streaming(self) -> bool:
+ return (self._type is self.Type.SERVER_STREAMING
+ or self._type is self.Type.BIDIRECTIONAL_STREAMING)
+
+ def client_streaming(self) -> bool:
+ return (self._type is self.Type.CLIENT_STREAMING
+ or self._type is self.Type.BIDIRECTIONAL_STREAMING)
+
def request_type(self) -> ProtoNode:
return self._request_type
diff --git a/pw_rpc/BUILD.gn b/pw_rpc/BUILD.gn
index d57a898..13b88cb 100644
--- a/pw_rpc/BUILD.gn
+++ b/pw_rpc/BUILD.gn
@@ -71,6 +71,7 @@
":$_target_name",
":common",
dir_pw_span,
+ dir_pw_unit_test,
]
visibility = [ "./*" ]
}
diff --git a/pw_rpc/base_server_writer.cc b/pw_rpc/base_server_writer.cc
index 67abc6f..6ddd136 100644
--- a/pw_rpc/base_server_writer.cc
+++ b/pw_rpc/base_server_writer.cc
@@ -45,7 +45,7 @@
uint32_t BaseServerWriter::method_id() const { return call_.method().id(); }
-void BaseServerWriter::Finish() {
+void BaseServerWriter::Finish(Status status) {
if (!open()) {
return;
}
@@ -53,13 +53,13 @@
call_.server().RemoveWriter(*this);
state_ = kClosed;
- // Send a control packet indicating that the stream has terminated.
- auto response = call_.channel().AcquireBuffer();
- call_.channel().Send(response,
- Packet(PacketType::CANCEL,
+ // Send a control packet indicating that the stream (and RPC) has terminated.
+ call_.channel().Send(Packet(PacketType::STREAM_END,
call_.channel().id(),
call_.service().id(),
- method().id()));
+ method().id(),
+ {},
+ status));
}
std::span<std::byte> BaseServerWriter::AcquirePayloadBuffer() {
diff --git a/pw_rpc/base_server_writer_test.cc b/pw_rpc/base_server_writer_test.cc
index 04d37db..ce9e806 100644
--- a/pw_rpc/base_server_writer_test.cc
+++ b/pw_rpc/base_server_writer_test.cc
@@ -112,8 +112,8 @@
writer.Finish();
- Packet packet = Packet::FromBuffer(context.output().sent_packet());
- EXPECT_EQ(packet.type(), PacketType::CANCEL);
+ const Packet& packet = context.output().sent_packet();
+ EXPECT_EQ(packet.type(), PacketType::STREAM_END);
EXPECT_EQ(packet.channel_id(), context.kChannelId);
EXPECT_EQ(packet.service_id(), context.kServiceId);
EXPECT_EQ(packet.method_id(), context.get().method().id());
@@ -141,10 +141,9 @@
auto sws = context.packet(data).Encode(encoded);
ASSERT_EQ(Status::OK, sws.status());
- EXPECT_EQ(sws.size(), context.output().sent_packet().size());
+ EXPECT_EQ(sws.size(), context.output().sent_data().size());
EXPECT_EQ(
- 0,
- std::memcmp(encoded, context.output().sent_packet().data(), sws.size()));
+ 0, std::memcmp(encoded, context.output().sent_data().data(), sws.size()));
}
TEST(ServerWriter, Closed_IgnoresPacket) {
diff --git a/pw_rpc/channel_test.cc b/pw_rpc/channel_test.cc
index b63535b..c2938dc 100644
--- a/pw_rpc/channel_test.cc
+++ b/pw_rpc/channel_test.cc
@@ -36,8 +36,8 @@
}
constexpr Packet kTestPacket(PacketType::RPC, 1, 42, 100);
-const size_t kReservedSize = 2 /* type */ + 2 /* channel */ + 2 /* service */ +
- 2 /* method */ + 2 /* payload key */ +
+const size_t kReservedSize = 2 /* type */ + 2 /* channel */ + 5 /* service */ +
+ 5 /* method */ + 2 /* payload key */ +
2 /* status */;
TEST(Channel, TestPacket_ReservedSizeMatchesMinEncodedSizeBytes) {
@@ -79,13 +79,11 @@
TestOutput<kReservedSize> output;
internal::Channel channel(100, &output);
- Channel::OutputBuffer output_buffer(channel.AcquireBuffer());
-
Packet packet = kTestPacket;
byte data[1] = {};
packet.set_payload(data);
- EXPECT_EQ(Status::INTERNAL, channel.Send(output_buffer, packet));
+ EXPECT_EQ(Status::INTERNAL, channel.Send(packet));
}
TEST(Channel, OutputBuffer_ExtraRoom) {
diff --git a/pw_rpc/docs.rst b/pw_rpc/docs.rst
index aee90bb..81efb4f 100644
--- a/pw_rpc/docs.rst
+++ b/pw_rpc/docs.rst
@@ -91,6 +91,8 @@
* ``FAILED_PRECONDITION`` -- Attempted to cancel an RPC that is not pending.
* ``RESOURCE_EXHAUSTED`` -- The request came on a new channel, but a channel
could not be allocated for it.
+* ``INTERNAL`` -- The server was unable to respond to an RPC due to an
+ unrecoverable internal error.
Inovking a service method
-------------------------
diff --git a/pw_rpc/nanopb/method.cc b/pw_rpc/nanopb/method.cc
index 3d91070..6c96d0a 100644
--- a/pw_rpc/nanopb/method.cc
+++ b/pw_rpc/nanopb/method.cc
@@ -16,6 +16,8 @@
#include "pb_decode.h"
#include "pb_encode.h"
+#include "pw_log/log.h"
+#include "pw_rpc/internal/packet.h"
namespace pw::rpc::internal {
namespace {
@@ -37,15 +39,6 @@
using std::byte;
-Status Method::DecodeRequest(std::span<const byte> buffer,
- void* proto_struct) const {
- auto input = pb_istream_from_buffer(
- reinterpret_cast<const pb_byte_t*>(buffer.data()), buffer.size());
- return pb_decode(&input, static_cast<Fields>(request_fields_), proto_struct)
- ? Status::OK
- : Status::INTERNAL;
-}
-
StatusWithSize Method::EncodeResponse(const void* proto_struct,
std::span<byte> buffer) const {
auto output = pb_ostream_from_buffer(
@@ -56,37 +49,68 @@
return StatusWithSize::INTERNAL;
}
-StatusWithSize Method::CallUnary(ServerCall& call,
- std::span<const byte> request_buffer,
- std::span<byte> response_buffer,
- void* request_struct,
- void* response_struct) const {
- Status status = DecodeRequest(request_buffer, request_struct);
- if (!status.ok()) {
- return StatusWithSize(status, 0);
+void Method::CallUnary(ServerCall& call,
+ const Packet& request,
+ void* request_struct,
+ void* response_struct) const {
+ if (!DecodeRequest(call.channel(), request, request_struct)) {
+ return;
}
- status = function_.unary(call.context(), request_struct, response_struct);
-
- StatusWithSize encoded = EncodeResponse(response_struct, response_buffer);
- if (encoded.ok()) {
- return StatusWithSize(status, encoded.size());
- }
- return encoded;
+ const Status status =
+ function_.unary(call.context(), request_struct, response_struct);
+ SendResponse(call.channel(), request, response_struct, status);
}
-StatusWithSize Method::CallServerStreaming(ServerCall& call,
- std::span<const byte> request_buffer,
- void* request_struct) const {
- Status status = DecodeRequest(request_buffer, request_struct);
- if (!status.ok()) {
- return StatusWithSize(status, 0);
+void Method::CallServerStreaming(ServerCall& call,
+ const Packet& request,
+ void* request_struct) const {
+ if (!DecodeRequest(call.channel(), request, request_struct)) {
+ return;
}
internal::BaseServerWriter server_writer(call);
- return StatusWithSize(
- function_.server_streaming(call.context(), request_struct, server_writer),
- 0);
+ function_.server_streaming(call.context(), request_struct, server_writer);
+}
+
+bool Method::DecodeRequest(Channel& channel,
+ const Packet& request,
+ void* proto_struct) const {
+ auto input = pb_istream_from_buffer(
+ reinterpret_cast<const pb_byte_t*>(request.payload().data()),
+ request.payload().size());
+ if (pb_decode(&input, static_cast<Fields>(request_fields_), proto_struct)) {
+ return true;
+ }
+
+ PW_LOG_WARN("Failed to decode request payload from channel %u",
+ unsigned(channel.id()));
+ channel.Send(Packet::Error(request, Status::DATA_LOSS));
+ return false;
+}
+
+void Method::SendResponse(Channel& channel,
+ const Packet& request,
+ const void* response_struct,
+ Status status) const {
+ Channel::OutputBuffer response_buffer = channel.AcquireBuffer();
+ std::span payload_buffer = response_buffer.payload(request);
+
+ StatusWithSize encoded = EncodeResponse(response_struct, payload_buffer);
+
+ if (encoded.ok()) {
+ Packet response = Packet::Response(request);
+
+ response.set_payload(payload_buffer.first(encoded.size()));
+ response.set_status(status);
+ if (channel.Send(response_buffer, response).ok()) {
+ return;
+ }
+ }
+
+ PW_LOG_WARN("Failed to encode response packet for channel %u",
+ unsigned(channel.id()));
+ channel.Send(response_buffer, Packet::Error(request, Status::INTERNAL));
}
} // namespace pw::rpc::internal
diff --git a/pw_rpc/nanopb/method_test.cc b/pw_rpc/nanopb/method_test.cc
index 80fd216..b90406f 100644
--- a/pw_rpc/nanopb/method_test.cc
+++ b/pw_rpc/nanopb/method_test.cc
@@ -59,9 +59,9 @@
const pw_rpc_test_TestRequest&,
pw_rpc_test_TestResponse&);
- static Status StartStream(ServerContext&,
- const pw_rpc_test_TestRequest&,
- ServerWriter<pw_rpc_test_TestResponse>&);
+ static void StartStream(ServerContext&,
+ const pw_rpc_test_TestRequest&,
+ ServerWriter<pw_rpc_test_TestResponse>&);
static constexpr std::array<Method, 3> kMethods = {
Method::Unary<DoNothing>(
@@ -87,78 +87,91 @@
Status FakeGeneratedService::DoNothing(ServerContext&,
const pw_rpc_test_Empty&,
pw_rpc_test_Empty&) {
- return Status::NOT_FOUND;
+ return Status::UNKNOWN;
}
-Status FakeGeneratedService::StartStream(
+void FakeGeneratedService::StartStream(
ServerContext&,
const pw_rpc_test_TestRequest& request,
ServerWriter<pw_rpc_test_TestResponse>& writer) {
last_request = request;
last_writer = std::move(writer);
- return Status::UNAVAILABLE;
-}
-
-TEST(Method, UnaryRpc_DoesNothing) {
- ENCODE_PB(pw_rpc_test_Empty, {}, request);
- byte response[128] = {};
-
- const Method& method = std::get<0>(FakeGeneratedService::kMethods);
- ServerContextForTest<FakeGeneratedService> context(method);
- StatusWithSize result = method.Invoke(context.get(), request, response);
- EXPECT_EQ(Status::NOT_FOUND, result.status());
}
TEST(Method, UnaryRpc_SendsResponse) {
ENCODE_PB(pw_rpc_test_TestRequest, {.integer = 123}, request);
- byte response[128] = {};
const Method& method = std::get<1>(FakeGeneratedService::kMethods);
ServerContextForTest<FakeGeneratedService> context(method);
- StatusWithSize result = method.Invoke(context.get(), request, response);
- EXPECT_EQ(Status::UNAUTHENTICATED, result.status());
+ method.Invoke(context.get(), context.packet(request));
+
+ const Packet& response = context.output().sent_packet();
+ EXPECT_EQ(response.status(), Status::UNAUTHENTICATED);
// Field 1 (encoded as 1 << 3) with 128 as the value.
constexpr std::byte expected[]{
std::byte{0x08}, std::byte{0x80}, std::byte{0x01}};
- EXPECT_EQ(sizeof(expected), result.size());
- EXPECT_EQ(0, std::memcmp(expected, response, sizeof(expected)));
+ EXPECT_EQ(sizeof(expected), response.payload().size());
+ EXPECT_EQ(0,
+ std::memcmp(expected, response.payload().data(), sizeof(expected)));
EXPECT_EQ(123, last_request.integer);
}
-TEST(Method, UnaryRpc_BufferTooSmallForResponse_InternalError) {
- ENCODE_PB(pw_rpc_test_TestRequest, {.integer = 123}, request);
- byte response[2] = {}; // Too small for the response
+TEST(Method, UnaryRpc_InvalidPayload_SendsError) {
+ std::array<byte, 8> bad_payload{byte{0xFF}, byte{0xAA}, byte{0xDD}};
+
+ const Method& method = std::get<0>(FakeGeneratedService::kMethods);
+ ServerContextForTest<FakeGeneratedService> context(method);
+ method.Invoke(context.get(), context.packet(bad_payload));
+
+ const Packet& packet = context.output().sent_packet();
+ EXPECT_EQ(PacketType::ERROR, packet.type());
+ EXPECT_EQ(Status::DATA_LOSS, packet.status());
+ EXPECT_EQ(context.kServiceId, packet.service_id());
+ EXPECT_EQ(method.id(), packet.method_id());
+}
+
+TEST(Method, UnaryRpc_BufferTooSmallForResponse_SendsInternalError) {
+ constexpr int64_t value = 0x7FFFFFFF'FFFFFF00ll;
+ ENCODE_PB(pw_rpc_test_TestRequest, {.integer = value}, request);
const Method& method = std::get<1>(FakeGeneratedService::kMethods);
- ServerContextForTest<FakeGeneratedService> context(method);
+ // Output buffer is too small for the response, but can fit an error packet.
+ ServerContextForTest<FakeGeneratedService, 22> context(method);
+ ASSERT_LT(context.output().buffer_size(),
+ context.packet(request).MinEncodedSizeBytes() + request.size() + 1);
- StatusWithSize result = method.Invoke(context.get(), request, response);
- EXPECT_EQ(Status::INTERNAL, result.status());
- EXPECT_EQ(0u, result.size());
- EXPECT_EQ(123, last_request.integer);
+ method.Invoke(context.get(), context.packet(request));
+
+ const Packet& packet = context.output().sent_packet();
+ EXPECT_EQ(PacketType::ERROR, packet.type());
+ EXPECT_EQ(Status::INTERNAL, packet.status());
+ EXPECT_EQ(context.kServiceId, packet.service_id());
+ EXPECT_EQ(method.id(), packet.method_id());
+
+ EXPECT_EQ(value, last_request.integer);
}
-TEST(Method, ServerStreamingRpc) {
+TEST(Method, ServerStreamingRpc_SendsNothingWhenInitiallyCalled) {
ENCODE_PB(pw_rpc_test_TestRequest, {.integer = 555}, request);
const Method& method = std::get<2>(FakeGeneratedService::kMethods);
ServerContextForTest<FakeGeneratedService> context(method);
- StatusWithSize result = method.Invoke(context.get(), request, {});
- EXPECT_EQ(Status::UNAVAILABLE, result.status());
- EXPECT_EQ(0u, result.size());
+ method.Invoke(context.get(), context.packet(request));
+ EXPECT_EQ(0u, context.output().packet_count());
EXPECT_EQ(555, last_request.integer);
}
TEST(Method, ServerWriter_SendsResponse) {
const Method& method = std::get<2>(FakeGeneratedService::kMethods);
ServerContextForTest<FakeGeneratedService> context(method);
- ASSERT_EQ(Status::UNAVAILABLE, method.Invoke(context.get(), {}, {}).status());
+
+ method.Invoke(context.get(), context.packet({}));
EXPECT_EQ(Status::OK, last_writer.Write({.value = 100}));
@@ -167,25 +180,31 @@
auto encoded = context.packet(payload).Encode(encoded_response);
ASSERT_EQ(Status::OK, encoded.status());
- ASSERT_EQ(encoded.size(), context.output().sent_packet().size());
+ ASSERT_EQ(encoded.size(), context.output().sent_data().size());
EXPECT_EQ(0,
std::memcmp(encoded_response.data(),
- context.output().sent_packet().data(),
+ context.output().sent_data().data(),
encoded.size()));
}
TEST(Method, ServerStreamingRpc_ServerWriterBufferTooSmall_InternalError) {
const Method& method = std::get<2>(FakeGeneratedService::kMethods);
- // Make the buffer barely fit a packet with no payload.
- ServerContextForTest<FakeGeneratedService, 12> context(method);
- ASSERT_EQ(Status::UNAVAILABLE, method.Invoke(context.get(), {}, {}).status());
+ constexpr size_t kNoPayloadPacketSize = 2 /* type */ + 2 /* channel */ +
+ 5 /* service */ + 5 /* method */ +
+ 2 /* payload */ + 2 /* status */;
- // Verify that the encoded size of a packet with an empty payload is 12.
+ // Make the buffer barely fit a packet with no payload.
+ ServerContextForTest<FakeGeneratedService, kNoPayloadPacketSize> context(
+ method);
+
+ // Verify that the encoded size of a packet with an empty payload is correct.
std::array<byte, 128> encoded_response = {};
auto encoded = context.packet({}).Encode(encoded_response);
ASSERT_EQ(Status::OK, encoded.status());
- ASSERT_EQ(12u, encoded.size());
+ ASSERT_EQ(kNoPayloadPacketSize, encoded.size());
+
+ method.Invoke(context.get(), context.packet({}));
EXPECT_EQ(Status::OK, last_writer.Write({})); // Barely fits
EXPECT_EQ(Status::INTERNAL, last_writer.Write({.value = 1})); // Too big
diff --git a/pw_rpc/nanopb/public_overrides/pw_rpc/internal/method.h b/pw_rpc/nanopb/public_overrides/pw_rpc/internal/method.h
index 19f1497..0a5d2a2 100644
--- a/pw_rpc/nanopb/public_overrides/pw_rpc/internal/method.h
+++ b/pw_rpc/nanopb/public_overrides/pw_rpc/internal/method.h
@@ -44,6 +44,8 @@
namespace internal {
+class Packet;
+
// Use a void* to cover both Nanopb 3's pb_field_s and Nanopb 4's pb_msgdesc_s.
using NanopbMessageDescriptor = const void*;
@@ -61,7 +63,7 @@
// Specialization for server streaming RPCs.
template <typename RequestType, typename ResponseType>
-struct RpcTraits<Status (*)(
+struct RpcTraits<void (*)(
ServerContext&, const RequestType&, ServerWriter<ResponseType>&)> {
using Request = RequestType;
using Response = ResponseType;
@@ -122,10 +124,9 @@
return Method(
{.server_streaming =
[](ServerContext& ctx, const void* req, BaseServerWriter& resp) {
- return method(
- ctx,
- *static_cast<const Request<method>*>(req),
- static_cast<ServerWriter<Response<method>>&>(resp));
+ method(ctx,
+ *static_cast<const Request<method>*>(req),
+ static_cast<ServerWriter<Response<method>>&>(resp));
}},
ServerStreamingInvoker<AllocateSpaceFor<Request<method>>()>,
id,
@@ -136,16 +137,10 @@
// The pw::rpc::Server calls method.Invoke to call a user-defined RPC. Invoke
// calls the invoker function, which encodes and decodes the request and
// response (if any) and calls the user-defined RPC function.
- StatusWithSize Invoke(ServerCall& call,
- std::span<const std::byte> request,
- std::span<std::byte> payload_buffer) const {
- return invoker_(*this, call, request, payload_buffer);
+ void Invoke(ServerCall& call, const Packet& request) const {
+ return invoker_(*this, call, request);
}
- // Decodes a request protobuf with Nanopb to the provided buffer.
- Status DecodeRequest(std::span<const std::byte> buffer,
- void* proto_struct) const;
-
// Encodes a response protobuf with Nanopb to the provided buffer.
StatusWithSize EncodeResponse(const void* proto_struct,
std::span<std::byte> buffer) const;
@@ -163,9 +158,9 @@
//
// Status(ServerContext&, const Request&, ServerWriter<Response>&)
//
- using ServerStreamingFunction = Status (*)(ServerContext&,
- const void* request,
- BaseServerWriter& writer);
+ using ServerStreamingFunction = void (*)(ServerContext&,
+ const void* request,
+ BaseServerWriter& writer);
// The Function union stores a pointer to a generic version of the
// user-defined RPC function. Using a union instead of void* avoids
@@ -184,12 +179,8 @@
}
// The Invoker allocates request/response structs on the stack and calls the
- // RPC according to its type (unary, server streaming, etc.). The Invoker
- // returns the number of bytes written to the response buffer, if any.
- using Invoker = StatusWithSize (&)(const Method&,
- ServerCall&,
- std::span<const std::byte>,
- std::span<std::byte>);
+ // RPC according to its type (unary, server streaming, etc.).
+ using Invoker = void (&)(const Method&, ServerCall&, const Packet&);
constexpr Method(Function function,
Invoker invoker,
@@ -202,15 +193,14 @@
request_fields_(request),
response_fields_(response) {}
- StatusWithSize CallUnary(ServerCall& call,
- std::span<const std::byte> request_buffer,
- std::span<std::byte> response_buffer,
- void* request_struct,
- void* response_struct) const;
+ void CallUnary(ServerCall& call,
+ const Packet& request,
+ void* request_struct,
+ void* response_struct) const;
- StatusWithSize CallServerStreaming(ServerCall& call,
- std::span<const std::byte> request_buffer,
- void* request_struct) const;
+ void CallServerStreaming(ServerCall& call,
+ const Packet& request,
+ void* request_struct) const;
// TODO(hepler): Add CallClientStreaming and CallBidiStreaming
@@ -218,37 +208,42 @@
// size, with maximum alignment, to avoid generating unnecessary copies of
// this function for each request/response type.
template <size_t request_size, size_t response_size>
- static StatusWithSize UnaryInvoker(const Method& method,
- ServerCall& call,
- std::span<const std::byte> request_buffer,
- std::span<std::byte> response_buffer) {
+ static void UnaryInvoker(const Method& method,
+ ServerCall& call,
+ const Packet& request) {
std::aligned_storage_t<request_size, alignof(std::max_align_t)>
request_struct{};
std::aligned_storage_t<response_size, alignof(std::max_align_t)>
response_struct{};
- return method.CallUnary(call,
- request_buffer,
- response_buffer,
- &request_struct,
- &response_struct);
+ method.CallUnary(call, request, &request_struct, &response_struct);
}
// Invoker function for server streaming RPCs. Allocates space for a request
// struct. Ignores the payload buffer since resposnes are sent through the
// ServerWriter.
template <size_t request_size>
- static StatusWithSize ServerStreamingInvoker(
- const Method& method,
- ServerCall& call,
- std::span<const std::byte> request_buffer,
- std::span<std::byte> /* payload not used */) {
+ static void ServerStreamingInvoker(const Method& method,
+ ServerCall& call,
+ const Packet& request) {
std::aligned_storage_t<request_size, alignof(std::max_align_t)>
request_struct{};
- return method.CallServerStreaming(call, request_buffer, &request_struct);
+ method.CallServerStreaming(call, request, &request_struct);
}
+ // Decodes a request protobuf with Nanopb to the provided buffer. Sends an
+ // error packet if the request failed to decode.
+ bool DecodeRequest(Channel& channel,
+ const Packet& request,
+ void* proto_struct) const;
+
+ // Encodes a response and sends it over the provided channel.
+ void SendResponse(Channel& channel,
+ const Packet& request,
+ const void* response_struct,
+ Status status) const;
+
// Allocates memory for the request/response structs and invokes the
// user-defined RPC based on its type (unary, server streaming, etc.).
Invoker invoker_;
diff --git a/pw_rpc/packet.cc b/pw_rpc/packet.cc
index 9a7e34f..dfbdb50 100644
--- a/pw_rpc/packet.cc
+++ b/pw_rpc/packet.cc
@@ -20,52 +20,51 @@
using std::byte;
-Packet Packet::FromBuffer(std::span<const byte> data) {
- PacketType type = PacketType::RPC;
- uint32_t channel_id = 0;
- uint32_t service_id = 0;
- uint32_t method_id = 0;
- std::span<const byte> payload;
- Status status;
+Status Packet::FromBuffer(std::span<const byte> data, Packet& packet) {
+ packet = Packet();
- uint32_t value;
protobuf::Decoder decoder(data);
- while (decoder.Next().ok()) {
+ Status status;
+
+ while ((status = decoder.Next()).ok()) {
RpcPacket::Fields field =
static_cast<RpcPacket::Fields>(decoder.FieldNumber());
- uint32_t proto_value = 0;
switch (field) {
- case RpcPacket::Fields::TYPE:
- decoder.ReadUint32(&proto_value);
- type = static_cast<PacketType>(proto_value);
+ case RpcPacket::Fields::TYPE: {
+ uint32_t value;
+ decoder.ReadUint32(&value);
+ packet.set_type(static_cast<PacketType>(value));
break;
+ }
case RpcPacket::Fields::CHANNEL_ID:
- decoder.ReadUint32(&channel_id);
+ decoder.ReadUint32(&packet.channel_id_);
break;
case RpcPacket::Fields::SERVICE_ID:
- decoder.ReadUint32(&service_id);
+ decoder.ReadFixed32(&packet.service_id_);
break;
case RpcPacket::Fields::METHOD_ID:
- decoder.ReadUint32(&method_id);
+ decoder.ReadFixed32(&packet.method_id_);
break;
case RpcPacket::Fields::PAYLOAD:
- decoder.ReadBytes(&payload);
+ decoder.ReadBytes(&packet.payload_);
break;
- case RpcPacket::Fields::STATUS:
+ case RpcPacket::Fields::STATUS: {
+ uint32_t value;
decoder.ReadUint32(&value);
- status = static_cast<Status::Code>(value);
+ packet.set_status(static_cast<Status::Code>(value));
break;
+ }
}
}
- return Packet(type, channel_id, service_id, method_id, payload, status);
+ return status == Status::DATA_LOSS ? Status::DATA_LOSS : Status::OK;
}
StatusWithSize Packet::Encode(std::span<byte> buffer) const {
@@ -94,10 +93,8 @@
reserved_size += 1; // channel_id key
reserved_size += varint::EncodedSize(channel_id());
- reserved_size += 1; // service_id key
- reserved_size += varint::EncodedSize(service_id());
- reserved_size += 1; // method_id key
- reserved_size += varint::EncodedSize(method_id());
+ reserved_size += 1 + sizeof(uint32_t); // service_id key and fixed32
+ reserved_size += 1 + sizeof(uint32_t); // method_id key and fixed32
// Packet type always takes two bytes to encode (varint key + varint enum).
reserved_size += 2;
diff --git a/pw_rpc/packet_test.cc b/pw_rpc/packet_test.cc
index 87d5c83..e4ac4b6 100644
--- a/pw_rpc/packet_test.cc
+++ b/pw_rpc/packet_test.cc
@@ -43,12 +43,18 @@
byte{1},
// Service ID
- byte{MakeKey(3, protobuf::WireType::kVarint)},
+ byte{MakeKey(3, protobuf::WireType::kFixed32)},
byte{42},
+ byte{0},
+ byte{0},
+ byte{0},
// Method ID
- byte{MakeKey(4, protobuf::WireType::kVarint)},
+ byte{MakeKey(4, protobuf::WireType::kFixed32)},
byte{100},
+ byte{0},
+ byte{0},
+ byte{0},
// Status
byte{MakeKey(6, protobuf::WireType::kVarint)},
@@ -65,8 +71,19 @@
EXPECT_EQ(std::memcmp(kEncoded, buffer, sizeof(kEncoded)), 0);
}
-TEST(Packet, Decode) {
- Packet packet = Packet::FromBuffer(kEncoded);
+TEST(Packet, Encode_BufferTooSmall) {
+ byte buffer[2];
+
+ Packet packet(PacketType::RPC, 1, 42, 100, kPayload);
+
+ auto sws = packet.Encode(buffer);
+ EXPECT_EQ(0u, sws.size());
+ EXPECT_EQ(Status::RESOURCE_EXHAUSTED, sws.status());
+}
+
+TEST(Packet, Decode_ValidPacket) {
+ Packet packet;
+ ASSERT_EQ(Status::OK, Packet::FromBuffer(kEncoded, packet));
EXPECT_EQ(PacketType::RPC, packet.type());
EXPECT_EQ(1u, packet.channel_id());
@@ -77,10 +94,17 @@
std::memcmp(packet.payload().data(), kPayload, sizeof(kPayload)));
}
+TEST(Packet, Decode_InvalidPacket) {
+ byte bad_data[] = {byte{0xFF}, byte{0x00}, byte{0x00}, byte{0xFF}};
+
+ Packet packet;
+ EXPECT_EQ(Status::DATA_LOSS, Packet::FromBuffer(bad_data, packet));
+}
+
TEST(Packet, EncodeDecode) {
constexpr byte payload[]{byte(0x00), byte(0x01), byte(0x02), byte(0x03)};
- Packet packet = Packet(PacketType::RPC);
+ Packet packet;
packet.set_channel_id(12);
packet.set_service_id(0xdeadbeef);
packet.set_method_id(0x03a82921);
@@ -92,7 +116,8 @@
ASSERT_EQ(sws.status(), Status::OK);
std::span<byte> packet_data(buffer, sws.size());
- Packet decoded = Packet::FromBuffer(packet_data);
+ Packet decoded;
+ ASSERT_EQ(Status::OK, Packet::FromBuffer(packet_data, decoded));
EXPECT_EQ(decoded.type(), packet.type());
EXPECT_EQ(decoded.channel_id(), packet.channel_id());
@@ -107,7 +132,7 @@
}
constexpr size_t kReservedSize = 2 /* type */ + 2 /* channel */ +
- 2 /* service */ + 2 /* method */ +
+ 5 /* service */ + 5 /* method */ +
2 /* payload key */ + 2 /* status */;
TEST(Packet, PayloadUsableSpace_ExactFit) {
@@ -116,7 +141,7 @@
}
TEST(Packet, PayloadUsableSpace_LargerVarints) {
- EXPECT_EQ(kReservedSize + 2 + 1 + 1,
+ EXPECT_EQ(kReservedSize + 2 /* channel */, // service and method are Fixed32
Packet(PacketType::RPC, 17000, 200, 200).MinEncodedSizeBytes());
}
diff --git a/pw_rpc/public/pw_rpc/internal/base_server_writer.h b/pw_rpc/public/pw_rpc/internal/base_server_writer.h
index 745b615..2b4b0fa 100644
--- a/pw_rpc/public/pw_rpc/internal/base_server_writer.h
+++ b/pw_rpc/public/pw_rpc/internal/base_server_writer.h
@@ -21,6 +21,7 @@
#include "pw_rpc/internal/call.h"
#include "pw_rpc/internal/channel.h"
#include "pw_rpc/internal/service.h"
+#include "pw_status/status.h"
namespace pw::rpc::internal {
@@ -52,13 +53,15 @@
uint32_t method_id() const;
// Closes the ServerWriter, if it is open.
- void Finish();
+ void Finish(Status status = Status::OK);
protected:
constexpr BaseServerWriter() : state_{kClosed} {}
const Method& method() const { return call_.method(); }
+ const Channel& channel() const { return call_.channel(); }
+
std::span<std::byte> AcquirePayloadBuffer();
Status ReleasePayloadBuffer(std::span<const std::byte> payload);
diff --git a/pw_rpc/public/pw_rpc/internal/channel.h b/pw_rpc/public/pw_rpc/internal/channel.h
index b42c8a0..ba03f07 100644
--- a/pw_rpc/public/pw_rpc/internal/channel.h
+++ b/pw_rpc/public/pw_rpc/internal/channel.h
@@ -65,6 +65,11 @@
return OutputBuffer(output().AcquireBuffer());
}
+ Status Send(const internal::Packet& packet) {
+ OutputBuffer buffer = AcquireBuffer();
+ return Send(buffer, packet);
+ }
+
Status Send(OutputBuffer& output, const internal::Packet& packet);
};
diff --git a/pw_rpc/public/pw_rpc/internal/packet.h b/pw_rpc/public/pw_rpc/internal/packet.h
index 79b18f6..15904f7 100644
--- a/pw_rpc/public/pw_rpc/internal/packet.h
+++ b/pw_rpc/public/pw_rpc/internal/packet.h
@@ -28,12 +28,39 @@
// Parses a packet from a protobuf message. Missing or malformed fields take
// their default values.
- static Packet FromBuffer(std::span<const std::byte> data);
+ static Status FromBuffer(std::span<const std::byte> data, Packet& packet);
+
+ // Creates an RPC packet with the channel, service, and method ID of the
+ // provided packet.
+ static constexpr Packet Response(const Packet& request,
+ Status status = Status::OK) {
+ return Packet(PacketType::RPC,
+ request.channel_id(),
+ request.service_id(),
+ request.method_id(),
+ {},
+ status);
+ }
+
+ // Creates an ERROR packet with the channel, service, and method ID of the
+ // provided packet.
+ static constexpr Packet Error(const Packet& packet, Status status) {
+ return Packet(PacketType::ERROR,
+ packet.channel_id(),
+ packet.service_id(),
+ packet.method_id(),
+ {},
+ status);
+ }
+
+ // Creates an empty packet.
+ constexpr Packet()
+ : Packet(PacketType::RPC, kUnassignedId, kUnassignedId, kUnassignedId) {}
constexpr Packet(PacketType type,
- uint32_t channel_id = kUnassignedId,
- uint32_t service_id = kUnassignedId,
- uint32_t method_id = kUnassignedId,
+ uint32_t channel_id,
+ uint32_t service_id,
+ uint32_t method_id,
std::span<const std::byte> payload = {},
Status status = Status::OK)
: type_(type),
@@ -47,26 +74,31 @@
StatusWithSize Encode(std::span<std::byte> buffer) const;
// Determines the space required to encode the packet proto fields for a
- // response. This may be used to split the buffer into reserved space and
- // available space for the payload.
+ // response, excluding the payload. This may be used to split the buffer into
+ // reserved space and available space for the payload.
size_t MinEncodedSizeBytes() const;
- bool is_control() const { return !is_rpc(); }
- bool is_rpc() const { return type_ == PacketType::RPC; }
+ constexpr PacketType type() const { return type_; }
+ constexpr uint32_t channel_id() const { return channel_id_; }
+ constexpr uint32_t service_id() const { return service_id_; }
+ constexpr uint32_t method_id() const { return method_id_; }
+ constexpr const std::span<const std::byte>& payload() const {
+ return payload_;
+ }
+ constexpr Status status() const { return status_; }
- PacketType type() const { return type_; }
- uint32_t channel_id() const { return channel_id_; }
- uint32_t service_id() const { return service_id_; }
- uint32_t method_id() const { return method_id_; }
- const std::span<const std::byte>& payload() const { return payload_; }
- Status status() const { return status_; }
-
- void set_type(PacketType type) { type_ = type; }
- void set_channel_id(uint32_t channel_id) { channel_id_ = channel_id; }
- void set_service_id(uint32_t service_id) { service_id_ = service_id; }
- void set_method_id(uint32_t method_id) { method_id_ = method_id; }
- void set_payload(std::span<const std::byte> payload) { payload_ = payload; }
- void set_status(Status status) { status_ = status; }
+ constexpr void set_type(PacketType type) { type_ = type; }
+ constexpr void set_channel_id(uint32_t channel_id) {
+ channel_id_ = channel_id;
+ }
+ constexpr void set_service_id(uint32_t service_id) {
+ service_id_ = service_id;
+ }
+ constexpr void set_method_id(uint32_t method_id) { method_id_ = method_id; }
+ constexpr void set_payload(std::span<const std::byte> payload) {
+ payload_ = payload;
+ }
+ constexpr void set_status(Status status) { status_ = status; }
private:
PacketType type_;
diff --git a/pw_rpc/public/pw_rpc/server.h b/pw_rpc/public/pw_rpc/server.h
index 9cd5695..5fa5f6a 100644
--- a/pw_rpc/public/pw_rpc/server.h
+++ b/pw_rpc/public/pw_rpc/server.h
@@ -48,15 +48,8 @@
IntrusiveList<internal::BaseServerWriter>& writers() { return writers_; }
private:
- void HandleRpcPacket(const internal::Packet& request,
- internal::Channel& channel);
-
- void HandleCancelPacket(const internal::Packet& request);
-
- void InvokeMethod(const internal::Packet& request,
- Channel& channel,
- internal::Packet& response,
- std::span<std::byte> buffer);
+ void HandleCancelPacket(const internal::Packet& request,
+ internal::Channel& channel);
internal::Channel* FindChannel(uint32_t id) const;
internal::Channel* AssignChannel(uint32_t id, ChannelOutput& interface);
diff --git a/pw_rpc/pw_rpc_private/test_utils.h b/pw_rpc/pw_rpc_private/test_utils.h
index 8ad045b..0bf3272 100644
--- a/pw_rpc/pw_rpc_private/test_utils.h
+++ b/pw_rpc/pw_rpc_private/test_utils.h
@@ -25,25 +25,42 @@
namespace pw::rpc {
-template <size_t buffer_size>
+template <size_t output_buffer_size>
class TestOutput : public ChannelOutput {
public:
+ static constexpr size_t buffer_size() { return output_buffer_size; }
+
constexpr TestOutput(const char* name = "TestOutput")
- : ChannelOutput(name), sent_packet_{} {}
+ : ChannelOutput(name), sent_data_{} {}
std::span<std::byte> AcquireBuffer() override { return buffer_; }
void SendAndReleaseBuffer(size_t size) override {
- sent_packet_ = std::span(buffer_.data(), size);
+ if (size == 0u) {
+ return;
+ }
+
+ packet_count_ += 1;
+ sent_data_ = std::span(buffer_.data(), size);
+ EXPECT_EQ(Status::OK,
+ internal::Packet::FromBuffer(sent_data_, sent_packet_));
}
std::span<const std::byte> buffer() const { return buffer_; }
- const std::span<const std::byte>& sent_packet() const { return sent_packet_; }
+ size_t packet_count() const { return packet_count_; }
+
+ const std::span<const std::byte>& sent_data() const { return sent_data_; }
+ const internal::Packet& sent_packet() const {
+ EXPECT_GT(packet_count_, 0u);
+ return sent_packet_;
+ }
private:
- std::array<std::byte, buffer_size> buffer_;
- std::span<const std::byte> sent_packet_;
+ std::array<std::byte, buffer_size()> buffer_;
+ std::span<const std::byte> sent_data_;
+ internal::Packet sent_packet_;
+ size_t packet_count_ = 0;
};
// Version of the internal::Server with extra methods exposed for testing.
diff --git a/pw_rpc/pw_rpc_protos/packet.proto b/pw_rpc/pw_rpc_protos/packet.proto
index cd03194..df226c7 100644
--- a/pw_rpc/pw_rpc_protos/packet.proto
+++ b/pw_rpc/pw_rpc_protos/packet.proto
@@ -16,34 +16,37 @@
package pw.rpc.internal;
enum PacketType {
- // RPC packets correspond with an RPC request or response.
+ // RPC packets correspond with a request or response for a service method.
RPC = 0;
- // CANCEL packets indicate cancellation of an ongoing streaming RPC. They are
- // sent by the client to request cancellation and by the server to signal that
- // an RPC was cancelled. The server sends a CANCEL packet when an RPC is
- // cancelled server-side and in response to a client's request.
- CANCEL = 1;
+ // STREAM_END packets signal the end of a server or client stream.
+ STREAM_END = 1;
+
+ // CANCEL packets request termination of an ongoing RPC.
+ CANCEL = 2;
+
+ // ERROR packets are sent by the server to indicate that it received an
+ // unexpected or malformed packet.
+ ERROR = 3;
}
message RpcPacket {
- // The type of packet. Either a general RPC packet or a specific control
- // packet. Required.
+ // The type of packet. Determines which other fields are used. Required.
PacketType type = 1;
// Channel through which the packet is sent. Required.
uint32 channel_id = 2;
- // Tokenized fully-qualified name of the service with which this packet is
+ // Hash of the fully-qualified name of the service with which this packet is
// associated. For RPC packets, this is the service that processes the packet.
- uint32 service_id = 3;
+ fixed32 service_id = 3;
- // Tokenized name of the method which should process this packet.
- uint32 method_id = 4;
+ // Hash of the name of the method which should process this packet.
+ fixed32 method_id = 4;
- // The packet's payload.
+ // The packet's payload, which is an encoded protobuf.
bytes payload = 5;
- // RPC response status code.
+ // Status code for the RPC response or error.
uint32 status = 6;
}
diff --git a/pw_rpc/pw_rpc_test_protos/test.proto b/pw_rpc/pw_rpc_test_protos/test.proto
index 3a67f91..16c4022 100644
--- a/pw_rpc/pw_rpc_test_protos/test.proto
+++ b/pw_rpc/pw_rpc_test_protos/test.proto
@@ -16,7 +16,7 @@
package pw.rpc.test;
message TestRequest {
- float integer = 1;
+ int64 integer = 1;
}
message TestResponse {
diff --git a/pw_rpc/py/pw_rpc/codegen_nanopb.py b/pw_rpc/py/pw_rpc/codegen_nanopb.py
index 3ea7d4d..fbfbe1c 100644
--- a/pw_rpc/py/pw_rpc/codegen_nanopb.py
+++ b/pw_rpc/py/pw_rpc/codegen_nanopb.py
@@ -75,18 +75,17 @@
req_type = method.request_type().nanopb_name()
res_type = method.response_type().nanopb_name()
- signature = f'static ::pw::Status {method.name()}'
output.write_line()
if method.type() == ProtoServiceMethod.Type.UNARY:
- output.write_line(f'{signature}(')
+ output.write_line(f'static ::pw::Status {method.name()} (')
with output.indent(4):
output.write_line('ServerContext& ctx,')
output.write_line(f'const {req_type}& request,')
output.write_line(f'{res_type}& response);')
elif method.type() == ProtoServiceMethod.Type.SERVER_STREAMING:
- output.write_line(f'{signature}(')
+ output.write_line(f'static void {method.name()} (')
with output.indent(4):
output.write_line('ServerContext& ctx,')
output.write_line(f'const {req_type}& request,')
diff --git a/pw_rpc/server.cc b/pw_rpc/server.cc
index 7c5866b..0760091 100644
--- a/pw_rpc/server.cc
+++ b/pw_rpc/server.cc
@@ -23,12 +23,52 @@
#include "pw_rpc/server_context.h"
namespace pw::rpc {
+namespace {
using std::byte;
using internal::Packet;
using internal::PacketType;
+enum IncludeMethod : bool { kIncludeMethod, kOmitMethod };
+
+void SendError(internal::Channel& channel,
+ Status status,
+ const Packet& packet,
+ IncludeMethod method = kIncludeMethod) {
+ Packet error = Packet::Error(packet, status);
+ if (method == kOmitMethod) {
+ error.set_method_id(Packet::kUnassignedId);
+ }
+ channel.Send(error);
+}
+
+bool DecodePacket(ChannelOutput& interface,
+ std::span<const byte> data,
+ Packet& packet) {
+ if (Status status = Packet::FromBuffer(data, packet); !status.ok()) {
+ PW_LOG_WARN("Failed to decode packet on interface %s", interface.name());
+ return false;
+ }
+
+ // If the packet is malformed, don't try to process it.
+ if (packet.channel_id() == Channel::kUnassignedChannelId ||
+ packet.service_id() == 0 || packet.method_id() == 0) {
+ PW_LOG_WARN("Received incomplete packet on interface %s", interface.name());
+
+ // Only send an ERROR response if a valid channel ID was provided.
+ if (packet.channel_id() != Channel::kUnassignedChannelId) {
+ internal::Channel temp_channel(packet.channel_id(), &interface);
+ SendError(temp_channel, Status::DATA_LOSS, packet);
+ }
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace
+
Server::~Server() {
// Since the writers remove themselves from the server in Finish(), remove the
// first writer until no writers remain.
@@ -40,12 +80,8 @@
void Server::ProcessPacket(std::span<const byte> data,
ChannelOutput& interface) {
// TODO(hepler): Update the packet parsing code to report when decoding fails.
- Packet packet = Packet::FromBuffer(data);
-
- if (packet.channel_id() == Channel::kUnassignedChannelId ||
- packet.service_id() == 0 || packet.method_id() == 0) {
- // Malformed packet; don't even try to process it.
- PW_LOG_WARN("Received incomplete packet on interface %s", interface.name());
+ Packet packet;
+ if (!DecodePacket(interface, data, packet)) {
return;
}
@@ -54,94 +90,64 @@
// If the requested channel doesn't exist, try to dynamically assign one.
channel = AssignChannel(packet.channel_id(), interface);
if (channel == nullptr) {
- // If a channel can't be assigned, send back a response indicating that
- // the server cannot process the request. The channel_id in the response
- // is not set, to allow clients to detect this error case.
+ // If a channel can't be assigned, send a RESOURCE_EXHAUSTED error.
internal::Channel temp_channel(packet.channel_id(), &interface);
-
- // TODO(hepler): Add a new PacketType for errors like this, rather than
- // using PacketType::RPC.
- Packet response(PacketType::RPC);
- response.set_status(Status::RESOURCE_EXHAUSTED);
- auto response_buffer = temp_channel.AcquireBuffer();
- temp_channel.Send(response_buffer, response);
+ SendError(temp_channel, Status::RESOURCE_EXHAUSTED, packet);
return;
}
}
- switch (packet.type()) {
- case PacketType::RPC:
- HandleRpcPacket(packet, *channel);
- break;
- case PacketType::CANCEL:
- HandleCancelPacket(packet);
- break;
- }
-}
-
-void Server::HandleRpcPacket(const internal::Packet& request,
- internal::Channel& channel) {
- Packet response(PacketType::RPC);
-
- response.set_channel_id(channel.id());
- auto response_buffer = channel.AcquireBuffer();
-
- // Invoke the method with matching service and method IDs, if any.
- InvokeMethod(request, channel, response, response_buffer.payload(response));
- channel.Send(response_buffer, response);
-}
-
-void Server::HandleCancelPacket(const internal::Packet& request) {
- auto writer = std::find_if(writers_.begin(), writers_.end(), [&](auto& w) {
- return w.channel_id() == request.channel_id() &&
- w.service_id() == request.service_id() &&
- w.method_id() == request.method_id();
- });
-
- if (writer == writers_.end()) {
- PW_LOG_WARN("Received CANCEL packet for unknown method");
- } else {
- writer->Finish();
- }
-}
-
-void Server::InvokeMethod(const Packet& request,
- Channel& channel,
- internal::Packet& response,
- std::span<std::byte> payload_buffer) {
+ // Packets always include service and method IDs.
auto service = std::find_if(services_.begin(), services_.end(), [&](auto& s) {
- return s.id() == request.service_id();
+ return s.id() == packet.service_id();
});
if (service == services_.end()) {
- // Couldn't find the requested service. Reply with a NOT_FOUND response
- // without the service_id field set.
- response.set_status(Status::NOT_FOUND);
+ SendError(*channel, Status::NOT_FOUND, packet, kOmitMethod);
return;
}
- response.set_service_id(service->id());
-
- const internal::Method* method = service->FindMethod(request.method_id());
+ const internal::Method* method = service->FindMethod(packet.method_id());
if (method == nullptr) {
- // Couldn't find the requested method. Reply with a NOT_FOUND response
- // without the method_id field set.
- response.set_status(Status::NOT_FOUND);
+ SendError(*channel, Status::NOT_FOUND, packet);
return;
}
- response.set_method_id(method->id());
+ switch (packet.type()) {
+ case PacketType::RPC: {
+ internal::ServerCall call(
+ static_cast<internal::Server&>(*this), *channel, *service, *method);
+ method->Invoke(call, packet);
+ return;
+ }
+ case PacketType::STREAM_END:
+ // TODO(hepler): Support client streaming RPCs.
+ break;
+ case PacketType::CANCEL:
+ HandleCancelPacket(packet, *channel);
+ return;
+ case PacketType::ERROR:
+ break;
+ }
+ SendError(*channel, Status::UNIMPLEMENTED, packet);
+ PW_LOG_WARN("Unable to handle packet of type %u", unsigned(packet.type()));
+}
- internal::ServerCall call(static_cast<internal::Server&>(*this),
- static_cast<internal::Channel&>(channel),
- *service,
- *method);
- StatusWithSize result =
- method->Invoke(call, request.payload(), payload_buffer);
+void Server::HandleCancelPacket(const Packet& packet,
+ internal::Channel& channel) {
+ auto writer = std::find_if(writers_.begin(), writers_.end(), [&](auto& w) {
+ return w.channel_id() == packet.channel_id() &&
+ w.service_id() == packet.service_id() &&
+ w.method_id() == packet.method_id();
+ });
- response.set_status(result.status());
- response.set_payload(payload_buffer.first(result.size()));
+ if (writer == writers_.end()) {
+ SendError(channel, Status::FAILED_PRECONDITION, packet);
+ PW_LOG_WARN("Received CANCEL packet for method that is not pending");
+ } else {
+ writer->Finish(Status::CANCELLED);
+ }
}
internal::Channel* Server::FindChannel(uint32_t id) const {
diff --git a/pw_rpc/server_test.cc b/pw_rpc/server_test.cc
index 47b62e5..a78118c 100644
--- a/pw_rpc/server_test.cc
+++ b/pw_rpc/server_test.cc
@@ -97,41 +97,13 @@
const Method& method = service_.method(100);
EXPECT_EQ(1u, method.last_channel_id());
- EXPECT_EQ(sizeof(kDefaultPayload), method.last_request().size());
+ ASSERT_EQ(sizeof(kDefaultPayload), method.last_request().payload().size());
EXPECT_EQ(std::memcmp(kDefaultPayload,
- method.last_request().data(),
- method.last_request().size()),
+ method.last_request().payload().data(),
+ method.last_request().payload().size()),
0);
}
-TEST_F(BasicServer, ProcessPacket_ValidMethod_SendsOkResponse) {
- server_.ProcessPacket(EncodeRequest(PacketType::RPC, 1, 42, 100), output_);
-
- Packet packet = Packet::FromBuffer(output_.sent_packet());
- EXPECT_EQ(packet.type(), PacketType::RPC);
- EXPECT_EQ(packet.channel_id(), 1u);
- EXPECT_EQ(packet.service_id(), 42u);
- EXPECT_EQ(packet.method_id(), 100u);
- EXPECT_TRUE(packet.payload().empty());
- EXPECT_EQ(packet.status(), Status::OK);
-}
-
-TEST_F(BasicServer, ProcessPacket_ValidMethod_SendsErrorResponse) {
- constexpr byte resp[] = {byte{0xf0}, byte{0x0d}};
- service_.method(200).set_response(resp);
- service_.method(200).set_status(Status::FAILED_PRECONDITION);
-
- server_.ProcessPacket(EncodeRequest(PacketType::RPC, 2, 42, 200), output_);
-
- Packet packet = Packet::FromBuffer(output_.sent_packet());
- EXPECT_EQ(packet.channel_id(), 2u);
- EXPECT_EQ(packet.service_id(), 42u);
- EXPECT_EQ(packet.method_id(), 200u);
- EXPECT_EQ(packet.status(), Status::FAILED_PRECONDITION);
- ASSERT_EQ(sizeof(resp), packet.payload().size());
- EXPECT_EQ(std::memcmp(packet.payload().data(), resp, sizeof(resp)), 0);
-}
-
TEST_F(BasicServer, ProcessPacket_IncompletePacket_NothingIsInvoked) {
server_.ProcessPacket(EncodeRequest(PacketType::RPC, 0, 42, 101), output_);
server_.ProcessPacket(EncodeRequest(PacketType::RPC, 1, 0, 101), output_);
@@ -141,6 +113,26 @@
EXPECT_EQ(0u, service_.method(200).last_channel_id());
}
+TEST_F(BasicServer, ProcessPacket_NoChannel_SendsNothing) {
+ server_.ProcessPacket(EncodeRequest(PacketType::RPC, 0, 42, 101), output_);
+
+ EXPECT_EQ(output_.packet_count(), 0u);
+}
+
+TEST_F(BasicServer, ProcessPacket_NoService_SendsDataLoss) {
+ server_.ProcessPacket(EncodeRequest(PacketType::RPC, 1, 0, 101), output_);
+
+ EXPECT_EQ(output_.sent_packet().type(), PacketType::ERROR);
+ EXPECT_EQ(output_.sent_packet().status(), Status::DATA_LOSS);
+}
+
+TEST_F(BasicServer, ProcessPacket_NoMethod_SendsDataLoss) {
+ server_.ProcessPacket(EncodeRequest(PacketType::RPC, 1, 42, 0), output_);
+
+ EXPECT_EQ(output_.sent_packet().type(), PacketType::ERROR);
+ EXPECT_EQ(output_.sent_packet().status(), Status::DATA_LOSS);
+}
+
TEST_F(BasicServer, ProcessPacket_InvalidMethod_NothingIsInvoked) {
server_.ProcessPacket(EncodeRequest(PacketType::RPC, 1, 42, 101), output_);
@@ -148,38 +140,34 @@
EXPECT_EQ(0u, service_.method(200).last_channel_id());
}
-TEST_F(BasicServer, ProcessPacket_InvalidMethod_SendsNotFound) {
+TEST_F(BasicServer, ProcessPacket_InvalidMethod_SendsError) {
server_.ProcessPacket(EncodeRequest(PacketType::RPC, 1, 42, 27), output_);
- Packet packet = Packet::FromBuffer(output_.sent_packet());
- EXPECT_EQ(packet.type(), PacketType::RPC);
+ const Packet& packet = output_.sent_packet();
+ EXPECT_EQ(packet.type(), PacketType::ERROR);
EXPECT_EQ(packet.channel_id(), 1u);
EXPECT_EQ(packet.service_id(), 42u);
- EXPECT_EQ(packet.method_id(), 0u); // No method ID 27
+ EXPECT_EQ(packet.method_id(), 27u); // No method ID 27
EXPECT_EQ(packet.status(), Status::NOT_FOUND);
}
-TEST_F(BasicServer, ProcessPacket_InvalidService_SendsNotFound) {
+TEST_F(BasicServer, ProcessPacket_InvalidService_SendsError) {
server_.ProcessPacket(EncodeRequest(PacketType::RPC, 1, 43, 27), output_);
- Packet packet = Packet::FromBuffer(output_.sent_packet());
- EXPECT_EQ(packet.status(), Status::NOT_FOUND);
+ const Packet& packet = output_.sent_packet();
+ EXPECT_EQ(packet.type(), PacketType::ERROR);
EXPECT_EQ(packet.channel_id(), 1u);
- EXPECT_EQ(packet.service_id(), 0u);
+ EXPECT_EQ(packet.service_id(), 43u);
+ EXPECT_EQ(packet.method_id(), 0u); // No method since service not found
+ EXPECT_EQ(packet.status(), Status::NOT_FOUND);
}
-TEST_F(BasicServer, ProcessPacket_UnassignedChannel_AssignsToAvalableSlot) {
+TEST_F(BasicServer, ProcessPacket_UnassignedChannel_AssignsToAvailableSlot) {
TestOutput<128> unassigned_output;
server_.ProcessPacket(
- EncodeRequest(PacketType::RPC, /*channel_id=*/99, 42, 27),
+ EncodeRequest(PacketType::RPC, /*channel_id=*/99, 42, 100),
unassigned_output);
- ASSERT_EQ(channels_[2].id(), 99u);
-
- Packet packet = Packet::FromBuffer(unassigned_output.sent_packet());
- EXPECT_EQ(packet.channel_id(), 99u);
- EXPECT_EQ(packet.service_id(), 42u);
- EXPECT_EQ(packet.method_id(), 0u); // No method ID 27
- EXPECT_EQ(packet.status(), Status::NOT_FOUND);
+ EXPECT_EQ(channels_[2].id(), 99u);
}
TEST_F(BasicServer,
@@ -189,50 +177,81 @@
server_.ProcessPacket(
EncodeRequest(PacketType::RPC, /*channel_id=*/99, 42, 27), output_);
- Packet packet = Packet::FromBuffer(output_.sent_packet());
+ const Packet& packet = output_.sent_packet();
EXPECT_EQ(packet.status(), Status::RESOURCE_EXHAUSTED);
- EXPECT_EQ(packet.channel_id(), 0u);
- EXPECT_EQ(packet.service_id(), 0u);
- EXPECT_EQ(packet.method_id(), 0u);
+ EXPECT_EQ(packet.channel_id(), 99u);
+ EXPECT_EQ(packet.service_id(), 42u);
+ EXPECT_EQ(packet.method_id(), 27u);
}
-TEST_F(BasicServer, ProcessPacket_Cancel_ClosesServerWriter) {
+TEST_F(BasicServer, ProcessPacket_Cancel_MethodNotActive_SendsError) {
// Set up a fake ServerWriter representing an ongoing RPC.
- internal::ServerCall call(static_cast<internal::Server&>(server_),
- static_cast<internal::Channel&>(channels_[0]),
- service_,
- service_.method(100));
- internal::BaseServerWriter writer(call);
- ASSERT_TRUE(writer.open());
-
server_.ProcessPacket(EncodeRequest(PacketType::CANCEL, 1, 42, 100), output_);
- EXPECT_FALSE(writer.open());
+ const Packet& packet = output_.sent_packet();
+ EXPECT_EQ(packet.type(), PacketType::ERROR);
+ EXPECT_EQ(packet.channel_id(), 1u);
+ EXPECT_EQ(packet.service_id(), 42u);
+ EXPECT_EQ(packet.method_id(), 100u);
+ EXPECT_EQ(packet.status(), Status::FAILED_PRECONDITION);
+}
- Packet packet = Packet::FromBuffer(output_.sent_packet());
- EXPECT_EQ(packet.type(), PacketType::CANCEL);
+class MethodPending : public BasicServer {
+ protected:
+ MethodPending()
+ : call_(static_cast<internal::Server&>(server_),
+ static_cast<internal::Channel&>(channels_[0]),
+ service_,
+ service_.method(100)),
+ writer_(call_) {
+ ASSERT_TRUE(writer_.open());
+ }
+
+ internal::ServerCall call_;
+ internal::BaseServerWriter writer_;
+};
+
+TEST_F(MethodPending, ProcessPacket_Cancel_ClosesServerWriter) {
+ server_.ProcessPacket(EncodeRequest(PacketType::CANCEL, 1, 42, 100), output_);
+
+ EXPECT_FALSE(writer_.open());
+}
+
+TEST_F(MethodPending, ProcessPacket_Cancel_SendsStreamEndPacket) {
+ server_.ProcessPacket(EncodeRequest(PacketType::CANCEL, 1, 42, 100), output_);
+
+ const Packet& packet = output_.sent_packet();
+ EXPECT_EQ(packet.type(), PacketType::STREAM_END);
EXPECT_EQ(packet.channel_id(), 1u);
EXPECT_EQ(packet.service_id(), 42u);
EXPECT_EQ(packet.method_id(), 100u);
EXPECT_TRUE(packet.payload().empty());
- EXPECT_EQ(packet.status(), Status::OK);
+ EXPECT_EQ(packet.status(), Status::CANCELLED);
}
-TEST_F(BasicServer, ProcessPacket_Cancel_UnknownIdIsIgnored) {
- internal::ServerCall call(static_cast<internal::Server&>(server_),
- static_cast<internal::Channel&>(channels_[0]),
- service_,
- service_.method(100));
- internal::BaseServerWriter writer(call);
- ASSERT_TRUE(writer.open());
-
- // Send packets with incorrect channel, service, and method ID.
+TEST_F(MethodPending, ProcessPacket_Cancel_IncorrectChannel) {
server_.ProcessPacket(EncodeRequest(PacketType::CANCEL, 2, 42, 100), output_);
- server_.ProcessPacket(EncodeRequest(PacketType::CANCEL, 1, 43, 100), output_);
- server_.ProcessPacket(EncodeRequest(PacketType::CANCEL, 1, 42, 101), output_);
- EXPECT_TRUE(writer.open());
- EXPECT_TRUE(output_.sent_packet().empty());
+ EXPECT_EQ(output_.sent_packet().type(), PacketType::ERROR);
+ EXPECT_EQ(output_.sent_packet().status(), Status::FAILED_PRECONDITION);
+ EXPECT_TRUE(writer_.open());
+}
+
+TEST_F(MethodPending, ProcessPacket_Cancel_IncorrectService) {
+ server_.ProcessPacket(EncodeRequest(PacketType::CANCEL, 1, 43, 100), output_);
+
+ EXPECT_EQ(output_.sent_packet().type(), PacketType::ERROR);
+ EXPECT_EQ(output_.sent_packet().status(), Status::NOT_FOUND);
+ EXPECT_EQ(output_.sent_packet().service_id(), 43u);
+ EXPECT_EQ(output_.sent_packet().method_id(), 0u);
+ EXPECT_TRUE(writer_.open());
+}
+
+TEST_F(MethodPending, ProcessPacket_CancelIncorrectMethod) {
+ server_.ProcessPacket(EncodeRequest(PacketType::CANCEL, 1, 42, 101), output_);
+ EXPECT_EQ(output_.sent_packet().type(), PacketType::ERROR);
+ EXPECT_EQ(output_.sent_packet().status(), Status::NOT_FOUND);
+ EXPECT_TRUE(writer_.open());
}
} // namespace
diff --git a/pw_rpc/test_impl/public_overrides/pw_rpc/internal/method.h b/pw_rpc/test_impl/public_overrides/pw_rpc/internal/method.h
index 297f14e..41ed781 100644
--- a/pw_rpc/test_impl/public_overrides/pw_rpc/internal/method.h
+++ b/pw_rpc/test_impl/public_overrides/pw_rpc/internal/method.h
@@ -18,6 +18,7 @@
#include <span>
#include "pw_rpc/internal/base_method.h"
+#include "pw_rpc/internal/packet.h"
#include "pw_rpc/server_context.h"
#include "pw_status/status_with_size.h"
@@ -29,24 +30,13 @@
public:
constexpr Method(uint32_t id) : BaseMethod(id), last_channel_id_(0) {}
- StatusWithSize Invoke(ServerCall& call,
- std::span<const std::byte> request,
- std::span<std::byte> payload_buffer) const {
+ void Invoke(ServerCall& call, const Packet& request) const {
last_channel_id_ = call.channel().id();
last_request_ = request;
- last_payload_buffer_ = payload_buffer;
-
- std::memcpy(payload_buffer.data(),
- response_.data(),
- std::min(response_.size(), payload_buffer.size()));
- return StatusWithSize(response_status_, response_.size());
}
uint32_t last_channel_id() const { return last_channel_id_; }
- std::span<const std::byte> last_request() const { return last_request_; }
- std::span<std::byte> last_payload_buffer() const {
- return last_payload_buffer_;
- }
+ const Packet& last_request() const { return last_request_; }
void set_response(std::span<const std::byte> payload) { response_ = payload; }
void set_status(Status status) { response_status_ = status; }
@@ -56,8 +46,7 @@
// The Method class is used exclusively in tests. Having these members mutable
// allows tests to verify that the Method is invoked correctly.
mutable uint32_t last_channel_id_;
- mutable std::span<const std::byte> last_request_;
- mutable std::span<std::byte> last_payload_buffer_;
+ mutable Packet last_request_;
std::span<const std::byte> response_;
Status response_status_;