encoding/text: unmarshal extensions

Change-Id: I4c82c5089371fa675871529c1d373f7ef28ee1df
Reviewed-on: https://go-review.googlesource.com/c/154937
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/encoding/textpb/decode.go b/encoding/textpb/decode.go
index 7ee73ee..27d7bf1 100644
--- a/encoding/textpb/decode.go
+++ b/encoding/textpb/decode.go
@@ -13,6 +13,7 @@
 	"github.com/golang/protobuf/v2/internal/set"
 	"github.com/golang/protobuf/v2/proto"
 	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
+	"github.com/golang/protobuf/v2/reflect/protoregistry"
 )
 
 // Unmarshal reads the given []byte into the given proto.Message.
@@ -24,6 +25,11 @@
 // UnmarshalOptions is a configurable textproto format parser.
 type UnmarshalOptions struct {
 	pragma.NoUnkeyedLiterals
+
+	// Resolver is the registry used for type lookups when unmarshaling extensions
+	// and processing Any. If Resolver is not set, unmarshaling will default to
+	// using protoregistry.GlobalTypes.
+	Resolver *protoregistry.Types
 }
 
 // Unmarshal reads the given []byte and populates the given proto.Message using options in
@@ -44,6 +50,9 @@
 		return err
 	}
 
+	if o.Resolver == nil {
+		o.Resolver = protoregistry.GlobalTypes
+	}
 	err = o.unmarshalMessage(val.Message(), mr)
 	if !nerr.Merge(err) {
 		return err
@@ -65,7 +74,6 @@
 		unknownFields.Set(num, nil)
 		return true
 	})
