encoding/textpb: add AllowPartial option to MarshalOptions and UnmarshalOptions

Provide AllowPartial option to accept messages with missing required
field during marshaling and unmarshaling.

Change-Id: Ia23783870a8125633f8ddc0b686984b4c5ca15ba
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/169500
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/encoding/textpb/decode.go b/encoding/textpb/decode.go
index bf0d790..b40b05e 100644
--- a/encoding/textpb/decode.go
+++ b/encoding/textpb/decode.go
@@ -18,15 +18,19 @@
 )
 
 // Unmarshal reads the given []byte into the given proto.Message.
-// TODO: may want to describe when Unmarshal returns error.
 func Unmarshal(m proto.Message, b []byte) error {
 	return UnmarshalOptions{}.Unmarshal(m, b)
 }
 
-// UnmarshalOptions is a configurable textproto format parser.
+// UnmarshalOptions is a configurable textproto format unmarshaler.
 type UnmarshalOptions struct {
 	pragma.NoUnkeyedLiterals
 
+	// AllowPartial accepts input for messages that will result in missing
+	// required fields. If AllowPartial is false (the default), Unmarshal will
+	// return error if there are any missing required fields.
+	AllowPartial bool
+
 	// Resolver is the registry used for type lookups when unmarshaling extensions
 	// and processing Any. If Resolver is not set, unmarshaling will default to
 	// using protoregistry.GlobalTypes.
@@ -161,19 +165,21 @@
 			if err := o.unmarshalSingular(tval, fd, knownFields); !nerr.Merge(err) {
 				return err
 			}
-			if cardinality == pref.Required {
+			if !o.AllowPartial && cardinality == pref.Required {
 				reqNums.Set(num)
 			}
 			seenNums.Set(num)
 		}
 	}
 
