internal/impl: lazy extension decoding

Historically, extensions have been placed in the unknown fields section
of the unmarshaled message and decoded lazily on demand. The current
unmarshal implementation decodes extensions eagerly at unmarshal time,
permitting errors to be immediately reported and correctly detecting
unset required fields in extension values.

Add support for validated lazy extension decoding, where the extension
value is fully validated at initial unmarshal time but the fully
unmarshaled message is only created lazily.

Make this behavior conditional on the protolegacy flag for now.

Change-Id: I9d742496a4bd4dafea83fca8619cd6e8d7e65bc3
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/216764
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index cbc21b3..48f7ca5 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -29,6 +29,12 @@
 
 func (o unmarshalOptions) DiscardUnknown() bool { return o.Flags&piface.UnmarshalDiscardUnknown != 0 }
 
+func (o unmarshalOptions) IsDefault() bool {
+	// The UnmarshalDefaultResolver flag indicates that we're using the default resolver.
+	// No other flag bit should be set.
+	return o.Flags == piface.UnmarshalDefaultResolver
+}
+
 type unmarshalOutput struct {
 	n           int // number of bytes consumed
 	initialized bool
@@ -185,6 +191,17 @@
 	if xi.funcs.unmarshal == nil {
 		return out, errUnknown
 	}
+	if flags.LazyUnmarshalExtensions {
+		if opts.IsDefault() && x.canLazy(xt) {
+			if n, ok := skipExtension(b, xi, num, wtyp, opts); ok {
+				x.appendLazyBytes(xt, xi, num, wtyp, b[:n])
+				exts[int32(num)] = x
+				out.n = n
+				out.initialized = true
+				return out, nil
+			}
+		}
+	}
 	ival := x.Value()
 	if !ival.IsValid() && xi.unmarshalNeedsValue {
 		// Create a new message, list, or map value to fill in.
@@ -200,3 +217,36 @@
 	exts[int32(num)] = x
 	return out, nil
 }
+
+func skipExtension(b []byte, xi *extensionFieldInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (n int, ok bool) {
+	if xi.validation.mi == nil {
+		return 0, false
+	}
+	xi.validation.mi.init()
+	var v []byte
+	switch xi.validation.typ {
+	case validationTypeMessage:
+		if wtyp != wire.BytesType {
+			return 0, false
+		}
+		v, n = wire.ConsumeBytes(b)
+		if n < 0 {
+			return 0, false
+		}
+	case validationTypeGroup:
+		if wtyp != wire.StartGroupType {
+			return 0, false
+		}
+		v, n = wire.ConsumeGroup(num, b)
+		if n < 0 {
+			return 0, false
+		}
+	default:
+		return 0, false
+	}
+	if xi.validation.mi.validate(v, 0, opts) != ValidationValidInitialized {
+		return 0, false
+	}
+	return n, true
+
+}