-
 	extTypes := knownFields.ExtensionTypes()
 	extTypes.Range(func(xt pref.ExtensionType) bool {
 		extTypes.Remove(xt)
@@ -81,6 +89,7 @@
 	fieldDescs := msgType.Fields()
 	reservedNames := msgType.ReservedNames()
 	knownFields := m.KnownFields()
+	xtTypes := knownFields.ExtensionTypes()
 	var reqNums set.Ints
 	var seenNums set.Ints
 
@@ -89,10 +98,32 @@
 		tval := tfield[1]
 
 		var fd pref.FieldDescriptor
-		name, ok := tkey.Name()
-		if ok {
+		var name pref.Name
+		switch tkey.Type() {
+		case text.Name:
+			name, _ = tkey.Name()
 			fd = fieldDescs.ByName(name)
+		case text.String:
+			// TODO: Handle Any expansions here as well.
+
+			// Handle extensions. Extensions have to be registered first in the message's
+			// ExtensionTypes before setting a value to it.
+			xtName := pref.FullName(tkey.String())
+			// Check first if it is already registered. This is the case for repeated fields.
+			xt := xtTypes.ByName(xtName)
+			if xt == nil {
+				var err error
+				xt, err = o.Resolver.FindExtensionByName(xtName)
+				if err != nil && err != protoregistry.NotFound {
+					return err
+				}
+				if xt != nil {
+					xtTypes.Register(xt)
+				}
+			}
+			fd = xt
 		}
+
 		if fd == nil {
 			// Ignore reserved names.
 			if reservedNames.Has(name) {
diff --git a/encoding/textpb/decode_test.go b/encoding/textpb/decode_test.go
index 8b6025a..00d8215 100644
--- a/encoding/textpb/decode_test.go
+++ b/encoding/textpb/decode_test.go
@@ -9,9 +9,12 @@
 	"testing"
 
 	protoV1 "github.com/golang/protobuf/proto"
+	"github.com/golang/protobuf/protoapi"
 	"github.com/golang/protobuf/v2/encoding/textpb"
+	"github.com/golang/protobuf/v2/internal/legacy"
 	"github.com/golang/protobuf/v2/internal/scalar"
 	"github.com/golang/protobuf/v2/proto"
+	preg "github.com/golang/protobuf/v2/reflect/protoregistry"
 
 	// The legacy package must be imported prior to use of any legacy messages.
 	// TODO: Remove this when protoV1 registers these hooks for you.
@@ -21,6 +24,28 @@
 	"github.com/golang/protobuf/v2/encoding/textpb/testprotos/pb3"
 )
 
+func init() {
+	registerExtension(pb2.E_OptExtBool)
+	registerExtension(pb2.E_OptExtString)
+	registerExtension(pb2.E_OptExtEnum)
+	registerExtension(pb2.E_OptExtNested)
+	registerExtension(pb2.E_RptExtFixed32)
+	registerExtension(pb2.E_RptExtEnum)
+	registerExtension(pb2.E_RptExtNested)
+	registerExtension(pb2.E_ExtensionsContainer_OptExtBool)
+	registerExtension(pb2.E_ExtensionsContainer_OptExtString)
+	registerExtension(pb2.E_ExtensionsContainer_OptExtEnum)
+	registerExtension(pb2.E_ExtensionsContainer_OptExtNested)
+	registerExtension(pb2.E_ExtensionsContainer_RptExtString)
+	registerExtension(pb2.E_ExtensionsContainer_RptExtEnum)
+	registerExtension(pb2.E_ExtensionsContainer_RptExtNested)
+}
+
+func registerExtension(xd *protoapi.ExtensionDesc) {
+	xt := legacy.Export{}.ExtensionTypeFromDesc(xd)
+	preg.GlobalTypes.Register(xt)
+}
+
 func TestUnmarshal(t *testing.T) {
 	tests := []struct {
 		desc         string
@@ -968,6 +993,135 @@
 		inputMessage: &pb2.Nests{},
 		inputText:    "reserved_field: 'ignore this'",
 		wantMessage:  &pb2.Nests{},
+	}, {
+		desc:         "extensions of non-repeated fields",
+		inputMessage: &pb2.Extensions{},
+		inputText: `opt_string: "non-extension field"
+[pb2.opt_ext_bool]: true
+opt_bool: true
+[pb2.opt_ext_nested]: {
+  opt_string: "nested in an extension"
+  opt_nested: {
+    opt_string: "another nested in an extension"
+  }
+}
+[pb2.opt_ext_string]: "extension field"
+opt_int32: 42
+[pb2.opt_ext_enum]: TENTH
+`,
+		wantMessage: func() proto.Message {
+			m := &pb2.Extensions{
+				OptString: scalar.String("non-extension field"),
+				OptBool:   scalar.Bool(true),
+				OptInt32:  scalar.Int32(42),
+			}
+			setExtension(m, pb2.E_OptExtBool, true)
+			setExtension(m, pb2.E_OptExtString, "extension field")
+			setExtension(m, pb2.E_OptExtEnum, pb2.Enum_TENTH)
+			setExtension(m, pb2.E_OptExtNested, &pb2.Nested{
+				OptString: scalar.String("nested in an extension"),
+				OptNested: &pb2.Nested{
+					OptString: scalar.String("another nested in an extension"),
+				},
+			})
+			return m
+		}(),
+	}, {
+		desc:         "extensions of repeated fields",
+		inputMessage: &pb2.Extensions{},
+		inputText: `[pb2.rpt_ext_enum]: TENTH
+[pb2.rpt_ext_enum]: 101
+[pb2.rpt_ext_fixed32]: 42
+[pb2.rpt_ext_enum]: FIRST
+[pb2.rpt_ext_nested]: {
+  opt_string: "one"
+}
+[pb2.rpt_ext_nested]: {
+  opt_string: "two"
+}
+[pb2.rpt_ext_fixed32]: 47
+[pb2.rpt_ext_nested]: {
+  opt_string: "three"
+}
+`,
+		wantMessage: func() proto.Message {
+			m := &pb2.Extensions{}
+			setExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TENTH, 101, pb2.Enum_FIRST})
+			setExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47})
+			setExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{
+				&pb2.Nested{OptString: scalar.String("one")},
+				&pb2.Nested{OptString: scalar.String("two")},
+				&pb2.Nested{OptString: scalar.String("three")},
+			})
+			return m
+		}(),
+	}, {
+		desc:         "extensions of non-repeated fields in another message",
+		inputMessage: &pb2.Extensions{},
+		inputText: `[pb2.ExtensionsContainer.opt_ext_bool]: true
+[pb2.ExtensionsContainer.opt_ext_enum]: TENTH
+[pb2.ExtensionsContainer.opt_ext_nested]: {
+  opt_string: "nested in an extension"
+  opt_nested: {
+    opt_string: "another nested in an extension"
+  }
+}
+[pb2.ExtensionsContainer.opt_ext_string]: "extension field"
+`,
+		wantMessage: func() proto.Message {
+			m := &pb2.Extensions{}
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true)
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field")
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TENTH)
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{
+				OptString: scalar.String("nested in an extension"),
+				OptNested: &pb2.Nested{
+					OptString: scalar.String("another nested in an extension"),
+				},
+			})
+			return m
+		}(),
+	}, {
+		desc:         "extensions of repeated fields in another message",
+		inputMessage: &pb2.Extensions{},
+		inputText: `opt_string: "non-extension field"
+opt_bool: true
+opt_int32: 42
+[pb2.ExtensionsContainer.rpt_ext_nested]: {
+  opt_string: "one"
+}
+[pb2.ExtensionsContainer.rpt_ext_enum]: TENTH
+[pb2.ExtensionsContainer.rpt_ext_nested]: {
+  opt_string: "two"
+}
+[pb2.ExtensionsContainer.rpt_ext_enum]: 101
+[pb2.ExtensionsContainer.rpt_ext_string]: "hello"
+[pb2.ExtensionsContainer.rpt_ext_enum]: FIRST
+[pb2.ExtensionsContainer.rpt_ext_nested]: {
+  opt_string: "three"
+}
+[pb2.ExtensionsContainer.rpt_ext_string]: "world"
+`,
+		wantMessage: func() proto.Message {
+			m := &pb2.Extensions{
+				OptString: scalar.String("non-extension field"),
+				OptBool:   scalar.Bool(true),
+				OptInt32:  scalar.Int32(42),
+			}
+			setExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TENTH, 101, pb2.Enum_FIRST})
+			setExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"})
+			setExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{
+				&pb2.Nested{OptString: scalar.String("one")},
+				&pb2.Nested{OptString: scalar.String("two")},
+				&pb2.Nested{OptString: scalar.String("three")},
+			})
+			return m
+		}(),
+	}, {
+		desc:         "invalid extension field name",
+		inputMessage: &pb2.Extensions{},
+		inputText:    "[pb2.invalid_message_field]: true",
+		wantErr:      true,
 	}}
 
 	for _, tt := range tests {