pw_protobuf: Reimplement decoder as field iterator

This changes the pw::protobuf::Decoder API to expose functions to
iterate over fields instead of using a virtual callback interface,
making the core decoder simpler and more flexible.

The virtual callback interface is kept but renamed to a CallbackDecoder,
and reimplemented in terms of basic Decoder.

Change-Id: Idff321cd5e37184aa730251475c9e336136596d2
diff --git a/pw_protobuf/decoder.cc b/pw_protobuf/decoder.cc
index 54d4675..3183556 100644
--- a/pw_protobuf/decoder.cc
+++ b/pw_protobuf/decoder.cc
@@ -20,75 +20,178 @@
 
 namespace pw::protobuf {
 
-Status Decoder::Decode(span<const std::byte> proto) {
-  if (handler_ == nullptr || state_ != kReady) {
-    return Status::FAILED_PRECONDITION;
-  }
-
-  state_ = kDecodeInProgress;
-  proto_ = proto;
-
-  // Iterate over each field in the proto, calling the handler with the field
-  // key.
-  while (state_ == kDecodeInProgress && !proto_.empty()) {
-    const std::byte* original_cursor = proto_.data();
-
-    uint64_t key;
-    size_t bytes_read = varint::Decode(proto_, &key);
-    if (bytes_read == 0) {
-      state_ = kDecodeFailed;
-      return Status::DATA_LOSS;
-    }
-
-    uint32_t field_number = key >> kFieldNumberShift;
-    Status status = handler_->ProcessField(this, field_number);
-    if (!status.ok()) {
-      state_ = status == Status::CANCELLED ? kDecodeCancelled : kDecodeFailed;
+Status Decoder::Next() {
+  if (!previous_field_consumed_) {
+    if (Status status = SkipField(); !status.ok()) {
       return status;
     }
+  }
+  if (proto_.empty()) {
+    return Status::OUT_OF_RANGE;
+  }
+  previous_field_consumed_ = false;
+  return FieldSize() == 0 ? Status::DATA_LOSS : Status::OK;
+}
 
-    // The callback function can modify the decoder's state; check that
-    // everything is still okay.
-    if (state_ == kDecodeFailed) {
-      break;
-    }
-
-    // If the cursor has not moved, the user has not consumed the field in their
-    // callback. Skip ahead to the next field.
-    if (original_cursor == proto_.data()) {
-      SkipField();
-    }
+Status Decoder::SkipField() {
+  if (proto_.empty()) {
+    return Status::OUT_OF_RANGE;
   }
 
-  if (state_ != kDecodeInProgress) {
+  size_t bytes_to_skip = FieldSize();
+  if (bytes_to_skip == 0) {
     return Status::DATA_LOSS;
   }
 
-  state_ = kReady;
+  proto_ = proto_.subspan(bytes_to_skip);
+  return proto_.empty() ? Status::OUT_OF_RANGE : Status::OK;
+}
+
+uint32_t Decoder::FieldNumber() const {
+  uint64_t key;
+  varint::Decode(proto_, &key);
+  return key >> kFieldNumberShift;
+}
+
+Status Decoder::ReadUint32(uint32_t* out) {
+  uint64_t value = 0;
+  Status status = ReadUint64(&value);
+  if (!status.ok()) {
+    return status;
+  }
+  if (value > std::numeric_limits<uint32_t>::max()) {
+    return Status::OUT_OF_RANGE;
+  }
+  *out = value;
   return Status::OK;
 }
 
-Status Decoder::ReadVarint(uint32_t field_number, uint64_t* out) {
-  Status status = ConsumeKey(field_number, WireType::kVarint);
+Status Decoder::ReadSint32(int32_t* out) {
+  int64_t value = 0;
+  Status status = ReadSint64(&value);
   if (!status.ok()) {
     return status;
   }
+  if (value > std::numeric_limits<int32_t>::max()) {
+    return Status::OUT_OF_RANGE;
+  }
+  *out = value;
+  return Status::OK;
+}
+
+Status Decoder::ReadSint64(int64_t* out) {
+  uint64_t value = 0;
+  Status status = ReadUint64(&value);
+  if (!status.ok()) {
+    return status;
+  }
+  *out = varint::ZigZagDecode(value);
+  return Status::OK;
+}
+
+Status Decoder::ReadBool(bool* out) {
+  uint64_t value = 0;
+  Status status = ReadUint64(&value);
+  if (!status.ok()) {
+    return status;
+  }
+  *out = value;
+  return Status::OK;
+}
+
+Status Decoder::ReadString(std::string_view* out) {
+  span<const std::byte> bytes;
+  Status status = ReadDelimited(&bytes);
+  if (!status.ok()) {
+    return status;
+  }
+  *out = std::string_view(reinterpret_cast<const char*>(bytes.data()),
+                          bytes.size());
+  return Status::OK;
+}
+
+size_t Decoder::FieldSize() const {
+  uint64_t key;
+  size_t key_size = varint::Decode(proto_, &key);
+  if (key_size == 0) {
+    return 0;
+  }
+
+  span<const std::byte> remainder = proto_.subspan(key_size);
+  WireType wire_type = static_cast<WireType>(key & kWireTypeMask);
+  uint64_t value = 0;
+  size_t expected_size = 0;
+
+  switch (wire_type) {
+    case WireType::kVarint:
+      expected_size = varint::Decode(remainder, &value);
+      if (expected_size == 0) {
+        return 0;
+      }
+      break;
+
+    case WireType::kDelimited:
+      // Varint at cursor indicates size of the field.
+      expected_size = varint::Decode(remainder, &value);
+      if (expected_size == 0) {
+        return 0;
+      }
+      expected_size += value;
+      break;
+
+    case WireType::kFixed32:
+      expected_size = sizeof(uint32_t);
+      break;
+
+    case WireType::kFixed64:
+      expected_size = sizeof(uint64_t);
+      break;
+  }
+
+  if (remainder.size() < expected_size) {
+    return 0;
+  }
+
+  return key_size + expected_size;
+}
+
+Status Decoder::ConsumeKey(WireType expected_type) {
+  uint64_t key;
+  size_t bytes_read = varint::Decode(proto_, &key);
+  if (bytes_read == 0) {
+    return Status::FAILED_PRECONDITION;
+  }
+
+  WireType wire_type = static_cast<WireType>(key & kWireTypeMask);
+  if (wire_type != expected_type) {
+    return Status::FAILED_PRECONDITION;
+  }
+
+  // Advance past the key.
+  proto_ = proto_.subspan(bytes_read);
+  return Status::OK;
+}
+
+Status Decoder::ReadVarint(uint64_t* out) {
+  if (Status status = ConsumeKey(WireType::kVarint); !status.ok()) {
+    return status;
+  }
 
   size_t bytes_read = varint::Decode(proto_, out);
   if (bytes_read == 0) {
-    state_ = kDecodeFailed;
     return Status::DATA_LOSS;
   }
 
   // Advance to the next field.
   proto_ = proto_.subspan(bytes_read);
+  previous_field_consumed_ = true;
   return Status::OK;
 }
 
-Status Decoder::ReadFixed(uint32_t field_number, std::byte* out, size_t size) {
+Status Decoder::ReadFixed(std::byte* out, size_t size) {
   WireType expected_wire_type =
       size == sizeof(uint32_t) ? WireType::kFixed32 : WireType::kFixed64;
-  Status status = ConsumeKey(field_number, expected_wire_type);
+  Status status = ConsumeKey(expected_wire_type);
   if (!status.ok()) {
     return status;
   }
@@ -99,13 +202,13 @@
 
   std::memcpy(out, proto_.data(), size);
   proto_ = proto_.subspan(size);
+  previous_field_consumed_ = true;
 
   return Status::OK;
 }
 
-Status Decoder::ReadDelimited(uint32_t field_number,
-                              span<const std::byte>* out) {
-  Status status = ConsumeKey(field_number, WireType::kDelimited);
+Status Decoder::ReadDelimited(span<const std::byte>* out) {
+  Status status = ConsumeKey(WireType::kDelimited);
   if (!status.ok()) {
     return status;
   }
@@ -113,80 +216,60 @@
   uint64_t length;
   size_t bytes_read = varint::Decode(proto_, &length);
   if (bytes_read == 0) {
-    state_ = kDecodeFailed;
     return Status::DATA_LOSS;
   }
 
   proto_ = proto_.subspan(bytes_read);
   if (proto_.size() < length) {
-    state_ = kDecodeFailed;
     return Status::DATA_LOSS;
   }
 
   *out = proto_.first(length);
   proto_ = proto_.subspan(length);
+  previous_field_consumed_ = true;
 
   return Status::OK;
 }
 
-Status Decoder::ConsumeKey(uint32_t field_number, WireType expected_type) {
+Status CallbackDecoder::Decode(span<const std::byte> proto) {
+  if (handler_ == nullptr || state_ != kReady) {
+    return Status::FAILED_PRECONDITION;
+  }
+
+  state_ = kDecodeInProgress;
+  decoder_.Reset(proto);
+
+  // Iterate the proto, calling the handler with each field number.
+  while (state_ == kDecodeInProgress) {
+    if (Status status = decoder_.Next(); !status.ok()) {
+      if (status == Status::OUT_OF_RANGE) {
+        // Reached the end of the proto.
+        break;
+      }
+
+      // Proto data is malformed.
+      return status;
+    }
+
+    Status status = handler_->ProcessField(*this, decoder_.FieldNumber());
+    if (!status.ok()) {
+      state_ = status == Status::CANCELLED ? kDecodeCancelled : kDecodeFailed;
+      return status;
+    }
+
+    // The callback function can modify the decoder's state; check that
+    // everything is still okay.
+    if (state_ == kDecodeFailed) {
+      break;
+    }
+  }
+
   if (state_ != kDecodeInProgress) {
-    return Status::FAILED_PRECONDITION;
+    return Status::DATA_LOSS;
   }
 
-  uint64_t key;
-  size_t bytes_read = varint::Decode(proto_, &key);
-  if (bytes_read == 0) {
-    state_ = kDecodeFailed;
-    return Status::FAILED_PRECONDITION;
-  }
-
-  uint32_t field = key >> kFieldNumberShift;
-  WireType wire_type = static_cast<WireType>(key & kWireTypeMask);
-
-  if (field != field_number || wire_type != expected_type) {
-    state_ = kDecodeFailed;
-    return Status::FAILED_PRECONDITION;
-  }
-
-  // Advance past the key.
-  proto_ = proto_.subspan(bytes_read);
+  state_ = kReady;
   return Status::OK;
 }
 
-void Decoder::SkipField() {
-  uint64_t key;
-  proto_ = proto_.subspan(varint::Decode(proto_, &key));
-
-  WireType wire_type = static_cast<WireType>(key & kWireTypeMask);
-  size_t bytes_to_skip = 0;
-  uint64_t value = 0;
-
-  switch (wire_type) {
-    case WireType::kVarint:
-      bytes_to_skip = varint::Decode(proto_, &value);
-      break;
-
-    case WireType::kDelimited:
-      // Varint at cursor indicates size of the field.
-      bytes_to_skip += varint::Decode(proto_, &value);
-      bytes_to_skip += value;
-      break;
-
-    case WireType::kFixed32:
-      bytes_to_skip = sizeof(uint32_t);
-      break;
-
-    case WireType::kFixed64:
-      bytes_to_skip = sizeof(uint64_t);
-      break;
-  }
-
-  if (bytes_to_skip == 0) {
-    state_ = kDecodeFailed;
-  } else {
-    proto_ = proto_.subspan(bytes_to_skip);
-  }
-}
-
 }  // namespace pw::protobuf
diff --git a/pw_protobuf/decoder_test.cc b/pw_protobuf/decoder_test.cc
index c655b09..abc9d54 100644
--- a/pw_protobuf/decoder_test.cc
+++ b/pw_protobuf/decoder_test.cc
@@ -22,27 +22,28 @@
 
 class TestDecodeHandler : public DecodeHandler {
  public:
-  Status ProcessField(Decoder* decoder, uint32_t field_number) override {
+  Status ProcessField(CallbackDecoder& decoder,
+                      uint32_t field_number) override {
     std::string_view str;
 
     switch (field_number) {
       case 1:
-        decoder->ReadInt32(field_number, &test_int32);
+        decoder.ReadInt32(&test_int32);
         break;
       case 2:
-        decoder->ReadSint32(field_number, &test_sint32);
+        decoder.ReadSint32(&test_sint32);
         break;
       case 3:
-        decoder->ReadBool(field_number, &test_bool);
+        decoder.ReadBool(&test_bool);
         break;
       case 4:
-        decoder->ReadDouble(field_number, &test_double);
+        decoder.ReadDouble(&test_double);
         break;
       case 5:
-        decoder->ReadFixed32(field_number, &test_fixed32);
+        decoder.ReadFixed32(&test_fixed32);
         break;
       case 6:
-        decoder->ReadString(field_number, &str);
+        decoder.ReadString(&str);
         std::memcpy(test_string, str.data(), str.size());
         test_string[str.size()] = '\0';
         break;
@@ -62,7 +63,101 @@
 };
 
 TEST(Decoder, Decode) {
-  Decoder decoder;
+  // clang-format off
+  uint8_t encoded_proto[] = {
+    // type=int32, k=1, v=42
+    0x08, 0x2a,
+    // type=sint32, k=2, v=-13
+    0x10, 0x19,
+    // type=bool, k=3, v=false
+    0x18, 0x00,
+    // type=double, k=4, v=3.14159
+    0x21, 0x6e, 0x86, 0x1b, 0xf0, 0xf9, 0x21, 0x09, 0x40,
+    // type=fixed32, k=5, v=0xdeadbeef
+    0x2d, 0xef, 0xbe, 0xad, 0xde,
+    // type=string, k=6, v="Hello world"
+    0x32, 0x0b, 'H', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd',
+  };
+  // clang-format on
+
+  Decoder decoder(as_bytes(span(encoded_proto)));
+
+  int32_t v1 = 0;
+  EXPECT_EQ(decoder.Next(), Status::OK);
+  ASSERT_EQ(decoder.FieldNumber(), 1u);
+  EXPECT_EQ(decoder.ReadInt32(&v1), Status::OK);
+  EXPECT_EQ(v1, 42);
+
+  int32_t v2 = 0;
+  EXPECT_EQ(decoder.Next(), Status::OK);
+  ASSERT_EQ(decoder.FieldNumber(), 2u);
+  EXPECT_EQ(decoder.ReadSint32(&v2), Status::OK);
+  EXPECT_EQ(v2, -13);
+
+  bool v3 = true;
+  EXPECT_EQ(decoder.Next(), Status::OK);
+  ASSERT_EQ(decoder.FieldNumber(), 3u);
+  EXPECT_EQ(decoder.ReadBool(&v3), Status::OK);
+  EXPECT_FALSE(v3);
+
+  double v4 = 0;
+  EXPECT_EQ(decoder.Next(), Status::OK);
+  ASSERT_EQ(decoder.FieldNumber(), 4u);
+  EXPECT_EQ(decoder.ReadDouble(&v4), Status::OK);
+  EXPECT_EQ(v4, 3.14159);
+
+  uint32_t v5 = 0;
+  EXPECT_EQ(decoder.Next(), Status::OK);
+  ASSERT_EQ(decoder.FieldNumber(), 5u);
+  EXPECT_EQ(decoder.ReadFixed32(&v5), Status::OK);
+  EXPECT_EQ(v5, 0xdeadbeef);
+
+  std::string_view v6;
+  char buffer[16];
+  EXPECT_EQ(decoder.Next(), Status::OK);
+  ASSERT_EQ(decoder.FieldNumber(), 6u);
+  EXPECT_EQ(decoder.ReadString(&v6), Status::OK);
+  std::memcpy(buffer, v6.data(), v6.size());
+  buffer[v6.size()] = '\0';
+  EXPECT_STREQ(buffer, "Hello world");
+
+  EXPECT_EQ(decoder.Next(), Status::OUT_OF_RANGE);
+}
+
+TEST(Decoder, Decode_SkipsUnusedFields) {
+  // clang-format off
+  uint8_t encoded_proto[] = {
+    // type=int32, k=1, v=42
+    0x08, 0x2a,
+    // type=sint32, k=2, v=-13
+    0x10, 0x19,
+    // type=bool, k=3, v=false
+    0x18, 0x00,
+    // type=double, k=4, v=3.14159
+    0x21, 0x6e, 0x86, 0x1b, 0xf0, 0xf9, 0x21, 0x09, 0x40,
+    // type=fixed32, k=5, v=0xdeadbeef
+    0x2d, 0xef, 0xbe, 0xad, 0xde,
+    // type=string, k=6, v="Hello world"
+    0x32, 0x0b, 'H', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd',
+  };
+  // clang-format on
+
+  Decoder decoder(as_bytes(span(encoded_proto)));
+
+  // Don't process any fields except for the fourth. Next should still iterate
+  // correctly despite field values not being consumed.
+  EXPECT_EQ(decoder.Next(), Status::OK);
+  EXPECT_EQ(decoder.Next(), Status::OK);
+  EXPECT_EQ(decoder.Next(), Status::OK);
+  EXPECT_EQ(decoder.Next(), Status::OK);
+  ASSERT_EQ(decoder.FieldNumber(), 4u);
+  EXPECT_EQ(decoder.Next(), Status::OK);
+  EXPECT_EQ(decoder.Next(), Status::OK);
+  EXPECT_EQ(decoder.Next(), Status::OUT_OF_RANGE);
+}
+
+TEST(CallbackDecoder, Decode) {
+  CallbackDecoder decoder;
   TestDecodeHandler handler;
 
   // clang-format off
@@ -93,8 +188,8 @@
   EXPECT_STREQ(handler.test_string, "Hello world");
 }
 
-TEST(Decoder, Decode_OverridesDuplicateFields) {
-  Decoder decoder;
+TEST(CallbackDecoder, Decode_OverridesDuplicateFields) {
+  CallbackDecoder decoder;
   TestDecodeHandler handler;
 
   // clang-format off
@@ -114,8 +209,8 @@
   EXPECT_EQ(handler.test_int32, 44);
 }
 
-TEST(Decoder, Decode_Empty) {
-  Decoder decoder;
+TEST(CallbackDecoder, Decode_Empty) {
+  CallbackDecoder decoder;
   TestDecodeHandler handler;
 
   decoder.set_handler(&handler);
@@ -125,8 +220,8 @@
   EXPECT_EQ(handler.test_sint32, 0);
 }
 
-TEST(Decoder, Decode_BadData) {
-  Decoder decoder;
+TEST(CallbackDecoder, Decode_BadData) {
+  CallbackDecoder decoder;
   TestDecodeHandler handler;
 
   // Field key without a value.
@@ -139,13 +234,14 @@
 // Only processes fields numbered 1 or 3.
 class OneThreeDecodeHandler : public DecodeHandler {
  public:
-  Status ProcessField(Decoder* decoder, uint32_t field_number) override {
+  Status ProcessField(CallbackDecoder& decoder,
+                      uint32_t field_number) override {
     switch (field_number) {
       case 1:
-        EXPECT_EQ(decoder->ReadInt32(field_number, &field_one), Status::OK);
+        EXPECT_EQ(decoder.ReadInt32(&field_one), Status::OK);
         break;
       case 3:
-        EXPECT_EQ(decoder->ReadInt32(field_number, &field_three), Status::OK);
+        EXPECT_EQ(decoder.ReadInt32(&field_three), Status::OK);
         break;
       default:
         // Do nothing.
@@ -161,8 +257,8 @@
   int32_t field_three = 0;
 };
 
-TEST(Decoder, Decode_SkipsUnprocessedFields) {
-  Decoder decoder;
+TEST(CallbackDecoder, Decode_SkipsUnprocessedFields) {
+  CallbackDecoder decoder;
   OneThreeDecodeHandler handler;
 
   // clang-format off
@@ -192,16 +288,17 @@
   EXPECT_EQ(handler.field_three, 99);
 }
 
-// Only processes fields numbered 1 or 3.
+// Only processes fields numbered 1 or 3, and stops the decode after hitting 1.
 class ExitOnOneDecoder : public DecodeHandler {
  public:
-  Status ProcessField(Decoder* decoder, uint32_t field_number) override {
+  Status ProcessField(CallbackDecoder& decoder,
+                      uint32_t field_number) override {
     switch (field_number) {
       case 1:
-        EXPECT_EQ(decoder->ReadInt32(field_number, &field_one), Status::OK);
+        EXPECT_EQ(decoder.ReadInt32(&field_one), Status::OK);
         return Status::CANCELLED;
       case 3:
-        EXPECT_EQ(decoder->ReadInt32(field_number, &field_three), Status::OK);
+        EXPECT_EQ(decoder.ReadInt32(&field_three), Status::OK);
         break;
       default:
         // Do nothing.
@@ -215,8 +312,8 @@
   int32_t field_three = 1111;
 };
 
-TEST(Decoder, Decode_StopsOnNonOkStatus) {
-  Decoder decoder;
+TEST(CallbackDecoder, Decode_StopsOnNonOkStatus) {
+  CallbackDecoder decoder;
   ExitOnOneDecoder handler;
 
   // clang-format off
diff --git a/pw_protobuf/find.cc b/pw_protobuf/find.cc
index b98c4b9..70daabb 100644
--- a/pw_protobuf/find.cc
+++ b/pw_protobuf/find.cc
@@ -16,7 +16,7 @@
 
 namespace pw::protobuf {
 
-Status FindDecodeHandler::ProcessField(Decoder* decoder,
+Status FindDecodeHandler::ProcessField(CallbackDecoder& decoder,
                                        uint32_t field_number) {
   if (field_number != field_number_) {
     // Continue to the next field.
@@ -29,12 +29,11 @@
   }
 
   span<const std::byte> submessage;
-  if (Status status = decoder->ReadBytes(field_number, &submessage);
-      !status.ok()) {
+  if (Status status = decoder.ReadBytes(&submessage); !status.ok()) {
     return status;
   }
 
-  Decoder subdecoder;
+  CallbackDecoder subdecoder;
   subdecoder.set_handler(nested_handler_);
   return subdecoder.Decode(submessage);
 }
diff --git a/pw_protobuf/find_test.cc b/pw_protobuf/find_test.cc
index 860e338..2a7edd6 100644
--- a/pw_protobuf/find_test.cc
+++ b/pw_protobuf/find_test.cc
@@ -41,7 +41,7 @@
 };
 
 TEST(FindDecodeHandler, SingleLevel_FindsExistingField) {
-  Decoder decoder;
+  CallbackDecoder decoder;
   FindDecodeHandler finder(3);
 
   decoder.set_handler(&finder);
@@ -52,7 +52,7 @@
 }
 
 TEST(FindDecodeHandler, SingleLevel_DoesntFindNonExistingField) {
-  Decoder decoder;
+  CallbackDecoder decoder;
   FindDecodeHandler finder(8);
 
   decoder.set_handler(&finder);
@@ -63,7 +63,7 @@
 }
 
 TEST(FindDecodeHandler, MultiLevel_FindsExistingNestedField) {
-  Decoder decoder;
+  CallbackDecoder decoder;
   FindDecodeHandler nested_finder(1);
   FindDecodeHandler finder(7, &nested_finder);
 
@@ -76,7 +76,7 @@
 }
 
 TEST(FindDecodeHandler, MultiLevel_DoesntFindNonExistingNestedField) {
-  Decoder decoder;
+  CallbackDecoder decoder;
   FindDecodeHandler nested_finder(3);
   FindDecodeHandler finder(7, &nested_finder);
 
diff --git a/pw_protobuf/public/pw_protobuf/decoder.h b/pw_protobuf/public/pw_protobuf/decoder.h
index a3aca75..184c38f 100644
--- a/pw_protobuf/public/pw_protobuf/decoder.h
+++ b/pw_protobuf/public/pw_protobuf/decoder.h
@@ -21,10 +21,8 @@
 #include "pw_varint/varint.h"
 
 // This file defines a low-level event-based protobuf wire format decoder.
-// The decoder processes an encoded message by iterating over its fields and
-// notifying a handler for each field it encounters. The handler receives a
-// reference to the decoder object and can extract the field's value from the
-// message.
+// The decoder processes an encoded message by iterating over its fields. The
+// caller can extract the values of any fields it cares about.
 //
 // The decoder does not provide any in-memory data structures to represent a
 // protobuf message's data. More sophisticated APIs can be built on top of the
@@ -32,17 +30,161 @@
 //
 // Example usage:
 //
+//   Decoder decoder(proto);
+//   while (decoder.Next().ok()) {
+//     switch (decoder.FieldNumber()) {
+//       case 1:
+//         decoder.ReadUint32(&my_uint32);
+//         break;
+//       // ... and other fields.
+//     }
+//   }
+//
+namespace pw::protobuf {
+
+class Decoder {
+ public:
+  constexpr Decoder(span<const std::byte> proto)
+      : proto_(proto), previous_field_consumed_(true) {}
+
+  Decoder(const Decoder& other) = delete;
+  Decoder& operator=(const Decoder& other) = delete;
+
+  // Advances to the next field in the proto.
+  //
+  // If Next() returns OK, there is guaranteed to be a valid protobuf field at
+  // the current cursor position.
+  //
+  // Return values:
+  //
+  //             OK: Advanced to a valid proto field.
+  //   OUT_OF_RANGE: Reached the end of the proto message.
+  //      DATA_LOSS: Invalid protobuf data.
+  //
+  Status Next();
+
+  // Returns the field number of the field at the current cursor position.
+  uint32_t FieldNumber() const;
+
+  // Reads a proto int32 value from the current cursor.
+  Status ReadInt32(int32_t* out) {
+    return ReadUint32(reinterpret_cast<uint32_t*>(out));
+  }
+
+  // Reads a proto uint32 value from the current cursor.
+  Status ReadUint32(uint32_t* out);
+
+  // Reads a proto int64 value from the current cursor.
+  Status ReadInt64(int64_t* out) {
+    return ReadVarint(reinterpret_cast<uint64_t*>(out));
+  }
+
+  // Reads a proto uint64 value from the current cursor.
+  Status ReadUint64(uint64_t* out) { return ReadVarint(out); }
+
+  // Reads a proto sint32 value from the current cursor.
+  Status ReadSint32(int32_t* out);
+
+  // Reads a proto sint64 value from the current cursor.
+  Status ReadSint64(int64_t* out);
+
+  // Reads a proto bool value from the current cursor.
+  Status ReadBool(bool* out);
+
+  // Reads a proto fixed32 value from the current cursor.
+  Status ReadFixed32(uint32_t* out) { return ReadFixed(out); }
+
+  // Reads a proto fixed64 value from the current cursor.
+  Status ReadFixed64(uint64_t* out) { return ReadFixed(out); }
+
+  // Reads a proto sfixed32 value from the current cursor.
+  Status ReadSfixed32(int32_t* out) {
+    return ReadFixed32(reinterpret_cast<uint32_t*>(out));
+  }
+
+  // Reads a proto sfixed64 value from the current cursor.
+  Status ReadSfixed64(int64_t* out) {
+    return ReadFixed64(reinterpret_cast<uint64_t*>(out));
+  }
+
+  // Reads a proto float value from the current cursor.
+  Status ReadFloat(float* out) {
+    static_assert(sizeof(float) == sizeof(uint32_t),
+                  "Float and uint32_t must be the same size for protobufs");
+    return ReadFixed(out);
+  }
+
+  // Reads a proto double value from the current cursor.
+  Status ReadDouble(double* out) {
+    static_assert(sizeof(double) == sizeof(uint64_t),
+                  "Double and uint64_t must be the same size for protobufs");
+    return ReadFixed(out);
+  }
+
+  // Reads a proto string value from the current cursor and returns a view of it
+  // in `out`. The raw protobuf data must outlive `out`. If the string field is
+  // invalid, `out` is not modified.
+  Status ReadString(std::string_view* out);
+
+  // Reads a proto bytes value from the current cursor and returns a view of it
+  // in `out`. The raw protobuf data must outlive the `out` span. If the bytes
+  // field is invalid, `out` is not modified.
+  Status ReadBytes(span<const std::byte>* out) { return ReadDelimited(out); }
+
+  // Resets the decoder to start reading a new proto message.
+  void Reset(span<const std::byte> proto) {
+    proto_ = proto;
+    previous_field_consumed_ = true;
+  }
+
+ private:
+  // Advances the cursor to the next field in the proto.
+  Status SkipField();
+
+  // Returns the size of the current field, or 0 if the field is invalid.
+  size_t FieldSize() const;
+
+  Status ConsumeKey(WireType expected_type);
+
+  // Reads a varint key-value pair from the current cursor position.
+  Status ReadVarint(uint64_t* out);
+
+  // Reads a fixed-size key-value pair from the current cursor position.
+  Status ReadFixed(std::byte* out, size_t size);
+
+  template <typename T>
+  Status ReadFixed(T* out) {
+    static_assert(
+        sizeof(T) == sizeof(uint32_t) || sizeof(T) == sizeof(uint64_t),
+        "Protobuf fixed-size fields must be 32- or 64-bit");
+    return ReadFixed(reinterpret_cast<std::byte*>(out), sizeof(T));
+  }
+
+  Status ReadDelimited(span<const std::byte>* out);
+
+  span<const std::byte> proto_;
+  bool previous_field_consumed_;
+};
+
+class DecodeHandler;
+
+// A protobuf decoder that iterates over an encoded protobuf, calling a handler
+// for each field it encounters.
+//
+// Example usage:
+//
 //   class FooProtoHandler : public DecodeHandler {
 //    public:
-//     Status ProcessField(Decoder* decoder, uint32_t field_number) override {
+//     Status ProcessField(CallbackDecoder& decoder,
+//                         uint32_t field_number) override {
 //       switch (field_number) {
 //         case FooFields::kBar:
-//           if (!decoder->ReadSint32(field_number, &bar).ok()) {
+//           if (!decoder.ReadSint32(&bar).ok()) {
 //             bar = 0;
 //           }
 //           break;
 //         case FooFields::kBaz:
-//           if (!decoder->ReadUint32(field_number, &baz).ok()) {
+//           if (!decoder.ReadUint32(&baz).ok()) {
 //             baz = 0;
 //           }
 //           break;
@@ -68,19 +210,13 @@
 //              handler.bar, handler.baz);
 //   }
 //
-
-namespace pw::protobuf {
-
-class DecodeHandler;
-
-// A protobuf decoder that iterates over an encoded protobuf, calling a handler
-// for each field it encounters.
-class Decoder {
+class CallbackDecoder {
  public:
-  constexpr Decoder() : handler_(nullptr), state_(kReady) {}
+  constexpr CallbackDecoder()
+      : decoder_({}), handler_(nullptr), state_(kReady) {}
 
-  Decoder(const Decoder& other) = delete;
-  Decoder& operator=(const Decoder& other) = delete;
+  CallbackDecoder(const CallbackDecoder& other) = delete;
+  CallbackDecoder& operator=(const CallbackDecoder& other) = delete;
 
   void set_handler(DecodeHandler* handler) { handler_ = handler; }
 
@@ -89,123 +225,54 @@
   Status Decode(span<const std::byte> proto);
 
   // Reads a proto int32 value from the current cursor.
-  Status ReadInt32(uint32_t field_number, int32_t* out) {
-    return ReadUint32(field_number, reinterpret_cast<uint32_t*>(out));
-  }
+  Status ReadInt32(int32_t* out) { return decoder_.ReadInt32(out); }
 
   // Reads a proto uint32 value from the current cursor.
-  Status ReadUint32(uint32_t field_number, uint32_t* out) {
-    uint64_t value = 0;
-    Status status = ReadUint64(field_number, &value);
-    if (!status.ok()) {
-      return status;
-    }
-    if (value > std::numeric_limits<uint32_t>::max()) {
-      return Status::OUT_OF_RANGE;
-    }
-    *out = value;
-    return Status::OK;
-  }
+  Status ReadUint32(uint32_t* out) { return decoder_.ReadUint32(out); }
 
   // Reads a proto int64 value from the current cursor.
-  Status ReadInt64(uint32_t field_number, int64_t* out) {
-    return ReadVarint(field_number, reinterpret_cast<uint64_t*>(out));
-  }
+  Status ReadInt64(int64_t* out) { return decoder_.ReadInt64(out); }
 
   // Reads a proto uint64 value from the current cursor.
-  Status ReadUint64(uint32_t field_number, uint64_t* out) {
-    return ReadVarint(field_number, out);
-  }
-
-  // Reads a proto sint32 value from the current cursor.
-  Status ReadSint32(uint32_t field_number, int32_t* out) {
-    int64_t value = 0;
-    Status status = ReadSint64(field_number, &value);
-    if (!status.ok()) {
-      return status;
-    }
-    if (value > std::numeric_limits<int32_t>::max()) {
-      return Status::OUT_OF_RANGE;
-    }
-    *out = value;
-    return Status::OK;
-  }
+  Status ReadUint64(uint64_t* out) { return decoder_.ReadUint64(out); }
 
   // Reads a proto sint64 value from the current cursor.
-  Status ReadSint64(uint32_t field_number, int64_t* out) {
-    uint64_t value = 0;
-    Status status = ReadUint64(field_number, &value);
-    if (!status.ok()) {
-      return status;
-    }
-    *out = varint::ZigZagDecode(value);
-    return Status::OK;
-  }
+  Status ReadSint32(int32_t* out) { return decoder_.ReadSint32(out); }
+
+  // Reads a proto sint64 value from the current cursor.
+  Status ReadSint64(int64_t* out) { return decoder_.ReadSint64(out); }
 
   // Reads a proto bool value from the current cursor.
-  Status ReadBool(uint32_t field_number, bool* out) {
-    uint64_t value = 0;
-    Status status = ReadUint64(field_number, &value);
-    if (!status.ok()) {
-      return status;
-    }
-    *out = value;
-    return Status::OK;
-  }
+  Status ReadBool(bool* out) { return decoder_.ReadBool(out); }
 
   // Reads a proto fixed32 value from the current cursor.
-  Status ReadFixed32(uint32_t field_number, uint32_t* out) {
-    return ReadFixed(field_number, out);
-  }
+  Status ReadFixed32(uint32_t* out) { return decoder_.ReadFixed32(out); }
 
   // Reads a proto fixed64 value from the current cursor.
-  Status ReadFixed64(uint32_t field_number, uint64_t* out) {
-    return ReadFixed(field_number, out);
-  }
+  Status ReadFixed64(uint64_t* out) { return decoder_.ReadFixed64(out); }
 
   // Reads a proto sfixed32 value from the current cursor.
-  Status ReadSfixed32(uint32_t field_number, int32_t* out) {
-    return ReadFixed32(field_number, reinterpret_cast<uint32_t*>(out));
-  }
+  Status ReadSfixed32(int32_t* out) { return decoder_.ReadSfixed32(out); }
 
   // Reads a proto sfixed64 value from the current cursor.
-  Status ReadSfixed64(uint32_t field_number, int64_t* out) {
-    return ReadFixed64(field_number, reinterpret_cast<uint64_t*>(out));
-  }
+  Status ReadSfixed64(int64_t* out) { return decoder_.ReadSfixed64(out); }
 
   // Reads a proto float value from the current cursor.
-  Status ReadFloat(uint32_t field_number, float* out) {
-    static_assert(sizeof(float) == sizeof(uint32_t),
-                  "Float and uint32_t must be the same size for protobufs");
-    return ReadFixed(field_number, out);
-  }
+  Status ReadFloat(float* out) { return decoder_.ReadFloat(out); }
 
   // Reads a proto double value from the current cursor.
-  Status ReadDouble(uint32_t field_number, double* out) {
-    static_assert(sizeof(double) == sizeof(uint64_t),
-                  "Double and uint64_t must be the same size for protobufs");
-    return ReadFixed(field_number, out);
-  }
+  Status ReadDouble(double* out) { return decoder_.ReadDouble(out); }
 
   // Reads a proto string value from the current cursor and returns a view of it
   // in `out`. The raw protobuf data must outlive `out`. If the string field is
   // invalid, `out` is not modified.
-  Status ReadString(uint32_t field_number, std::string_view* out) {
-    span<const std::byte> bytes;
-    Status status = ReadDelimited(field_number, &bytes);
-    if (!status.ok()) {
-      return status;
-    }
-    *out = std::string_view(reinterpret_cast<const char*>(bytes.data()),
-                            bytes.size());
-    return Status::OK;
-  }
+  Status ReadString(std::string_view* out) { return decoder_.ReadString(out); }
 
   // Reads a proto bytes value from the current cursor and returns a view of it
   // in `out`. The raw protobuf data must outlive the `out` span. If the bytes
   // field is invalid, `out` is not modified.
-  Status ReadBytes(uint32_t field_number, span<const std::byte>* out) {
-    return ReadDelimited(field_number, out);
+  Status ReadBytes(span<const std::byte>* out) {
+    return decoder_.ReadBytes(out);
   }
 
   bool cancelled() const { return state_ == kDecodeCancelled; };
@@ -218,43 +285,14 @@
     kDecodeFailed,
   };
 
-  // Reads a varint key-value pair from the current cursor position.
-  Status ReadVarint(uint32_t field_number, uint64_t* out);
-
-  // Reads a fixed-size key-value pair from the current cursor position.
-  Status ReadFixed(uint32_t field_number, std::byte* out, size_t size);
-
-  template <typename T>
-  Status ReadFixed(uint32_t field_number, T* out) {
-    static_assert(
-        sizeof(T) == sizeof(uint32_t) || sizeof(T) == sizeof(uint64_t),
-        "Protobuf fixed-size fields must be 32- or 64-bit");
-    union {
-      T value;
-      std::byte bytes[sizeof(T)];
-    };
-    Status status = ReadFixed(field_number, bytes, sizeof(bytes));
-    if (!status.ok()) {
-      return status;
-    }
-    *out = value;
-    return Status::OK;
-  }
-
-  Status ReadDelimited(uint32_t field_number, span<const std::byte>* out);
-
-  Status ConsumeKey(uint32_t field_number, WireType expected_type);
-
-  // Advances the cursor to the next field in the proto.
-  void SkipField();
-
+  Decoder decoder_;
   DecodeHandler* handler_;
 
   State state_;
-  span<const std::byte> proto_;
 };
 
-// The event-handling interface implemented for a proto decoding operation.
+// The event-handling interface implemented for a proto callback decoding
+// operation.
 class DecodeHandler {
  public:
   virtual ~DecodeHandler() = default;
@@ -266,7 +304,8 @@
   // If the status returned is not Status::OK, the decode operation is exited
   // with the provided status. Returning Status::CANCELLED allows a convenient
   // way of stopping a decode early (for example, if a desired field is found).
-  virtual Status ProcessField(Decoder* decoder, uint32_t field_number) = 0;
+  virtual Status ProcessField(CallbackDecoder& decoder,
+                              uint32_t field_number) = 0;
 };
 
 }  // namespace pw::protobuf
diff --git a/pw_protobuf/public/pw_protobuf/find.h b/pw_protobuf/public/pw_protobuf/find.h
index 1e2f3e8..64ad94d 100644
--- a/pw_protobuf/public/pw_protobuf/find.h
+++ b/pw_protobuf/public/pw_protobuf/find.h
@@ -29,7 +29,7 @@
   constexpr FindDecodeHandler(uint32_t field_number, FindDecodeHandler* nested)
       : field_number_(field_number), found_(false), nested_handler_(nested) {}
 
-  Status ProcessField(Decoder* decoder, uint32_t field_number) override;
+  Status ProcessField(CallbackDecoder& decoder, uint32_t field_number) override;
 
   bool found() const { return found_; }
 
diff --git a/pw_protobuf/size_report/decoder_full.cc b/pw_protobuf/size_report/decoder_full.cc
index d879012..561335e 100644
--- a/pw_protobuf/size_report/decoder_full.cc
+++ b/pw_protobuf/size_report/decoder_full.cc
@@ -28,71 +28,38 @@
 // clang-format on
 }  // namespace
 
-class TestDecodeHandler : public pw::protobuf::DecodeHandler {
- public:
-  pw::Status ProcessField(pw::protobuf::Decoder* decoder,
-                          uint32_t field_number) override {
-    std::string_view str;
-
-    switch (field_number) {
-      case 1:
-        if (!decoder->ReadInt32(field_number, &test_int32).ok()) {
-          test_int32 = 0;
-        }
-        break;
-      case 2:
-        if (!decoder->ReadSint32(field_number, &test_sint32).ok()) {
-          test_sint32 = 0;
-        }
-        break;
-      case 3:
-        if (!decoder->ReadBool(field_number, &test_bool).ok()) {
-          test_bool = false;
-        }
-        break;
-      case 4:
-        if (!decoder->ReadDouble(field_number, &test_double).ok()) {
-          test_double = 0;
-        }
-        break;
-      case 5:
-        if (!decoder->ReadFixed32(field_number, &test_fixed32).ok()) {
-          test_fixed32 = 0;
-        }
-        break;
-      case 6:
-        if (decoder->ReadString(field_number, &str).ok()) {
-          // In real code:
-          // assert(str.size() < sizeof(test_string));
-          std::memcpy(test_string, str.data(), str.size());
-          test_string[str.size()] = '\0';
-        }
-        break;
-    }
-
-    return pw::Status::OK;
-  }
-
-  int32_t test_int32 = 0;
-  int32_t test_sint32 = 0;
-  bool test_bool = false;
-  double test_double = 0;
-  uint32_t test_fixed32 = 0;
-  char test_string[16];
-};
-
 int* volatile non_optimizable_pointer;
 
 int main() {
   pw::bloat::BloatThisBinary();
 
-  pw::protobuf::Decoder decoder;
-  TestDecodeHandler handler;
+  int32_t test_int32, test_sint32;
+  std::string_view str;
+  float f;
+  double d;
 
-  decoder.set_handler(&handler);
-  decoder.Decode(pw::as_bytes(pw::span(encoded_proto)));
+  pw::protobuf::Decoder decoder(pw::as_bytes(pw::span(encoded_proto)));
+  while (decoder.Next().ok()) {
+    switch (decoder.FieldNumber()) {
+      case 1:
+        decoder.ReadInt32(&test_int32);
+        break;
+      case 2:
+        decoder.ReadSint32(&test_sint32);
+        break;
+      case 3:
+        decoder.ReadString(&str);
+        break;
+      case 4:
+        decoder.ReadFloat(&f);
+        break;
+      case 5:
+        decoder.ReadDouble(&d);
+        break;
+    }
+  }
 
-  *non_optimizable_pointer = handler.test_int32 + handler.test_sint32;
+  *non_optimizable_pointer = test_int32 + test_sint32;
 
   return 0;
 }
diff --git a/pw_protobuf/size_report/decoder_incremental.cc b/pw_protobuf/size_report/decoder_incremental.cc
index ef6f2db..064f579 100644
--- a/pw_protobuf/size_report/decoder_incremental.cc
+++ b/pw_protobuf/size_report/decoder_incremental.cc
@@ -28,98 +28,50 @@
 // clang-format on
 }  // namespace
 
-class TestDecodeHandler : public pw::protobuf::DecodeHandler {
- public:
-  pw::Status ProcessField(pw::protobuf::Decoder* decoder,
-                          uint32_t field_number) override {
-    std::string_view str;
-
-    switch (field_number) {
-      case 1:
-        if (!decoder->ReadInt32(field_number, &test_int32).ok()) {
-          test_int32 = 0;
-        }
-        break;
-      case 2:
-        if (!decoder->ReadSint32(field_number, &test_sint32).ok()) {
-          test_sint32 = 0;
-        }
-        break;
-      case 3:
-        if (!decoder->ReadBool(field_number, &test_bool).ok()) {
-          test_bool = false;
-        }
-        break;
-      case 4:
-        if (!decoder->ReadDouble(field_number, &test_double).ok()) {
-          test_double = 0;
-        }
-        break;
-      case 5:
-        if (!decoder->ReadFixed32(field_number, &test_fixed32).ok()) {
-          test_fixed32 = 0;
-        }
-        break;
-      case 6:
-        if (decoder->ReadString(field_number, &str).ok()) {
-          // In real code:
-          // assert(str.size() < sizeof(test_string));
-          std::memcpy(test_string, str.data(), str.size());
-          test_string[str.size()] = '\0';
-        }
-        break;
-
-      // Extra fields.
-      case 21:
-        if (!decoder->ReadInt32(field_number, &test_int32).ok()) {
-          test_int32 = 0;
-        }
-        break;
-      case 22:
-        if (!decoder->ReadInt32(field_number, &test_int32).ok()) {
-          test_int32 = 0;
-        }
-        break;
-      case 23:
-        if (!decoder->ReadInt32(field_number, &test_int32).ok()) {
-          test_int32 = 0;
-        }
-        break;
-      case 24:
-        if (!decoder->ReadSint32(field_number, &test_sint32).ok()) {
-          test_sint32 = 0;
-        }
-        break;
-      case 25:
-        if (!decoder->ReadSint32(field_number, &test_sint32).ok()) {
-          test_sint32 = 0;
-        }
-        break;
-    }
-
-    return pw::Status::OK;
-  }
-
-  int32_t test_int32 = 0;
-  int32_t test_sint32 = 0;
-  bool test_bool = false;
-  double test_double = 0;
-  uint32_t test_fixed32 = 0;
-  char test_string[16];
-};
-
 int* volatile non_optimizable_pointer;
 
 int main() {
   pw::bloat::BloatThisBinary();
 
-  pw::protobuf::Decoder decoder;
-  TestDecodeHandler handler;
+  int32_t test_int32, test_sint32;
+  std::string_view str;
+  float f;
+  double d;
+  uint32_t uint;
 
-  decoder.set_handler(&handler);
-  decoder.Decode(pw::as_bytes(pw::span(encoded_proto)));
+  pw::protobuf::Decoder decoder(pw::as_bytes(pw::span(encoded_proto)));
+  while (decoder.Next().ok()) {
+    switch (decoder.FieldNumber()) {
+      case 1:
+        decoder.ReadInt32(&test_int32);
+        break;
+      case 2:
+        decoder.ReadSint32(&test_sint32);
+        break;
+      case 3:
+        decoder.ReadString(&str);
+        break;
+      case 4:
+        decoder.ReadFloat(&f);
+        break;
+      case 5:
+        decoder.ReadDouble(&d);
+        break;
 
-  *non_optimizable_pointer = handler.test_int32 + handler.test_sint32;
+      // Extra fields over decoder_full.
+      case 21:
+        decoder.ReadInt32(&test_int32);
+        break;
+      case 22:
+        decoder.ReadUint32(&uint);
+        break;
+      case 23:
+        decoder.ReadSint32(&test_sint32);
+        break;
+    }
+  }
+
+  *non_optimizable_pointer = test_int32 + test_sint32;
 
   return 0;
 }