encoding/textpb: add AllowPartial option to MarshalOptions and UnmarshalOptions
Provide AllowPartial option to accept messages with missing required
field during marshaling and unmarshaling.
Change-Id: Ia23783870a8125633f8ddc0b686984b4c5ca15ba
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/169500
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/encoding/textpb/decode.go b/encoding/textpb/decode.go
index bf0d790..b40b05e 100644
--- a/encoding/textpb/decode.go
+++ b/encoding/textpb/decode.go
@@ -18,15 +18,19 @@
)
// Unmarshal reads the given []byte into the given proto.Message.
-// TODO: may want to describe when Unmarshal returns error.
func Unmarshal(m proto.Message, b []byte) error {
return UnmarshalOptions{}.Unmarshal(m, b)
}
-// UnmarshalOptions is a configurable textproto format parser.
+// UnmarshalOptions is a configurable textproto format unmarshaler.
type UnmarshalOptions struct {
pragma.NoUnkeyedLiterals
+ // AllowPartial accepts input for messages that will result in missing
+ // required fields. If AllowPartial is false (the default), Unmarshal will
+ // return error if there are any missing required fields.
+ AllowPartial bool
+
// Resolver is the registry used for type lookups when unmarshaling extensions
// and processing Any. If Resolver is not set, unmarshaling will default to
// using protoregistry.GlobalTypes.
@@ -161,19 +165,21 @@
if err := o.unmarshalSingular(tval, fd, knownFields); !nerr.Merge(err) {
return err
}
- if cardinality == pref.Required {
+ if !o.AllowPartial && cardinality == pref.Required {
reqNums.Set(num)
}
seenNums.Set(num)
}
}
- // Check for any missing required fields.
- allReqNums := msgType.RequiredNumbers()
- if reqNums.Len() != allReqNums.Len() {
- for i := 0; i < allReqNums.Len(); i++ {
- if num := allReqNums.Get(i); !reqNums.Has(uint64(num)) {
- nerr.AppendRequiredNotSet(string(fieldDescs.ByNumber(num).FullName()))
+ if !o.AllowPartial {
+ // Check for any missing required fields.
+ allReqNums := msgType.RequiredNumbers()
+ if reqNums.Len() != allReqNums.Len() {
+ for i := 0; i < allReqNums.Len(); i++ {
+ if num := allReqNums.Get(i); !reqNums.Has(uint64(num)) {
+ nerr.AppendRequiredNotSet(string(fieldDescs.ByNumber(num).FullName()))
+ }
}
}
}
diff --git a/encoding/textpb/decode_test.go b/encoding/textpb/decode_test.go
index 44f4daf..2f4045c 100644
--- a/encoding/textpb/decode_test.go
+++ b/encoding/textpb/decode_test.go
@@ -952,18 +952,18 @@
},
},
}, {
- desc: "proto2 required fields not set",
+ desc: "required fields not set",
inputMessage: &pb2.Requireds{},
wantErr: true,
}, {
- desc: "proto2 required field set",
+ desc: "required field set",
inputMessage: &pb2.PartialRequired{},
inputText: "req_string: 'this is required'",
wantMessage: &pb2.PartialRequired{
ReqString: scalar.String("this is required"),
},
}, {
- desc: "proto2 required fields partially set",
+ desc: "required fields partially set",
inputMessage: &pb2.Requireds{},
inputText: `
req_bool: false
@@ -979,7 +979,23 @@
},
wantErr: true,
}, {
- desc: "proto2 required fields all set",
+ desc: "required fields partially set with AllowPartial",
+ umo: textpb.UnmarshalOptions{AllowPartial: true},
+ inputMessage: &pb2.Requireds{},
+ inputText: `
+req_bool: false
+req_sfixed64: 3203386110
+req_string: "hello"
+req_enum: ONE
+`,
+ wantMessage: &pb2.Requireds{
+ ReqBool: scalar.Bool(false),
+ ReqSfixed64: scalar.Int64(0xbeefcafe),
+ ReqString: scalar.String("hello"),
+ ReqEnum: pb2.Enum_ONE.Enum(),
+ },
+ }, {
+ desc: "required fields all set",
inputMessage: &pb2.Requireds{},
inputText: `
req_bool: false
@@ -1006,6 +1022,14 @@
},
wantErr: true,
}, {
+ desc: "indirect required field with AllowPartial",
+ umo: textpb.UnmarshalOptions{AllowPartial: true},
+ inputMessage: &pb2.IndirectRequired{},
+ inputText: "opt_nested: {}",
+ wantMessage: &pb2.IndirectRequired{
+ OptNested: &pb2.NestedWithRequired{},
+ },
+ }, {
desc: "indirect required field in repeated",
inputMessage: &pb2.IndirectRequired{},
inputText: `
@@ -1013,9 +1037,6 @@
req_string: "one"
}
rpt_nested: {}
-rpt_nested: {
- req_string: "three"
-}
`,
wantMessage: &pb2.IndirectRequired{
RptNested: []*pb2.NestedWithRequired{
@@ -1023,13 +1044,28 @@
ReqString: scalar.String("one"),
},
{},
- {
- ReqString: scalar.String("three"),
- },
},
},
wantErr: true,
}, {
+ desc: "indirect required field in repeated with AllowPartial",
+ umo: textpb.UnmarshalOptions{AllowPartial: true},
+ inputMessage: &pb2.IndirectRequired{},
+ inputText: `
+rpt_nested: {
+ req_string: "one"
+}
+rpt_nested: {}
+`,
+ wantMessage: &pb2.IndirectRequired{
+ RptNested: []*pb2.NestedWithRequired{
+ {
+ ReqString: scalar.String("one"),
+ },
+ {},
+ },
+ },
+ }, {
desc: "indirect required field in map",
inputMessage: &pb2.IndirectRequired{},
inputText: `
@@ -1053,6 +1089,29 @@
},
wantErr: true,
}, {
+ desc: "indirect required field in map with AllowPartial",
+ umo: textpb.UnmarshalOptions{AllowPartial: true},
+ inputMessage: &pb2.IndirectRequired{},
+ inputText: `
+str_to_nested: {
+ key: "missing"
+}
+str_to_nested: {
+ key: "contains"
+ value: {
+ req_string: "here"
+ }
+}
+`,
+ wantMessage: &pb2.IndirectRequired{
+ StrToNested: map[string]*pb2.NestedWithRequired{
+ "missing": &pb2.NestedWithRequired{},
+ "contains": &pb2.NestedWithRequired{
+ ReqString: scalar.String("here"),
+ },
+ },
+ },
+ }, {
desc: "indirect required field in oneof",
inputMessage: &pb2.IndirectRequired{},
inputText: `oneof_nested: {}
@@ -1064,6 +1123,17 @@
},
wantErr: true,
}, {
+ desc: "indirect required field in oneof with AllowPartial",
+ umo: textpb.UnmarshalOptions{AllowPartial: true},
+ inputMessage: &pb2.IndirectRequired{},
+ inputText: `oneof_nested: {}
+`,
+ wantMessage: &pb2.IndirectRequired{
+ Union: &pb2.IndirectRequired_OneofNested{
+ OneofNested: &pb2.NestedWithRequired{},
+ },
+ },
+ }, {
desc: "ignore reserved field",
inputMessage: &pb2.Nests{},
inputText: "reserved_field: 'ignore this'",
diff --git a/encoding/textpb/encode.go b/encoding/textpb/encode.go
index 93eab31..d94f771 100644
--- a/encoding/textpb/encode.go
+++ b/encoding/textpb/encode.go
@@ -18,7 +18,6 @@
)
// Marshal writes the given proto.Message in textproto format using default options.
-// TODO: may want to describe when Marshal returns error.
func Marshal(m proto.Message) ([]byte, error) {
return MarshalOptions{}.Marshal(m)
}
@@ -27,6 +26,11 @@
type MarshalOptions struct {
pragma.NoUnkeyedLiterals
+ // AllowPartial allows messages that have missing required fields to marshal
+ // without returning an error. If AllowPartial is false (the default),
+ // Marshal will return error if there are any missing required fields.
+ AllowPartial bool
+
// If Indent is a non-empty string, it causes entries for a Message to be
// preceded by the indent and trailed by a newline. Indent can only be
// composed of space or tab characters.
@@ -85,7 +89,7 @@
num := fd.Number()
if !knownFields.Has(num) {
- if fd.Cardinality() == pref.Required {
+ if !o.AllowPartial && fd.Cardinality() == pref.Required {
// Treat unset required fields as a non-fatal error.
nerr.AppendRequiredNotSet(string(fd.FullName()))
}
diff --git a/encoding/textpb/encode_test.go b/encoding/textpb/encode_test.go
index 2654e05..d185fa6 100644
--- a/encoding/textpb/encode_test.go
+++ b/encoding/textpb/encode_test.go
@@ -683,12 +683,12 @@
}
`,
}, {
- desc: "proto2 required fields not set",
+ desc: "required fields not set",
input: &pb2.Requireds{},
want: "\n",
wantErr: true,
}, {
- desc: "proto2 required fields partially set",
+ desc: "required fields partially set",
input: &pb2.Requireds{
ReqBool: scalar.Bool(false),
ReqSfixed64: scalar.Int64(0xbeefcafe),
@@ -704,7 +704,23 @@
`,
wantErr: true,
}, {
- desc: "proto2 required fields all set",
+ desc: "required fields not set with AllowPartial",
+ mo: textpb.MarshalOptions{AllowPartial: true},
+ input: &pb2.Requireds{
+ ReqBool: scalar.Bool(false),
+ ReqSfixed64: scalar.Int64(0xbeefcafe),
+ ReqDouble: scalar.Float64(math.NaN()),
+ ReqString: scalar.String("hello"),
+ ReqEnum: pb2.Enum_ONE.Enum(),
+ },
+ want: `req_bool: false
+req_sfixed64: 3203386110
+req_double: nan
+req_string: "hello"
+req_enum: ONE
+`,
+ }, {
+ desc: "required fields all set",
input: &pb2.Requireds{
ReqBool: scalar.Bool(false),
ReqSfixed64: scalar.Int64(0),
@@ -728,6 +744,13 @@
want: "opt_nested: {}\n",
wantErr: true,
}, {
+ desc: "indirect required field with AllowPartial",
+ mo: textpb.MarshalOptions{AllowPartial: true},
+ input: &pb2.IndirectRequired{
+ OptNested: &pb2.NestedWithRequired{},
+ },
+ want: "opt_nested: {}\n",
+ }, {
desc: "indirect required field in empty repeated",
input: &pb2.IndirectRequired{
RptNested: []*pb2.NestedWithRequired{},
@@ -743,6 +766,15 @@
want: "rpt_nested: {}\n",
wantErr: true,
}, {
+ desc: "indirect required field in repeated with AllowPartial",
+ mo: textpb.MarshalOptions{AllowPartial: true},
+ input: &pb2.IndirectRequired{
+ RptNested: []*pb2.NestedWithRequired{
+ &pb2.NestedWithRequired{},
+ },
+ },
+ want: "rpt_nested: {}\n",
+ }, {
desc: "indirect required field in empty map",
input: &pb2.IndirectRequired{
StrToNested: map[string]*pb2.NestedWithRequired{},
@@ -762,6 +794,19 @@
`,
wantErr: true,
}, {
+ desc: "indirect required field in map with AllowPartial",
+ mo: textpb.MarshalOptions{AllowPartial: true},
+ input: &pb2.IndirectRequired{
+ StrToNested: map[string]*pb2.NestedWithRequired{
+ "fail": &pb2.NestedWithRequired{},
+ },
+ },
+ want: `str_to_nested: {
+ key: "fail"
+ value: {}
+}
+`,
+ }, {
desc: "indirect required field in oneof",
input: &pb2.IndirectRequired{
Union: &pb2.IndirectRequired_OneofNested{
@@ -771,6 +816,15 @@
want: "oneof_nested: {}\n",
wantErr: true,
}, {
+ desc: "indirect required field in oneof with AllowPartial",
+ mo: textpb.MarshalOptions{AllowPartial: true},
+ input: &pb2.IndirectRequired{
+ Union: &pb2.IndirectRequired_OneofNested{
+ OneofNested: &pb2.NestedWithRequired{},
+ },
+ },
+ want: "oneof_nested: {}\n",
+ }, {
desc: "unknown varint and fixed types",
input: &pb2.Scalars{
OptString: scalar.String("this message contains unknown fields"),