-	// Check for any missing required fields.
-	allReqNums := msgType.RequiredNumbers()
-	if reqNums.Len() != allReqNums.Len() {
-		for i := 0; i < allReqNums.Len(); i++ {
-			if num := allReqNums.Get(i); !reqNums.Has(uint64(num)) {
-				nerr.AppendRequiredNotSet(string(fieldDescs.ByNumber(num).FullName()))
+	if !o.AllowPartial {
+		// Check for any missing required fields.
+		allReqNums := msgType.RequiredNumbers()
+		if reqNums.Len() != allReqNums.Len() {
+			for i := 0; i < allReqNums.Len(); i++ {
+				if num := allReqNums.Get(i); !reqNums.Has(uint64(num)) {
+					nerr.AppendRequiredNotSet(string(fieldDescs.ByNumber(num).FullName()))
+				}
 			}
 		}
 	}
diff --git a/encoding/textpb/decode_test.go b/encoding/textpb/decode_test.go
index 44f4daf..2f4045c 100644
--- a/encoding/textpb/decode_test.go
+++ b/encoding/textpb/decode_test.go
@@ -952,18 +952,18 @@
 			},
 		},
 	}, {
-		desc:         "proto2 required fields not set",
+		desc:         "required fields not set",
 		inputMessage: &pb2.Requireds{},
 		wantErr:      true,
 	}, {
-		desc:         "proto2 required field set",
+		desc:         "required field set",
 		inputMessage: &pb2.PartialRequired{},
 		inputText:    "req_string: 'this is required'",
 		wantMessage: &pb2.PartialRequired{
 			ReqString: scalar.String("this is required"),
 		},
 	}, {
-		desc:         "proto2 required fields partially set",
+		desc:         "required fields partially set",
 		inputMessage: &pb2.Requireds{},
 		inputText: `
 req_bool: false
@@ -979,7 +979,23 @@
 		},
 		wantErr: true,
 	}, {
-		desc:         "proto2 required fields all set",
+		desc:         "required fields partially set with AllowPartial",
+		umo:          textpb.UnmarshalOptions{AllowPartial: true},
+		inputMessage: &pb2.Requireds{},
+		inputText: `
+req_bool: false
+req_sfixed64: 3203386110
+req_string: "hello"
+req_enum: ONE
+`,
+		wantMessage: &pb2.Requireds{
+			ReqBool:     scalar.Bool(false),
+			ReqSfixed64: scalar.Int64(0xbeefcafe),
+			ReqString:   scalar.String("hello"),
+			ReqEnum:     pb2.Enum_ONE.Enum(),
+		},
+	}, {
+		desc:         "required fields all set",
 		inputMessage: &pb2.Requireds{},
 		inputText: `
 req_bool: false
@@ -1006,6 +1022,14 @@
 		},
 		wantErr: true,
 	}, {
+		desc:         "indirect required field with AllowPartial",
+		umo:          textpb.UnmarshalOptions{AllowPartial: true},
+		inputMessage: &pb2.IndirectRequired{},
+		inputText:    "opt_nested: {}",
+		wantMessage: &pb2.IndirectRequired{
+			OptNested: &pb2.NestedWithRequired{},
+		},
+	}, {
 		desc:         "indirect required field in repeated",
 		inputMessage: &pb2.IndirectRequired{},
 		inputText: `
@@ -1013,9 +1037,6 @@
   req_string: "one"
 }
 rpt_nested: {}
-rpt_nested: {
-  req_string: "three"
-}
 `,
 		wantMessage: &pb2.IndirectRequired{
 			RptNested: []*pb2.NestedWithRequired{
@@ -1023,13 +1044,28 @@
 					ReqString: scalar.String("one"),
 				},
 				{},
-				{
-					ReqString: scalar.String("three"),
-				},
 			},
 		},
 		wantErr: true,
 	}, {
+		desc:         "indirect required field in repeated with AllowPartial",
+		umo:          textpb.UnmarshalOptions{AllowPartial: true},
+		inputMessage: &pb2.IndirectRequired{},
+		inputText: `
+rpt_nested: {
+  req_string: "one"
+}
+rpt_nested: {}
+`,
+		wantMessage: &pb2.IndirectRequired{
+			RptNested: []*pb2.NestedWithRequired{
+				{
+					ReqString: scalar.String("one"),
+				},
+				{},
+			},
+		},
+	}, {
 		desc:         "indirect required field in map",
 		inputMessage: &pb2.IndirectRequired{},
 		inputText: `
@@ -1053,6 +1089,29 @@
 		},
 		wantErr: true,
 	}, {
+		desc:         "indirect required field in map with AllowPartial",
+		umo:          textpb.UnmarshalOptions{AllowPartial: true},
+		inputMessage: &pb2.IndirectRequired{},
+		inputText: `
+str_to_nested: {
+  key: "missing"
+}
+str_to_nested: {
+  key: "contains"
+  value: {
+    req_string: "here"
+  }
+}
+`,
+		wantMessage: &pb2.IndirectRequired{
+			StrToNested: map[string]*pb2.NestedWithRequired{
+				"missing": &pb2.NestedWithRequired{},
+				"contains": &pb2.NestedWithRequired{
+					ReqString: scalar.String("here"),
+				},
+			},
+		},
+	}, {
 		desc:         "indirect required field in oneof",
 		inputMessage: &pb2.IndirectRequired{},
 		inputText: `oneof_nested: {}
@@ -1064,6 +1123,17 @@
 		},
 		wantErr: true,
 	}, {
+		desc:         "indirect required field in oneof with AllowPartial",
+		umo:          textpb.UnmarshalOptions{AllowPartial: true},
+		inputMessage: &pb2.IndirectRequired{},
+		inputText: `oneof_nested: {}
+`,
+		wantMessage: &pb2.IndirectRequired{
+			Union: &pb2.IndirectRequired_OneofNested{
+				OneofNested: &pb2.NestedWithRequired{},
+			},
+		},
+	}, {
 		desc:         "ignore reserved field",
 		inputMessage: &pb2.Nests{},
 		inputText:    "reserved_field: 'ignore this'",
diff --git a/encoding/textpb/encode.go b/encoding/textpb/encode.go
index 93eab31..d94f771 100644
--- a/encoding/textpb/encode.go
+++ b/encoding/textpb/encode.go
@@ -18,7 +18,6 @@
 )
 
 // Marshal writes the given proto.Message in textproto format using default options.
-// TODO: may want to describe when Marshal returns error.
 func Marshal(m proto.Message) ([]byte, error) {
 	return MarshalOptions{}.Marshal(m)
 }
