goprotobuf: Make several changes to ErrRequiredNotSet:
- Report the full or partial path to the first missing required field (where possible) instead of the message name.
- Make it ignorable. Unmarshal and Marshal will continue to decode/encode the full proto.
R=r
CC=golang-dev
https://codereview.appspot.com/13248047
diff --git a/proto/encode.go b/proto/encode.go
index 9d592cd..d49ab84 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -37,6 +37,7 @@
import (
"errors"
+ "fmt"
"reflect"
"sort"
)
@@ -46,12 +47,16 @@
// all been initialized. It is also the error returned if Unmarshal is
// called with an encoded protocol buffer that does not include all the
// required fields.
+//
+// When printed, ErrRequiredNotSet reports the first unset required field in a
+// message. If the field cannot be precisely determined, it is reported as
+// "{Unknown}".
type ErrRequiredNotSet struct {
- t reflect.Type
+ field string
}
func (e *ErrRequiredNotSet) Error() string {
- return "proto: required fields not set in " + e.t.String()
+ return fmt.Sprintf("proto: required field %q not set", e.field)
}
var (
@@ -175,7 +180,8 @@
}
p := NewBuffer(nil)
err := p.Marshal(pb)
- if err != nil {
+ var state errorState
+ if err != nil && !state.shouldContinue(err, nil) {
return nil, err
}
return p.buf, err
@@ -274,6 +280,7 @@
// Encode a message struct.
func (o *Buffer) enc_struct_message(p *Properties, base structPointer) error {
+ var state errorState
structp := structPointer_GetStructPointer(base, p.field)
if structPointer_IsNil(structp) {
return ErrNil
@@ -283,7 +290,7 @@
if p.isMarshaler {
m := structPointer_Interface(structp, p.stype).(Marshaler)
data, err := m.Marshal()
- if err != nil {
+ if err != nil && !state.shouldContinue(err, nil) {
return err
}
o.buf = append(o.buf, p.tagcode...)
@@ -300,18 +307,19 @@
nbuf := o.buf
o.buf = obuf
- if err != nil {
+ if err != nil && !state.shouldContinue(err, nil) {
o.buffree(nbuf)
return err
}
o.buf = append(o.buf, p.tagcode...)
o.EncodeRawBytes(nbuf)
o.buffree(nbuf)
- return nil
+ return state.err
}
// Encode a group struct.
func (o *Buffer) enc_struct_group(p *Properties, base structPointer) error {
+ var state errorState
b := structPointer_GetStructPointer(base, p.field)
if structPointer_IsNil(b) {
return ErrNil
@@ -319,11 +327,11 @@
o.EncodeVarint(uint64((p.Tag << 3) | WireStartGroup))
err := o.enc_struct(p.stype, p.sprop, b)
- if err != nil {
+ if err != nil && !state.shouldContinue(err, nil) {
return err
}
o.EncodeVarint(uint64((p.Tag << 3) | WireEndGroup))
- return nil
+ return state.err
}
// Encode a slice of bools ([]bool).
@@ -470,6 +478,7 @@
// Encode a slice of message structs ([]*struct).
func (o *Buffer) enc_slice_struct_message(p *Properties, base structPointer) error {
+ var state errorState
s := structPointer_StructPointerSlice(base, p.field)
l := s.Len()
@@ -483,7 +492,7 @@
if p.isMarshaler {
m := structPointer_Interface(structp, p.stype).(Marshaler)
data, err := m.Marshal()
- if err != nil {
+ if err != nil && !state.shouldContinue(err, nil) {
return err
}
o.buf = append(o.buf, p.tagcode...)
@@ -498,7 +507,7 @@
nbuf := o.buf
o.buf = obuf
- if err != nil {
+ if err != nil && !state.shouldContinue(err, nil) {
o.buffree(nbuf)
if err == ErrNil {
return ErrRepeatedHasNil
@@ -510,11 +519,12 @@
o.buffree(nbuf)
}
- return nil
+ return state.err
}
// Encode a slice of group structs ([]*struct).
func (o *Buffer) enc_slice_struct_group(p *Properties, base structPointer) error {
+ var state errorState
s := structPointer_StructPointerSlice(base, p.field)
l := s.Len()
@@ -528,7 +538,7 @@
err := o.enc_struct(p.stype, p.sprop, b)
- if err != nil {
+ if err != nil && !state.shouldContinue(err, nil) {
if err == ErrNil {
return ErrRepeatedHasNil
}
@@ -537,7 +547,7 @@
o.EncodeVarint(uint64((p.Tag << 3) | WireEndGroup))
}
- return nil
+ return state.err
}
// Encode an extension map.
@@ -569,7 +579,7 @@
// Encode a struct.
func (o *Buffer) enc_struct(t reflect.Type, prop *StructProperties, base structPointer) error {
- required := prop.reqCount
+ var state errorState
// Encode fields in tag order so that decoders may use optimizations
// that depend on the ordering.
// http://code.google.com/apis/protocolbuffers/docs/encoding.html#order
@@ -577,19 +587,15 @@
p := prop.Prop[i]
if p.enc != nil {
err := p.enc(o, p, base)
- if err != nil {
+ if err != nil && !state.shouldContinue(err, p) {
if err != ErrNil {
return err
+ } else if p.Required && state.err == nil {
+ state.err = &ErrRequiredNotSet{p.Name}
}
- } else if p.Required {
- required--
}
}
}
- // See if we encoded all required fields.
- if required > 0 {
- return &ErrRequiredNotSet{t}
- }
// Add unrecognized fields at the end.
if prop.unrecField.IsValid() {
@@ -599,5 +605,33 @@
}
}
- return nil
+ return state.err
+}
+
+// errorState maintains the first error that occurs and updates that error
+// with additional context.
+type errorState struct {
+ err error
+}
+
+// shouldContinue reports whether encoding should continue upon encountering the
+// given error. If the error is ErrRequiredNotSet, shouldContinue returns true
+// and, if this is the first appearance of that error, remembers it for future
+// reporting.
+//
+// If prop is not nil, it may update any error with additional context about the
+// field with the error.
+func (s *errorState) shouldContinue(err error, prop *Properties) bool {
+ // Ignore unset required fields.
+ reqNotSet, ok := err.(*ErrRequiredNotSet)
+ if !ok {
+ return false
+ }
+ if s.err == nil {
+ if prop != nil {
+ err = &ErrRequiredNotSet{prop.Name + "." + reqNotSet.field}
+ }
+ s.err = err
+ }
+ return true
}