internal/impl: refactor validation a bit

Return the size of the field read from the validator, permitting us to
avoid an extra parse when skipping over groups.

Return an UnmarshalOutput from the validator, since it already combines
two of the validator outputs: bytes read and initialization status.

Remove initialization status from the ValidationStatus enum, since it's
covered by the UnmarshalOutput.

Change-Id: I3e684c45d15aa1992d8dc3bde0f608880d34a94b
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/217763
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/impl/validate.go b/internal/impl/validate.go
index eab8ec0..0c32026 100644
--- a/internal/impl/validate.go
+++ b/internal/impl/validate.go
@@ -33,16 +33,8 @@
 	// ValidationInvalid indicates that unmarshaling the message will fail.
 	ValidationInvalid
 
-	// ValidationValidInitialized indicates that unmarshaling the message will succeed
-	// and IsInitialized on the result will report success.
-	ValidationValidInitialized
-
-	// ValidationValidMaybeUninitalized indicates unmarshaling the message will succeed,
-	// but the output of IsInitialized on the result is unknown.
-	//
-	// This status may be returned for an initialized message when a message value
-	// is split across multiple fields.
-	ValidationValidMaybeUninitalized
+	// ValidationValid indicates that unmarshaling the message will succeed.
+	ValidationValid
 )
 
 func (v ValidationStatus) String() string {
@@ -51,10 +43,8 @@
 		return "ValidationUnknown"
 	case ValidationInvalid:
 		return "ValidationInvalid"
-	case ValidationValidInitialized:
-		return "ValidationValidInitialized"
-	case ValidationValidMaybeUninitalized:
-		return "ValidationValidMaybeUninitalized"
+	case ValidationValid:
+		return "ValidationValid"
 	default:
 		return fmt.Sprintf("ValidationStatus(%d)", int(v))
 	}
@@ -64,12 +54,14 @@
 // of the message type.
 //
 // This function is exposed for testing.
-func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) ValidationStatus {
+func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) (out piface.UnmarshalOutput, _ ValidationStatus) {
 	mi, ok := mt.(*MessageInfo)
 	if !ok {
-		return ValidationUnknown
+		return out, ValidationUnknown
 	}
-	return mi.validate(b, 0, unmarshalOptions(opts))
+	o, st := mi.validate(b, 0, unmarshalOptions(opts))
+	out.Initialized = o.initialized
+	return out, st
 }
 
 type validationInfo struct {
@@ -219,7 +211,7 @@
 	return vi
 }
 
-func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (result ValidationStatus) {
+func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
 	mi.init()
 	type validationState struct {
 		typ              validationType
@@ -241,12 +233,13 @@
 		states[0].endGroup = groupTag
 	}
 	initialized := true
+	start := len(b)
 State:
 	for len(states) > 0 {
 		st := &states[len(states)-1]
 		if st.mi != nil {
 			if flags.ProtoLegacy && st.mi.isMessageSet {
-				return ValidationUnknown
+				return out, ValidationUnknown
 			}
 		}
 		for len(b) > 0 {
@@ -262,13 +255,13 @@
 				var n int
 				tag, n = wire.ConsumeVarint(b)
 				if n < 0 {
-					return ValidationInvalid
+					return out, ValidationInvalid
 				}
 				b = b[n:]
 			}
 			var num wire.Number
 			if n := tag >> 3; n < uint64(wire.MinValidNumber) || n > uint64(wire.MaxValidNumber) {
-				return ValidationInvalid
+				return out, ValidationInvalid
 			} else {
 				num = wire.Number(n)
 			}
@@ -278,7 +271,7 @@
 				if st.endGroup == num {
 					goto PopState
 				}
-				return ValidationInvalid
+				return out, ValidationInvalid
 			}
 			var vi validationInfo
 			switch st.typ {
@@ -317,7 +310,7 @@
 						case preg.NotFound:
 							vi.typ = validationTypeBytes
 						default:
-							return ValidationUnknown
+							return out, ValidationUnknown
 						}
 					}
 					break
@@ -332,7 +325,7 @@
 				// determine if the resolver is frozen.
 				xt, err := opts.Resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
 				if err != nil && err != preg.NotFound {
-					return ValidationUnknown
+					return out, ValidationUnknown
 				}
 				if err == nil {
 					vi = getExtensionFieldInfo(xt).validation
@@ -383,7 +376,7 @@
 					case b[9] < 0x80 && b[9] < 2:
 						b = b[10:]
 					default:
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 				} else {
 					switch {
@@ -408,7 +401,7 @@
 					case len(b) > 9 && b[9] < 2:
 						b = b[10:]
 					default:
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 				}
 				continue State
@@ -424,19 +417,19 @@
 					var n int
 					size, n = wire.ConsumeVarint(b)
 					if n < 0 {
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 					b = b[n:]
 				}
 				if size > uint64(len(b)) {
-					return ValidationInvalid
+					return out, ValidationInvalid
 				}
 				v := b[:size]
 				b = b[size:]
 				switch vi.typ {
 				case validationTypeMessage:
 					if vi.mi == nil {
-						return ValidationUnknown
+						return out, ValidationUnknown
 					}
 					vi.mi.init()
 					fallthrough
@@ -455,40 +448,40 @@
 					for len(v) > 0 {
 						_, n := wire.ConsumeVarint(v)
 						if n < 0 {
-							return ValidationInvalid
+							return out, ValidationInvalid
 						}
 						v = v[n:]
 					}
 				case validationTypeRepeatedFixed32:
 					// Packed field.
 					if len(v)%4 != 0 {
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 				case validationTypeRepeatedFixed64:
 					// Packed field.
 					if len(v)%8 != 0 {
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 				case validationTypeUTF8String:
 					if !utf8.Valid(v) {
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 				}
 			case wire.Fixed32Type:
 				if len(b) < 4 {
-					return ValidationInvalid
+					return out, ValidationInvalid
 				}
 				b = b[4:]
 			case wire.Fixed64Type:
 				if len(b) < 8 {
-					return ValidationInvalid
+					return out, ValidationInvalid
 				}
 				b = b[8:]
 			case wire.StartGroupType:
 				switch vi.typ {
 				case validationTypeGroup:
 					if vi.mi == nil {
-						return ValidationUnknown
+						return out, ValidationUnknown
 					}
 					vi.mi.init()
 					states = append(states, validationState{
@@ -500,19 +493,19 @@
 				default:
 					n := wire.ConsumeFieldValue(num, wtyp, b)
 					if n < 0 {
-						return ValidationInvalid
+						return out, ValidationInvalid
 					}
 					b = b[n:]
 				}
 			default:
-				return ValidationInvalid
+				return out, ValidationInvalid
 			}
 		}
 		if st.endGroup != 0 {
-			return ValidationInvalid
+			return out, ValidationInvalid
 		}
 		if len(b) != 0 {
-			return ValidationInvalid
+			return out, ValidationInvalid
 		}
 		b = st.tail
 	PopState:
@@ -535,8 +528,9 @@
 		}
 		states = states[:len(states)-1]
 	}
-	if !initialized {
-		return ValidationValidMaybeUninitalized
+	out.n = start - len(b)
+	if initialized {
+		out.initialized = true
 	}
-	return ValidationValidInitialized
+	return out, ValidationValid
 }