internal/impl: refactor validation a bit

Return the size of the field read from the validator, permitting us to
avoid an extra parse when skipping over groups.

Return an UnmarshalOutput from the validator, since it already combines
two of the validator outputs: bytes read and initialization status.

Remove initialization status from the ValidationStatus enum, since it's
covered by the UnmarshalOutput.

Change-Id: I3e684c45d15aa1992d8dc3bde0f608880d34a94b
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/217763
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index 290fc41..3155bc5 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -196,11 +196,9 @@
 	}
 	if flags.LazyUnmarshalExtensions {
 		if opts.IsDefault() && x.canLazy(xt) {
-			if n, ok := skipExtension(b, xi, num, wtyp, opts); ok {
-				x.appendLazyBytes(xt, xi, num, wtyp, b[:n])
+			if out, ok := skipExtension(b, xi, num, wtyp, opts); ok && out.initialized {
+				x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
 				exts[int32(num)] = x
-				out.n = n
-				out.initialized = true
 				return out, nil
 			}
 		}
@@ -224,35 +222,31 @@
 	return out, nil
 }
 
-func skipExtension(b []byte, xi *extensionFieldInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (n int, ok bool) {
+func skipExtension(b []byte, xi *extensionFieldInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, ok bool) {
 	if xi.validation.mi == nil {
-		return 0, false
+		return out, false
 	}
 	xi.validation.mi.init()
 	var v []byte
 	switch xi.validation.typ {
 	case validationTypeMessage:
 		if wtyp != wire.BytesType {
-			return 0, false
+			return out, false
 		}
-		v, n = wire.ConsumeBytes(b)
+		v, n := wire.ConsumeBytes(b)
 		if n < 0 {
-			return 0, false
+			return out, false
 		}
+		out, st := xi.validation.mi.validate(v, 0, opts)
+		out.n = n
+		return out, st == ValidationValid
 	case validationTypeGroup:
 		if wtyp != wire.StartGroupType {
-			return 0, false
+			return out, false
 		}
-		v, n = wire.ConsumeGroup(num, b)
-		if n < 0 {
-			return 0, false
-		}
+		out, st := xi.validation.mi.validate(v, num, opts)
+		return out, st == ValidationValid
 	default:
-		return 0, false
+		return out, false
 	}
-	if xi.validation.mi.validate(v, 0, opts) != ValidationValidInitialized {
-		return 0, false
-	}
-	return n, true
-
 }