pw_rpc: Expand server-side packet processing
- Handle various error cases with incoming packets.
- Dynamically assign a channel when a packet's ID doesn't exist.
- Reserve space in the response buffer for packet "header" fields.
Change-Id: Ibdce99c8ff1d37aa46bb4e400a4d8f8e646a8ac7
diff --git a/pw_protobuf/encoder.cc b/pw_protobuf/encoder.cc
index 05ce88a..c20b093 100644
--- a/pw_protobuf/encoder.cc
+++ b/pw_protobuf/encoder.cc
@@ -56,7 +56,10 @@
return encode_status_;
}
- memcpy(cursor_, ptr, size);
+ // Memmove the value into place as it's possible that it shares the encode
+ // buffer on a memory-constrained system.
+ std::memmove(cursor_, ptr, size);
+
cursor_ += size;
return Status::OK;
}
@@ -167,7 +170,7 @@
to_copy = end - read_cursor;
}
- memmove(write_cursor, read_cursor, to_copy);
+ std::memmove(write_cursor, read_cursor, to_copy);
write_cursor += to_copy;
read_cursor += to_copy;
diff --git a/pw_rpc/BUILD.gn b/pw_rpc/BUILD.gn
index 192a9bc..5226243 100644
--- a/pw_rpc/BUILD.gn
+++ b/pw_rpc/BUILD.gn
@@ -62,7 +62,10 @@
}
pw_test("server_test") {
- deps = [ ":pw_rpc" ]
+ deps = [
+ ":protos_pwpb",
+ ":pw_rpc",
+ ]
sources = [ "server_test.cc" ]
}
diff --git a/pw_rpc/packet.cc b/pw_rpc/packet.cc
index 57e6a47..38a9174 100644
--- a/pw_rpc/packet.cc
+++ b/pw_rpc/packet.cc
@@ -70,11 +70,13 @@
pw::protobuf::NestedEncoder encoder(buffer);
RpcPacket::Encoder rpc_packet(&encoder);
+ // The payload is encoded first, as it may share the encode buffer.
+ rpc_packet.WritePayload(payload_);
+
rpc_packet.WriteType(type_);
rpc_packet.WriteChannelId(channel_id_);
rpc_packet.WriteServiceId(service_id_);
rpc_packet.WriteMethodId(method_id_);
- rpc_packet.WritePayload(payload_);
rpc_packet.WriteStatus(status_);
span<const std::byte> proto;
diff --git a/pw_rpc/packet_test.cc b/pw_rpc/packet_test.cc
index 83d4ee8..dd413dd 100644
--- a/pw_rpc/packet_test.cc
+++ b/pw_rpc/packet_test.cc
@@ -24,8 +24,7 @@
TEST(Packet, EncodeDecode) {
constexpr byte payload[]{byte(0x00), byte(0x01), byte(0x02), byte(0x03)};
- Packet packet = Packet::Empty();
- packet.set_type(PacketType::RPC);
+ Packet packet = Packet::Empty(PacketType::RPC);
packet.set_channel_id(12);
packet.set_service_id(0xdeadbeef);
packet.set_method_id(0x03a82921);
diff --git a/pw_rpc/public/pw_rpc/channel.h b/pw_rpc/public/pw_rpc/channel.h
index 1ef6880..68fceca 100644
--- a/pw_rpc/public/pw_rpc/channel.h
+++ b/pw_rpc/public/pw_rpc/channel.h
@@ -15,6 +15,7 @@
#include <cstdint>
+#include "pw_assert/assert.h"
#include "pw_span/span.h"
#include "pw_status/status.h"
@@ -40,16 +41,22 @@
class Channel {
public:
+ static constexpr uint32_t kUnassignedChannelId = 0;
+
// Creates a dynamically assignable channel without a set ID or output.
constexpr Channel() : id_(kUnassignedChannelId), output_(nullptr) {}
// Creates a channel with a static ID. The channel's output can also be
// static, or it can set to null to allow dynamically opening connections
// through the channel.
- constexpr Channel(uint32_t id, ChannelOutput* output)
- : id_(id), output_(output) {}
+ template <uint32_t id>
+ static Channel Create(ChannelOutput* output) {
+ static_assert(id != kUnassignedChannelId, "Channel ID cannot be 0");
+ return Channel(id, output);
+ }
constexpr uint32_t id() const { return id_; }
+ constexpr bool assigned() const { return id_ != kUnassignedChannelId; }
span<std::byte> AcquireBuffer() const { return output_->AcquireBuffer(); }
void SendAndReleaseBuffer(size_t size) const {
@@ -57,7 +64,13 @@
}
private:
- static constexpr uint32_t kUnassignedChannelId = 0;
+ friend class Server;
+
+ constexpr Channel(uint32_t id, ChannelOutput* output)
+ : id_(id), output_(output) {
+ PW_CHECK_UINT_NE(id, kUnassignedChannelId);
+ }
+
uint32_t id_;
ChannelOutput* output_;
};
diff --git a/pw_rpc/public/pw_rpc/internal/packet.h b/pw_rpc/public/pw_rpc/internal/packet.h
index 0ed74b9..2f3a164 100644
--- a/pw_rpc/public/pw_rpc/internal/packet.h
+++ b/pw_rpc/public/pw_rpc/internal/packet.h
@@ -29,8 +29,8 @@
static Packet FromBuffer(span<const std::byte> data);
// Returns an empty packet with default values set.
- static constexpr Packet Empty() {
- return Packet(PacketType::RPC, 0, 0, 0, {}, Status::OK);
+ static constexpr Packet Empty(PacketType type) {
+ return Packet(type, 0, 0, 0, {}, Status::OK);
}
// Encodes the packet into its wire format. Returns the encoded size.
diff --git a/pw_rpc/public/pw_rpc/internal/service.h b/pw_rpc/public/pw_rpc/internal/service.h
index e457285..377b6c1 100644
--- a/pw_rpc/public/pw_rpc/internal/service.h
+++ b/pw_rpc/public/pw_rpc/internal/service.h
@@ -42,7 +42,8 @@
// Handles an incoming packet and populates a response. Errors that occur
// should be set within the response packet.
void ProcessPacket(const internal::Packet& request,
- internal::Packet& response);
+ internal::Packet& response,
+ span<std::byte> payload_buffer);
private:
friend class internal::ServiceRegistry;
diff --git a/pw_rpc/public/pw_rpc/server.h b/pw_rpc/public/pw_rpc/server.h
index 2faf1b8..b2727cd 100644
--- a/pw_rpc/public/pw_rpc/server.h
+++ b/pw_rpc/public/pw_rpc/server.h
@@ -42,7 +42,18 @@
using Service = internal::Service;
using ServiceRegistry = internal::ServiceRegistry;
- Channel* FindChannel(uint32_t id);
+ void SendResponse(const Channel& channel,
+ const internal::Packet& response,
+ span<std::byte> response_buffer) const;
+
+ // Determines the space required to encode the packet proto fields for a
+ // response, and splits the buffer into reserved space and available space for
+ // the payload. Returns a subspan of the payload space.
+ span<std::byte> ResponsePayloadUsableSpace(const internal::Packet& request,
+ span<std::byte> buffer) const;
+
+ Channel* FindChannel(uint32_t id) const;
+ Channel* AssignChannel(uint32_t id, ChannelOutput& interface);
span<Channel> channels_;
ServiceRegistry services_;
diff --git a/pw_rpc/server.cc b/pw_rpc/server.cc
index fb44de7..badf8a9 100644
--- a/pw_rpc/server.cc
+++ b/pw_rpc/server.cc
@@ -20,6 +20,7 @@
namespace pw::rpc {
using internal::Packet;
+using internal::PacketType;
void Server::ProcessPacket(span<const std::byte> data,
ChannelOutput& interface) {
@@ -29,44 +30,51 @@
return;
}
- if (packet.service_id() == 0 || packet.method_id() == 0) {
+ if (packet.channel_id() == Channel::kUnassignedChannelId ||
+ packet.service_id() == 0 || packet.method_id() == 0) {
// Malformed packet; don't even try to process it.
PW_LOG_ERROR("Received incomplete RPC packet on interface %u",
unsigned(interface.id()));
return;
}
+ Packet response = Packet::Empty(PacketType::RPC);
+
Channel* channel = FindChannel(packet.channel_id());
if (channel == nullptr) {
- // TODO(frolv): Dynamically assign channel.
- return;
+ // If the requested channel doesn't exist, try to dynamically assign one.
+ channel = AssignChannel(packet.channel_id(), interface);
+ if (channel == nullptr) {
+ // If a channel can't be assigned, send back a response indicating that
+ // the server cannot process the request. The channel_id in the response
+ // is not set, to allow clients to detect this error case.
+ Channel temp_channel(packet.channel_id(), &interface);
+ response.set_status(Status::RESOURCE_EXHAUSTED);
+ SendResponse(temp_channel, response, temp_channel.AcquireBuffer());
+ return;
+ }
}
span<std::byte> response_buffer = channel->AcquireBuffer();
+ span<std::byte> payload_buffer =
+ ResponsePayloadUsableSpace(packet, response_buffer);
+
+ response.set_channel_id(channel->id());
Service* service = services_.Find(packet.service_id());
if (service == nullptr) {
- // TODO(frolv): Send back a NOT_FOUND response.
- channel->SendAndReleaseBuffer(0);
+ // Couldn't find the requested service. Reply with a NOT_FOUND response
+ // without the server_id field set.
+ response.set_status(Status::NOT_FOUND);
+ SendResponse(*channel, response, response_buffer);
return;
}
- Packet response = Packet::Empty();
- response.set_channel_id(channel->id());
-
- service->ProcessPacket(packet, response);
-
- StatusWithSize sws = response.Encode(response_buffer);
- if (!sws.ok()) {
- // TODO(frolv): What should be done here?
- channel->SendAndReleaseBuffer(0);
- return;
- }
-
- channel->SendAndReleaseBuffer(sws.size());
+ service->ProcessPacket(packet, response, payload_buffer);
+ SendResponse(*channel, response, response_buffer);
}
-Channel* Server::FindChannel(uint32_t id) {
+Channel* Server::FindChannel(uint32_t id) const {
for (Channel& c : channels_) {
if (c.id() == id) {
return &c;
@@ -75,4 +83,48 @@
return nullptr;
}
+Channel* Server::AssignChannel(uint32_t id, ChannelOutput& interface) {
+ Channel* channel = FindChannel(Channel::kUnassignedChannelId);
+ if (channel == nullptr) {
+ return nullptr;
+ }
+
+ *channel = Channel(id, &interface);
+ return channel;
+}
+
+void Server::SendResponse(const Channel& channel,
+ const Packet& response,
+ span<std::byte> response_buffer) const {
+ StatusWithSize sws = response.Encode(response_buffer);
+ if (!sws.ok()) {
+ // TODO(frolv): What should be done here?
+ channel.SendAndReleaseBuffer(0);
+ PW_LOG_ERROR("Failed to encode response packet to channel buffer");
+ return;
+ }
+
+ channel.SendAndReleaseBuffer(sws.size());
+}
+
+span<std::byte> Server::ResponsePayloadUsableSpace(
+ const Packet& request, span<std::byte> buffer) const {
+ size_t reserved_size = 0;
+
+ reserved_size += 1; // channel_id key
+ reserved_size += varint::EncodedSize(request.channel_id());
+ reserved_size += 1; // service_id key
+ reserved_size += varint::EncodedSize(request.service_id());
+ reserved_size += 1; // method_id key
+ reserved_size += varint::EncodedSize(request.method_id());
+
+ // Packet type always takes two bytes to encode (varint key + varint enum).
+ reserved_size += 2;
+
+ // Status field always takes two bytes to encode (varint key + varint status).
+ reserved_size += 2;
+
+ return buffer.subspan(reserved_size);
+}
+
} // namespace pw::rpc
diff --git a/pw_rpc/server_test.cc b/pw_rpc/server_test.cc
index aae410e..85eecc0 100644
--- a/pw_rpc/server_test.cc
+++ b/pw_rpc/server_test.cc
@@ -15,10 +15,13 @@
#include "pw_rpc/server.h"
#include "gtest/gtest.h"
+#include "pw_rpc/internal/packet.h"
namespace pw::rpc {
namespace {
+using internal::Packet;
+using internal::PacketType;
using std::byte;
template <size_t buffer_size>
@@ -39,35 +42,110 @@
span<const byte> sent_packet_;
};
-TestOutput<512> output(1);
+Packet MakePacket(uint32_t channel_id,
+ uint32_t service_id,
+ uint32_t method_id,
+ span<const byte> payload) {
+ Packet packet = Packet::Empty(PacketType::RPC);
+ packet.set_channel_id(channel_id);
+ packet.set_service_id(service_id);
+ packet.set_method_id(method_id);
+ packet.set_payload(payload);
+ return packet;
+}
-// clang-format off
-constexpr uint8_t encoded_packet[] = {
- // type = PacketType::kRpc
- 0x08, 0x00,
- // channel_id = 1
- 0x10, 0x01,
- // service_id = 42
- 0x18, 0x2a,
- // method_id = 27
- 0x20, 0x1b,
- // payload
- 0x82, 0x02, 0xff, 0xff,
-};
-// clang-format on
-
-TEST(Server, DoesStuff) {
+TEST(Server, ProcessPacket_SendsResponse) {
+ TestOutput<128> output(1);
Channel channels[] = {
- Channel(1, &output),
- Channel(2, &output),
+ Channel::Create<1>(&output),
+ Channel::Create<2>(&output),
};
Server server(channels);
internal::Service service(42, {});
server.RegisterService(service);
- server.ProcessPacket(as_bytes(span(encoded_packet)), output);
- auto packet = output.sent_packet();
- EXPECT_GT(packet.size(), 0u);
+ byte encoded_packet[64];
+ constexpr byte payload[] = {byte(0x82), byte(0x02), byte(0xff), byte(0xff)};
+ Packet request = MakePacket(1, 42, 27, payload);
+ auto sws = request.Encode(encoded_packet);
+
+ server.ProcessPacket(span(encoded_packet, sws.size()), output);
+ Packet packet = Packet::FromBuffer(output.sent_packet());
+ EXPECT_EQ(packet.status(), Status::OK);
+ EXPECT_EQ(packet.channel_id(), 1u);
+ EXPECT_EQ(packet.service_id(), 42u);
+}
+
+TEST(Server, ProcessPacket_SendsNotFoundOnInvalidService) {
+ TestOutput<128> output(1);
+ Channel channels[] = {
+ Channel::Create<1>(&output),
+ Channel::Create<2>(&output),
+ };
+ Server server(channels);
+ internal::Service service(42, {});
+ server.RegisterService(service);
+
+ byte encoded_packet[64];
+ constexpr byte payload[] = {byte(0x82), byte(0x02), byte(0xff), byte(0xff)};
+ Packet request = MakePacket(1, 43, 27, payload);
+ auto sws = request.Encode(encoded_packet);
+
+ server.ProcessPacket(span(encoded_packet, sws.size()), output);
+ Packet packet = Packet::FromBuffer(output.sent_packet());
+ EXPECT_EQ(packet.status(), Status::NOT_FOUND);
+ EXPECT_EQ(packet.channel_id(), 1u);
+ EXPECT_EQ(packet.service_id(), 0u);
+}
+
+TEST(Server, ProcessPacket_AssignsAnUnassignedChannel) {
+ TestOutput<128> output(1);
+ Channel channels[] = {
+ Channel::Create<1>(&output),
+ Channel::Create<2>(&output),
+ Channel(),
+ };
+ Server server(channels);
+ internal::Service service(42, {});
+ server.RegisterService(service);
+
+ byte encoded_packet[64];
+ constexpr byte payload[] = {byte(0x82), byte(0x02), byte(0xff), byte(0xff)};
+ Packet request = MakePacket(/*channel_id=*/99, 42, 27, payload);
+ auto sws = request.Encode(encoded_packet);
+
+ TestOutput<128> unassigned_output(2);
+ server.ProcessPacket(span(encoded_packet, sws.size()), unassigned_output);
+ ASSERT_EQ(channels[2].id(), 99u);
+
+ Packet packet = Packet::FromBuffer(unassigned_output.sent_packet());
+ EXPECT_EQ(packet.status(), Status::OK);
+ EXPECT_EQ(packet.channel_id(), 99u);
+ EXPECT_EQ(packet.service_id(), 42u);
+}
+
+TEST(Server, ProcessPacket_SendsResourceExhaustedWhenChannelCantBeAssigned) {
+ TestOutput<128> output(1);
+ Channel channels[] = {
+ Channel::Create<1>(&output),
+ Channel::Create<2>(&output),
+ };
+ Server server(channels);
+ internal::Service service(42, {});
+ server.RegisterService(service);
+
+ byte encoded_packet[64];
+ constexpr byte payload[] = {byte(0x82), byte(0x02), byte(0xff), byte(0xff)};
+ Packet request = MakePacket(/*channel_id=*/99, 42, 27, payload);
+ auto sws = request.Encode(encoded_packet);
+
+ server.ProcessPacket(span(encoded_packet, sws.size()), output);
+
+ Packet packet = Packet::FromBuffer(output.sent_packet());
+ EXPECT_EQ(packet.status(), Status::RESOURCE_EXHAUSTED);
+ EXPECT_EQ(packet.channel_id(), 0u);
+ EXPECT_EQ(packet.service_id(), 0u);
+ EXPECT_EQ(packet.method_id(), 0u);
}
} // namespace
diff --git a/pw_rpc/service.cc b/pw_rpc/service.cc
index 996043c..819e6cd 100644
--- a/pw_rpc/service.cc
+++ b/pw_rpc/service.cc
@@ -18,8 +18,9 @@
namespace pw::rpc::internal {
-void Service::ProcessPacket(const Packet& request, Packet& response) {
- response.set_type(PacketType::RPC);
+void Service::ProcessPacket(const Packet& request,
+ Packet& response,
+ span<std::byte> payload_buffer) {
response.set_service_id(id_);
for (const Method& method : methods_) {
@@ -28,6 +29,8 @@
response.set_method_id(method.id);
}
}
+
+ (void)payload_buffer;
}
} // namespace pw::rpc::internal