proto: check for required fields in encoding/decoding

Change-Id: I0555a92e0399782f075b1dcd248e880dd48c7d6d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/170579
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/proto/decode.go b/proto/decode.go
index 4928ace..2b871c4 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -5,9 +5,8 @@
 package proto
 
 import (
-	"errors"
-
 	"github.com/golang/protobuf/v2/internal/encoding/wire"
+	"github.com/golang/protobuf/v2/internal/errors"
 	"github.com/golang/protobuf/v2/internal/pragma"
 	"github.com/golang/protobuf/v2/reflect/protoreflect"
 	"github.com/golang/protobuf/v2/runtime/protoiface"
@@ -18,6 +17,11 @@
 // Example usage:
 //   err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
 type UnmarshalOptions struct {
+	// AllowPartial accepts input for messages that will result in missing
+	// required fields. If AllowPartial is false (the default), Unmarshal will
+	// return an error if there are any missing required fields.
+	AllowPartial bool
+
 	// If DiscardUnknown is set, unknown fields are ignored.
 	DiscardUnknown bool
 
@@ -60,6 +64,7 @@
 	fieldTypes := messageType.Fields()
 	knownFields := m.KnownFields()
 	unknownFields := m.UnknownFields()
+	var nerr errors.NonFatal
 	for len(b) > 0 {
 		// Parse the tag (field number and wire type).
 		num, wtyp, tagLen := wire.ConsumeTag(b)
@@ -90,13 +95,15 @@
 				return wire.ParseError(valLen)
 			}
 			unknownFields.Set(num, append(unknownFields.Get(num), b[:tagLen+valLen]...))
-		} else if err != nil {
+		} else if !nerr.Merge(err) {
 			return err
 		}
 		b = b[tagLen+valLen:]
 	}
-	// TODO: required field checks
-	return nil
+	if !o.AllowPartial {
+		checkRequiredFields(m, &nerr)
+	}
+	return nerr.E
 }
 
 func (o UnmarshalOptions) unmarshalScalarField(b []byte, wtyp wire.Type, num wire.Number, knownFields protoreflect.KnownFields, field protoreflect.FieldDescriptor) (n int, err error) {
@@ -118,14 +125,13 @@
 			m = knownFields.NewMessage(num)
 			knownFields.Set(num, protoreflect.ValueOf(m))
 		}
-		if err := o.unmarshalMessage(v.Bytes(), m); err != nil {
-			return 0, err
-		}
+		// Pass up errors (fatal and otherwise).
+		err = o.unmarshalMessage(v.Bytes(), m)
 	default:
 		// Non-message scalars replace the previous value.
 		knownFields.Set(num, v)
 	}
-	return n, nil
+	return n, err
 }
 
 func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number, mapv protoreflect.Map, field protoreflect.FieldDescriptor) (n int, err error) {
@@ -150,6 +156,7 @@
 	}
 	// Map entries are represented as a two-element message with fields
 	// containing the key and value.
+	var nerr errors.NonFatal
 	for len(b) > 0 {
 		num, wtyp, n := wire.ConsumeTag(b)
 		if n < 0 {
@@ -172,7 +179,7 @@
 			}
 			switch valField.Kind() {
 			case protoreflect.GroupKind, protoreflect.MessageKind:
-				if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
+				if err := o.unmarshalMessage(v.Bytes(), val.Message()); !nerr.Merge(err) {
 					return 0, err
 				}
 			default:
@@ -185,7 +192,7 @@
 			if n < 0 {
 				return 0, wire.ParseError(n)
 			}
-		} else if err != nil {
+		} else if !nerr.Merge(err) {
 			return 0, err
 		}
 		b = b[n:]
@@ -197,19 +204,18 @@
 	if !haveVal {
 		switch valField.Kind() {
 		case protoreflect.GroupKind, protoreflect.MessageKind:
-			// Trigger required field checks by unmarshaling an empty message.
-			if err := o.unmarshalMessage(nil, val.Message()); err != nil {
-				return 0, err
+			if !o.AllowPartial {
+				checkRequiredFields(val.Message(), &nerr)
 			}
 		default:
 			val = valField.Default()
 		}
 	}
 	mapv.Set(key.MapKey(), val)
-	return n, nil
+	return n, nerr.E
 }
 
 // errUnknown is used internally to indicate fields which should be added
 // to the unknown field set of a message. It is never returned from an exported
 // function.
-var errUnknown = errors.New("unknown")
+var errUnknown = errors.New("BUG: internal error (unknown)")