proto: check for required fields in encoding/decoding

Change-Id: I0555a92e0399782f075b1dcd248e880dd48c7d6d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/170579
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/proto/encode.go b/proto/encode.go
index 973e53a..adf3de4 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -9,6 +9,7 @@
 	"sort"
 
 	"github.com/golang/protobuf/v2/internal/encoding/wire"
+	"github.com/golang/protobuf/v2/internal/errors"
 	"github.com/golang/protobuf/v2/internal/mapsort"
 	"github.com/golang/protobuf/v2/internal/pragma"
 	"github.com/golang/protobuf/v2/reflect/protoreflect"
@@ -20,6 +21,11 @@
 // Example usage:
 //   b, err := MarshalOptions{Deterministic: true}.Marshal(m)
 type MarshalOptions struct {
+	// AllowPartial allows messages that have missing required fields to marshal
+	// without returning an error. If AllowPartial is false (the default),
+	// Marshal will return an error if there are any missing required fields.
+	AllowPartial bool
+
 	// Deterministic controls whether the same message will always be
 	// serialized to the same bytes within the same binary.
 	//
@@ -100,6 +106,7 @@
 	fields := m.Type().Fields()
 	knownFields := m.KnownFields()
 	var err error
+	var nerr errors.NonFatal
 	o.rangeKnown(knownFields, func(num protoreflect.FieldNumber, value protoreflect.Value) bool {
 		field := fields.ByNumber(num)
 		if field == nil {
@@ -109,17 +116,23 @@
 			}
 		}
 		b, err = o.marshalField(b, field, value)
-		return err == nil
+		if nerr.Merge(err) {
+			err = nil
+			return true
+		}
+		return false
 	})
 	if err != nil {
-		return nil, err
+		return b, err
 	}
 	m.UnknownFields().Range(func(_ protoreflect.FieldNumber, raw protoreflect.RawFields) bool {
 		b = append(b, raw...)
 		return true
 	})
-	// TODO: required field checks
-	return b, nil
+	if !o.AllowPartial {
+		checkRequiredFields(m, &nerr)
+	}
+	return b, nerr.E
 }
 
 // rangeKnown visits known fields in field number order when deterministic
@@ -163,6 +176,7 @@
 func (o MarshalOptions) marshalMap(b []byte, num wire.Number, kind protoreflect.Kind, mdesc protoreflect.MessageDescriptor, mapv protoreflect.Map) ([]byte, error) {
 	keyf := mdesc.Fields().ByNumber(1)
 	valf := mdesc.Fields().ByNumber(2)
+	var nerr errors.NonFatal
 	var err error
 	o.rangeMap(mapv, keyf.Kind(), func(key protoreflect.MapKey, value protoreflect.Value) bool {
 		b = wire.AppendTag(b, num, wire.BytesType)
@@ -170,21 +184,22 @@
 		b, pos = appendSpeculativeLength(b)
 
 		b, err = o.marshalField(b, keyf, key.Value())
-		if err != nil {
+		if !nerr.Merge(err) {
 			return false
 		}
 		b, err = o.marshalField(b, valf, value)
-		if err != nil {
+		if !nerr.Merge(err) {
 			return false
 		}
+		err = nil
 
 		b = finishSpeculativeLength(b, pos)
 		return true
 	})
 	if err != nil {
-		return nil, err
+		return b, err
 	}
-	return b, nil
+	return b, nerr.E
 }
 
 func (o MarshalOptions) rangeMap(mapv protoreflect.Map, kind protoreflect.Kind, f func(protoreflect.MapKey, protoreflect.Value) bool) {
@@ -198,27 +213,29 @@
 func (o MarshalOptions) marshalPacked(b []byte, num wire.Number, kind protoreflect.Kind, list protoreflect.List) ([]byte, error) {
 	b = wire.AppendTag(b, num, wire.BytesType)
 	b, pos := appendSpeculativeLength(b)
+	var nerr errors.NonFatal
 	for i, llen := 0, list.Len(); i < llen; i++ {
 		var err error
 		b, err = o.marshalSingular(b, num, kind, list.Get(i))
-		if err != nil {
-			return nil, err
+		if !nerr.Merge(err) {
+			return b, err
 		}
 	}
 	b = finishSpeculativeLength(b, pos)
-	return b, nil
+	return b, nerr.E
 }
 
 func (o MarshalOptions) marshalList(b []byte, num wire.Number, kind protoreflect.Kind, list protoreflect.List) ([]byte, error) {
+	var nerr errors.NonFatal
 	for i, llen := 0, list.Len(); i < llen; i++ {
 		var err error
 		b = wire.AppendTag(b, num, wireTypes[kind])
 		b, err = o.marshalSingular(b, num, kind, list.Get(i))
-		if err != nil {
-			return nil, err
+		if !nerr.Merge(err) {
+			return b, err
 		}
 	}
-	return b, nil
+	return b, nerr.E
 }
 
 // When encoding length-prefixed fields, we speculatively set aside some number of bytes