proto: check for required fields in encoding/decoding
Change-Id: I0555a92e0399782f075b1dcd248e880dd48c7d6d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/170579
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/proto/encode.go b/proto/encode.go
index 973e53a..adf3de4 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -9,6 +9,7 @@
"sort"
"github.com/golang/protobuf/v2/internal/encoding/wire"
+ "github.com/golang/protobuf/v2/internal/errors"
"github.com/golang/protobuf/v2/internal/mapsort"
"github.com/golang/protobuf/v2/internal/pragma"
"github.com/golang/protobuf/v2/reflect/protoreflect"
@@ -20,6 +21,11 @@
// Example usage:
// b, err := MarshalOptions{Deterministic: true}.Marshal(m)
type MarshalOptions struct {
+ // AllowPartial allows messages that have missing required fields to marshal
+ // without returning an error. If AllowPartial is false (the default),
+ // Marshal will return an error if there are any missing required fields.
+ AllowPartial bool
+
// Deterministic controls whether the same message will always be
// serialized to the same bytes within the same binary.
//
@@ -100,6 +106,7 @@
fields := m.Type().Fields()
knownFields := m.KnownFields()
var err error
+ var nerr errors.NonFatal
o.rangeKnown(knownFields, func(num protoreflect.FieldNumber, value protoreflect.Value) bool {
field := fields.ByNumber(num)
if field == nil {
@@ -109,17 +116,23 @@
}
}
b, err = o.marshalField(b, field, value)
- return err == nil
+ if nerr.Merge(err) {
+ err = nil
+ return true
+ }
+ return false
})
if err != nil {
- return nil, err
+ return b, err
}
m.UnknownFields().Range(func(_ protoreflect.FieldNumber, raw protoreflect.RawFields) bool {
b = append(b, raw...)
return true
})
- // TODO: required field checks
- return b, nil
+ if !o.AllowPartial {
+ checkRequiredFields(m, &nerr)
+ }
+ return b, nerr.E
}
// rangeKnown visits known fields in field number order when deterministic
@@ -163,6 +176,7 @@
func (o MarshalOptions) marshalMap(b []byte, num wire.Number, kind protoreflect.Kind, mdesc protoreflect.MessageDescriptor, mapv protoreflect.Map) ([]byte, error) {
keyf := mdesc.Fields().ByNumber(1)
valf := mdesc.Fields().ByNumber(2)
+ var nerr errors.NonFatal
var err error
o.rangeMap(mapv, keyf.Kind(), func(key protoreflect.MapKey, value protoreflect.Value) bool {
b = wire.AppendTag(b, num, wire.BytesType)
@@ -170,21 +184,22 @@
b, pos = appendSpeculativeLength(b)
b, err = o.marshalField(b, keyf, key.Value())
- if err != nil {
+ if !nerr.Merge(err) {
return false
}
b, err = o.marshalField(b, valf, value)
- if err != nil {
+ if !nerr.Merge(err) {
return false
}
+ err = nil
b = finishSpeculativeLength(b, pos)
return true
})
if err != nil {
- return nil, err
+ return b, err
}
- return b, nil
+ return b, nerr.E
}
func (o MarshalOptions) rangeMap(mapv protoreflect.Map, kind protoreflect.Kind, f func(protoreflect.MapKey, protoreflect.Value) bool) {
@@ -198,27 +213,29 @@
func (o MarshalOptions) marshalPacked(b []byte, num wire.Number, kind protoreflect.Kind, list protoreflect.List) ([]byte, error) {
b = wire.AppendTag(b, num, wire.BytesType)
b, pos := appendSpeculativeLength(b)
+ var nerr errors.NonFatal
for i, llen := 0, list.Len(); i < llen; i++ {
var err error
b, err = o.marshalSingular(b, num, kind, list.Get(i))
- if err != nil {
- return nil, err
+ if !nerr.Merge(err) {
+ return b, err
}
}
b = finishSpeculativeLength(b, pos)
- return b, nil
+ return b, nerr.E
}
func (o MarshalOptions) marshalList(b []byte, num wire.Number, kind protoreflect.Kind, list protoreflect.List) ([]byte, error) {
+ var nerr errors.NonFatal
for i, llen := 0, list.Len(); i < llen; i++ {
var err error
b = wire.AppendTag(b, num, wireTypes[kind])
b, err = o.marshalSingular(b, num, kind, list.Get(i))
- if err != nil {
- return nil, err
+ if !nerr.Merge(err) {
+ return b, err
}
}
- return b, nil
+ return b, nerr.E
}
// When encoding length-prefixed fields, we speculatively set aside some number of bytes