pw_protobuf: Reimplement decoder as field iterator

This changes the pw::protobuf::Decoder API to expose functions to
iterate over fields instead of using a virtual callback interface,
making the core decoder simpler and more flexible.

The virtual callback interface is kept but renamed to a CallbackDecoder,
and reimplemented in terms of basic Decoder.

Change-Id: Idff321cd5e37184aa730251475c9e336136596d2
diff --git a/pw_protobuf/decoder.cc b/pw_protobuf/decoder.cc
index 54d4675..3183556 100644
--- a/pw_protobuf/decoder.cc
+++ b/pw_protobuf/decoder.cc
@@ -20,75 +20,178 @@
 
 namespace pw::protobuf {
 
-Status Decoder::Decode(span<const std::byte> proto) {
-  if (handler_ == nullptr || state_ != kReady) {
-    return Status::FAILED_PRECONDITION;
-  }
-
-  state_ = kDecodeInProgress;
-  proto_ = proto;
-
-  // Iterate over each field in the proto, calling the handler with the field
-  // key.
-  while (state_ == kDecodeInProgress && !proto_.empty()) {
-    const std::byte* original_cursor = proto_.data();
-
-    uint64_t key;
-    size_t bytes_read = varint::Decode(proto_, &key);
-    if (bytes_read == 0) {
-      state_ = kDecodeFailed;
-      return Status::DATA_LOSS;
-    }
-
-    uint32_t field_number = key >> kFieldNumberShift;
-    Status status = handler_->ProcessField(this, field_number);
-    if (!status.ok()) {
-      state_ = status == Status::CANCELLED ? kDecodeCancelled : kDecodeFailed;
+Status Decoder::Next() {
+  if (!previous_field_consumed_) {
+    if (Status status = SkipField(); !status.ok()) {
       return status;
     }
+  }
+  if (proto_.empty()) {
+    return Status::OUT_OF_RANGE;
+  }
+  previous_field_consumed_ = false;
+  return FieldSize() == 0 ? Status::DATA_LOSS : Status::OK;
+}
 
-    // The callback function can modify the decoder's state; check that
-    // everything is still okay.
-    if (state_ == kDecodeFailed) {
-      break;
-    }
-
-    // If the cursor has not moved, the user has not consumed the field in their
-    // callback. Skip ahead to the next field.
-    if (original_cursor == proto_.data()) {
-      SkipField();
-    }
+Status Decoder::SkipField() {
+  if (proto_.empty()) {
+    return Status::OUT_OF_RANGE;
   }
 
-  if (state_ != kDecodeInProgress) {
+  size_t bytes_to_skip = FieldSize();
+  if (bytes_to_skip == 0) {
     return Status::DATA_LOSS;
   }
 
-  state_ = kReady;
+  proto_ = proto_.subspan(bytes_to_skip);
+  return proto_.empty() ? Status::OUT_OF_RANGE : Status::OK;
+}
+
+uint32_t Decoder::FieldNumber() const {
+  uint64_t key;
+  varint::Decode(proto_, &key);
+  return key >> kFieldNumberShift;
+}
+
+Status Decoder::ReadUint32(uint32_t* out) {
+  uint64_t value = 0;
+  Status status = ReadUint64(&value);
+  if (!status.ok()) {
+    return status;
+  }
+  if (value > std::numeric_limits<uint32_t>::max()) {
+    return Status::OUT_OF_RANGE;
+  }
+  *out = value;
   return Status::OK;
 }
 
-Status Decoder::ReadVarint(uint32_t field_number, uint64_t* out) {
-  Status status = ConsumeKey(field_number, WireType::kVarint);
+Status Decoder::ReadSint32(int32_t* out) {
+  int64_t value = 0;
+  Status status = ReadSint64(&value);
   if (!status.ok()) {
     return status;
   }
+  if (value > std::numeric_limits<int32_t>::max()) {
+    return Status::OUT_OF_RANGE;
+  }
+  *out = value;
+  return Status::OK;
+}
+
+Status Decoder::ReadSint64(int64_t* out) {
+  uint64_t value = 0;
+  Status status = ReadUint64(&value);
+  if (!status.ok()) {
+    return status;
+  }
+  *out = varint::ZigZagDecode(value);
+  return Status::OK;
+}
+
+Status Decoder::ReadBool(bool* out) {
+  uint64_t value = 0;
+  Status status = ReadUint64(&value);
+  if (!status.ok()) {
+    return status;
+  }
+  *out = value;
+  return Status::OK;
+}
+
+Status Decoder::ReadString(std::string_view* out) {
+  span<const std::byte> bytes;
+  Status status = ReadDelimited(&bytes);
+  if (!status.ok()) {
+    return status;
+  }
+  *out = std::string_view(reinterpret_cast<const char*>(bytes.data()),
+                          bytes.size());
+  return Status::OK;
+}
+
+size_t Decoder::FieldSize() const {
+  uint64_t key;
+  size_t key_size = varint::Decode(proto_, &key);
+  if (key_size == 0) {
+    return 0;
+  }
+
+  span<const std::byte> remainder = proto_.subspan(key_size);
+  WireType wire_type = static_cast<WireType>(key & kWireTypeMask);
+  uint64_t value = 0;
+  size_t expected_size = 0;
+
+  switch (wire_type) {
+    case WireType::kVarint:
+      expected_size = varint::Decode(remainder, &value);
+      if (expected_size == 0) {
+        return 0;
+      }
+      break;
+
+    case WireType::kDelimited:
+      // Varint at cursor indicates size of the field.
+      expected_size = varint::Decode(remainder, &value);
+      if (expected_size == 0) {
+        return 0;
+      }
+      expected_size += value;
+      break;
+
+    case WireType::kFixed32:
+      expected_size = sizeof(uint32_t);
+      break;
+
+    case WireType::kFixed64:
+      expected_size = sizeof(uint64_t);
+      break;
+  }
+
+  if (remainder.size() < expected_size) {
+    return 0;
+  }
+
+  return key_size + expected_size;
+}
+
+Status Decoder::ConsumeKey(WireType expected_type) {
+  uint64_t key;
+  size_t bytes_read = varint::Decode(proto_, &key);
+  if (bytes_read == 0) {
+    return Status::FAILED_PRECONDITION;
+  }
+
+  WireType wire_type = static_cast<WireType>(key & kWireTypeMask);
+  if (wire_type != expected_type) {
+    return Status::FAILED_PRECONDITION;
+  }
+
+  // Advance past the key.
+  proto_ = proto_.subspan(bytes_read);
+  return Status::OK;
+}
+
+Status Decoder::ReadVarint(uint64_t* out) {
+  if (Status status = ConsumeKey(WireType::kVarint); !status.ok()) {
+    return status;
+  }
 
   size_t bytes_read = varint::Decode(proto_, out);
   if (bytes_read == 0) {
-    state_ = kDecodeFailed;
     return Status::DATA_LOSS;
   }
 
   // Advance to the next field.
   proto_ = proto_.subspan(bytes_read);
+  previous_field_consumed_ = true;
   return Status::OK;
 }
 
-Status Decoder::ReadFixed(uint32_t field_number, std::byte* out, size_t size) {
+Status Decoder::ReadFixed(std::byte* out, size_t size) {
   WireType expected_wire_type =
       size == sizeof(uint32_t) ? WireType::kFixed32 : WireType::kFixed64;
-  Status status = ConsumeKey(field_number, expected_wire_type);
+  Status status = ConsumeKey(expected_wire_type);
   if (!status.ok()) {
     return status;
   }
@@ -99,13 +202,13 @@
 
   std::memcpy(out, proto_.data(), size);
   proto_ = proto_.subspan(size);
+  previous_field_consumed_ = true;
 
   return Status::OK;
 }
 
-Status Decoder::ReadDelimited(uint32_t field_number,
-                              span<const std::byte>* out) {
-  Status status = ConsumeKey(field_number, WireType::kDelimited);
+Status Decoder::ReadDelimited(span<const std::byte>* out) {
+  Status status = ConsumeKey(WireType::kDelimited);
   if (!status.ok()) {
     return status;
   }
@@ -113,80 +216,60 @@
   uint64_t length;
   size_t bytes_read = varint::Decode(proto_, &length);
   if (bytes_read == 0) {
-    state_ = kDecodeFailed;
     return Status::DATA_LOSS;
   }
 
   proto_ = proto_.subspan(bytes_read);
   if (proto_.size() < length) {
-    state_ = kDecodeFailed;
     return Status::DATA_LOSS;
   }
 
   *out = proto_.first(length);
   proto_ = proto_.subspan(length);
+  previous_field_consumed_ = true;
 
   return Status::OK;
 }
 
-Status Decoder::ConsumeKey(uint32_t field_number, WireType expected_type) {
+Status CallbackDecoder::Decode(span<const std::byte> proto) {
+  if (handler_ == nullptr || state_ != kReady) {
+    return Status::FAILED_PRECONDITION;
+  }
+
+  state_ = kDecodeInProgress;
+  decoder_.Reset(proto);
+
+  // Iterate the proto, calling the handler with each field number.
+  while (state_ == kDecodeInProgress) {
+    if (Status status = decoder_.Next(); !status.ok()) {
+      if (status == Status::OUT_OF_RANGE) {
+        // Reached the end of the proto.
+        break;
+      }
+
+      // Proto data is malformed.
+      return status;
+    }
+
+    Status status = handler_->ProcessField(*this, decoder_.FieldNumber());
+    if (!status.ok()) {
+      state_ = status == Status::CANCELLED ? kDecodeCancelled : kDecodeFailed;
+      return status;
+    }
+
+    // The callback function can modify the decoder's state; check that
+    // everything is still okay.
+    if (state_ == kDecodeFailed) {
+      break;
+    }
+  }
+
   if (state_ != kDecodeInProgress) {
-    return Status::FAILED_PRECONDITION;
+    return Status::DATA_LOSS;
   }
 
-  uint64_t key;
-  size_t bytes_read = varint::Decode(proto_, &key);
-  if (bytes_read == 0) {
-    state_ = kDecodeFailed;
-    return Status::FAILED_PRECONDITION;
-  }
-
-  uint32_t field = key >> kFieldNumberShift;
-  WireType wire_type = static_cast<WireType>(key & kWireTypeMask);
-
-  if (field != field_number || wire_type != expected_type) {
-    state_ = kDecodeFailed;
-    return Status::FAILED_PRECONDITION;
-  }
-
-  // Advance past the key.
-  proto_ = proto_.subspan(bytes_read);
+  state_ = kReady;
   return Status::OK;
 }
 
-void Decoder::SkipField() {
-  uint64_t key;
-  proto_ = proto_.subspan(varint::Decode(proto_, &key));
-
-  WireType wire_type = static_cast<WireType>(key & kWireTypeMask);
-  size_t bytes_to_skip = 0;
-  uint64_t value = 0;
-
-  switch (wire_type) {
-    case WireType::kVarint:
-      bytes_to_skip = varint::Decode(proto_, &value);
-      break;
-
-    case WireType::kDelimited:
-      // Varint at cursor indicates size of the field.
-      bytes_to_skip += varint::Decode(proto_, &value);
-      bytes_to_skip += value;
-      break;
-
-    case WireType::kFixed32:
-      bytes_to_skip = sizeof(uint32_t);
-      break;
-
-    case WireType::kFixed64:
-      bytes_to_skip = sizeof(uint64_t);
-      break;
-  }
-
-  if (bytes_to_skip == 0) {
-    state_ = kDecodeFailed;
-  } else {
-    proto_ = proto_.subspan(bytes_to_skip);
-  }
-}
-
 }  // namespace pw::protobuf