all: do best-effort initialization check on fast path unmarshal

Add a fast check for required fields to the fast path unmarshal.
This is best-effort and will fail to detect some initialized
messages: Messages with more than 64 required fields, messages
split across multiple tags, possibly other cases.

In the cases where it works (which is most of them in practice),
this permits us to skip the IsInitialized check.

Change-Id: I6b70953a333033a5e64fb7ca37a59786cb0f75a0
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/215878
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index 5427317..fc93525 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -5,6 +5,8 @@
 package impl
 
 import (
+	"math/bits"
+
 	"google.golang.org/protobuf/internal/encoding/wire"
 	"google.golang.org/protobuf/internal/errors"
 	"google.golang.org/protobuf/internal/flags"
@@ -58,7 +60,8 @@
 func (o unmarshalOptions) Resolver() preg.ExtensionTypeResolver { return o.resolver }
 
 type unmarshalOutput struct {
-	n int // number of bytes consumed
+	n           int // number of bytes consumed
+	initialized bool
 }
 
 // unmarshal is protoreflect.Methods.Unmarshal.
@@ -69,8 +72,10 @@
 	} else {
 		p = m.(*messageReflectWrapper).pointer()
 	}
-	_, err := mi.unmarshalPointer(in.Buf, p, 0, newUnmarshalOptions(opts))
-	return piface.UnmarshalOutput{}, err
+	out, err := mi.unmarshalPointer(in.Buf, p, 0, newUnmarshalOptions(opts))
+	return piface.UnmarshalOutput{
+		Initialized: out.initialized,
+	}, err
 }
 
 // errUnknown is returned during unmarshaling to indicate a parse error that
@@ -86,6 +91,8 @@
 	if flags.ProtoLegacy && mi.isMessageSet {
 		return unmarshalMessageSet(mi, b, p, opts)
 	}
+	initialized := true
+	var requiredMask uint64
 	var exts *map[int32]ExtensionField
 	start := len(b)
 	for len(b) > 0 {
@@ -104,8 +111,8 @@
 			if num != groupTag {
 				return out, errors.New("mismatching end group marker")
 			}
-			out.n = start - len(b)
-			return out, nil
+			groupTag = 0
+			break
 		}
 
 		var f *coderFieldInfo
@@ -123,6 +130,12 @@
 			var o unmarshalOutput
 			o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, opts)
 			n = o.n
+			if reqi := f.validation.requiredIndex; reqi > 0 && err == nil {
+				requiredMask |= 1 << (reqi - 1)
+			}
+			if f.funcs.isInit != nil && !o.initialized {
+				initialized = false
+			}
 		default:
 			// Possible extension.
 			if exts == nil && mi.extensionOffset.IsValid() {
@@ -137,6 +150,9 @@
 			var o unmarshalOutput
 			o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
 			n = o.n
+			if !o.initialized {
+				initialized = false
+			}
 		}
 		if err != nil {
 			if err != errUnknown {
@@ -157,7 +173,13 @@
 	if groupTag != 0 {
 		return out, errors.New("missing end group marker")
 	}
-	out.n = start
+	if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
+		initialized = false
+	}
+	if initialized {
+		out.initialized = true
+	}
+	out.n = start - len(b)
 	return out, nil
 }