proto: check for required fields in encoding/decoding
Change-Id: I0555a92e0399782f075b1dcd248e880dd48c7d6d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/170579
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/proto/decode.go b/proto/decode.go
index 4928ace..2b871c4 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -5,9 +5,8 @@
package proto
import (
- "errors"
-
"github.com/golang/protobuf/v2/internal/encoding/wire"
+ "github.com/golang/protobuf/v2/internal/errors"
"github.com/golang/protobuf/v2/internal/pragma"
"github.com/golang/protobuf/v2/reflect/protoreflect"
"github.com/golang/protobuf/v2/runtime/protoiface"
@@ -18,6 +17,11 @@
// Example usage:
// err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
type UnmarshalOptions struct {
+ // AllowPartial accepts input for messages that will result in missing
+ // required fields. If AllowPartial is false (the default), Unmarshal will
+ // return an error if there are any missing required fields.
+ AllowPartial bool
+
// If DiscardUnknown is set, unknown fields are ignored.
DiscardUnknown bool
@@ -60,6 +64,7 @@
fieldTypes := messageType.Fields()
knownFields := m.KnownFields()
unknownFields := m.UnknownFields()
+ var nerr errors.NonFatal
for len(b) > 0 {
// Parse the tag (field number and wire type).
num, wtyp, tagLen := wire.ConsumeTag(b)
@@ -90,13 +95,15 @@
return wire.ParseError(valLen)
}
unknownFields.Set(num, append(unknownFields.Get(num), b[:tagLen+valLen]...))
- } else if err != nil {
+ } else if !nerr.Merge(err) {
return err
}
b = b[tagLen+valLen:]
}
- // TODO: required field checks
- return nil
+ if !o.AllowPartial {
+ checkRequiredFields(m, &nerr)
+ }
+ return nerr.E
}
func (o UnmarshalOptions) unmarshalScalarField(b []byte, wtyp wire.Type, num wire.Number, knownFields protoreflect.KnownFields, field protoreflect.FieldDescriptor) (n int, err error) {
@@ -118,14 +125,13 @@
m = knownFields.NewMessage(num)
knownFields.Set(num, protoreflect.ValueOf(m))
}
- if err := o.unmarshalMessage(v.Bytes(), m); err != nil {
- return 0, err
- }
+ // Pass up errors (fatal and otherwise).
+ err = o.unmarshalMessage(v.Bytes(), m)
default:
// Non-message scalars replace the previous value.
knownFields.Set(num, v)
}
- return n, nil
+ return n, err
}
func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number, mapv protoreflect.Map, field protoreflect.FieldDescriptor) (n int, err error) {
@@ -150,6 +156,7 @@
}
// Map entries are represented as a two-element message with fields
// containing the key and value.
+ var nerr errors.NonFatal
for len(b) > 0 {
num, wtyp, n := wire.ConsumeTag(b)
if n < 0 {
@@ -172,7 +179,7 @@
}
switch valField.Kind() {
case protoreflect.GroupKind, protoreflect.MessageKind:
- if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
+ if err := o.unmarshalMessage(v.Bytes(), val.Message()); !nerr.Merge(err) {
return 0, err
}
default:
@@ -185,7 +192,7 @@
if n < 0 {
return 0, wire.ParseError(n)
}
- } else if err != nil {
+ } else if !nerr.Merge(err) {
return 0, err
}
b = b[n:]
@@ -197,19 +204,18 @@
if !haveVal {
switch valField.Kind() {
case protoreflect.GroupKind, protoreflect.MessageKind:
- // Trigger required field checks by unmarshaling an empty message.
- if err := o.unmarshalMessage(nil, val.Message()); err != nil {
- return 0, err
+ if !o.AllowPartial {
+ checkRequiredFields(val.Message(), &nerr)
}
default:
val = valField.Default()
}
}
mapv.Set(key.MapKey(), val)
- return n, nil
+ return n, nerr.E
}
// errUnknown is used internally to indicate fields which should be added
// to the unknown field set of a message. It is never returned from an exported
// function.
-var errUnknown = errors.New("unknown")
+var errUnknown = errors.New("BUG: internal error (unknown)")