@@ -27,6 +26,11 @@
 type MarshalOptions struct {
 	pragma.NoUnkeyedLiterals
 
+	// AllowPartial allows messages that have missing required fields to marshal
+	// without returning an error. If AllowPartial is false (the default),
+	// Marshal will return error if there are any missing required fields.
+	AllowPartial bool
+
 	// If Indent is a non-empty string, it causes entries for a Message to be
 	// preceded by the indent and trailed by a newline. Indent can only be
 	// composed of space or tab characters.
@@ -85,7 +89,7 @@
 		num := fd.Number()
 
 		if !knownFields.Has(num) {
-			if fd.Cardinality() == pref.Required {
+			if !o.AllowPartial && fd.Cardinality() == pref.Required {
 				// Treat unset required fields as a non-fatal error.
 				nerr.AppendRequiredNotSet(string(fd.FullName()))
 			}
diff --git a/encoding/textpb/encode_test.go b/encoding/textpb/encode_test.go
index 2654e05..d185fa6 100644
--- a/encoding/textpb/encode_test.go
+++ b/encoding/textpb/encode_test.go
@@ -683,12 +683,12 @@
 }
 `,
 	}, {
-		desc:    "proto2 required fields not set",
+		desc:    "required fields not set",
 		input:   &pb2.Requireds{},
 		want:    "\n",
 		wantErr: true,
 	}, {
-		desc: "proto2 required fields partially set",
+		desc: "required fields partially set",
 		input: &pb2.Requireds{
 			ReqBool:     scalar.Bool(false),
 			ReqSfixed64: scalar.Int64(0xbeefcafe),
@@ -704,7 +704,23 @@
 `,
 		wantErr: true,
 	}, {
-		desc: "proto2 required fields all set",
+		desc: "required fields not set with AllowPartial",
+		mo:   textpb.MarshalOptions{AllowPartial: true},
+		input: &pb2.Requireds{
+			ReqBool:     scalar.Bool(false),
+			ReqSfixed64: scalar.Int64(0xbeefcafe),
+			ReqDouble:   scalar.Float64(math.NaN()),
+			ReqString:   scalar.String("hello"),
+			ReqEnum:     pb2.Enum_ONE.Enum(),
+		},
+		want: `req_bool: false
+req_sfixed64: 3203386110
+req_double: nan
+req_string: "hello"
+req_enum: ONE
+`,
+	}, {
+		desc: "required fields all set",
 		input: &pb2.Requireds{
 			ReqBool:     scalar.Bool(false),
 			ReqSfixed64: scalar.Int64(0),
@@ -728,6 +744,13 @@
 		want:    "opt_nested: {}\n",
 		wantErr: true,
 	}, {
+		desc: "indirect required field with AllowPartial",
+		mo:   textpb.MarshalOptions{AllowPartial: true},
+		input: &pb2.IndirectRequired{
+			OptNested: &pb2.NestedWithRequired{},
+		},
+		want: "opt_nested: {}\n",
+	}, {
 		desc: "indirect required field in empty repeated",
 		input: &pb2.IndirectRequired{
 			RptNested: []*pb2.NestedWithRequired{},
@@ -743,6 +766,15 @@
 		want:    "rpt_nested: {}\n",
 		wantErr: true,
 	}, {
+		desc: "indirect required field in repeated with AllowPartial",
+		mo:   textpb.MarshalOptions{AllowPartial: true},
+		input: &pb2.IndirectRequired{
+			RptNested: []*pb2.NestedWithRequired{
+				&pb2.NestedWithRequired{},
+			},
+		},
+		want: "rpt_nested: {}\n",
+	}, {
 		desc: "indirect required field in empty map",
 		input: &pb2.IndirectRequired{
 			StrToNested: map[string]*pb2.NestedWithRequired{},
@@ -762,6 +794,19 @@
 `,
 		wantErr: true,
 	}, {
+		desc: "indirect required field in map with AllowPartial",
+		mo:   textpb.MarshalOptions{AllowPartial: true},
+		input: &pb2.IndirectRequired{
+			StrToNested: map[string]*pb2.NestedWithRequired{
+				"fail": &pb2.NestedWithRequired{},
+			},
+		},
+		want: `str_to_nested: {
+  key: "fail"
+  value: {}
+}
+`,
+	}, {
 		desc: "indirect required field in oneof",
 		input: &pb2.IndirectRequired{
 			Union: &pb2.IndirectRequired_OneofNested{
@@ -771,6 +816,15 @@
 		want:    "oneof_nested: {}\n",
 		wantErr: true,
 	}, {
+		desc: "indirect required field in oneof with AllowPartial",
+		mo:   textpb.MarshalOptions{AllowPartial: true},
+		input: &pb2.IndirectRequired{
+			Union: &pb2.IndirectRequired_OneofNested{
+				OneofNested: &pb2.NestedWithRequired{},
+			},
+		},
+		want: "oneof_nested: {}\n",
+	}, {
 		desc: "unknown varint and fixed types",
 		input: &pb2.Scalars{
 			OptString: scalar.String("this message contains unknown fields"),