encoding: unify MessageSet extension handling logic

This CL unifies common MessageSet logic in prototext and protojson
into the messageset package. While we are at it, also enable
MessageSet support only if the proto1_legacy build flag is enabled.

Change-Id: I1a7d475e8bb1dad61ecd286df45e4239e5bef072
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/185898
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/encoding/prototext/decode.go b/encoding/prototext/decode.go
index 7874410..4f384a0 100644
--- a/encoding/prototext/decode.go
+++ b/encoding/prototext/decode.go
@@ -9,9 +9,11 @@
 	"strings"
 	"unicode/utf8"
 
+	"google.golang.org/protobuf/internal/encoding/messageset"
 	"google.golang.org/protobuf/internal/encoding/text"
 	"google.golang.org/protobuf/internal/errors"
 	"google.golang.org/protobuf/internal/fieldnum"
+	"google.golang.org/protobuf/internal/flags"
 	"google.golang.org/protobuf/internal/pragma"
 	"google.golang.org/protobuf/internal/set"
 	"google.golang.org/protobuf/proto"
@@ -74,17 +76,18 @@
 // unmarshalMessage unmarshals a [][2]text.Value message into the given protoreflect.Message.
 func (o UnmarshalOptions) unmarshalMessage(tmsg [][2]text.Value, m pref.Message) error {
 	messageDesc := m.Descriptor()
+	if !flags.Proto1Legacy && messageset.IsMessageSet(messageDesc) {
+		return errors.New("no support for proto1 MessageSets")
+	}
 
 	// Handle expanded Any message.
 	if messageDesc.FullName() == "google.protobuf.Any" && isExpandedAny(tmsg) {
 		return o.unmarshalAny(tmsg[0], m)
 	}
 
-	fieldDescs := messageDesc.Fields()
-	reservedNames := messageDesc.ReservedNames()
 	var seenNums set.Ints
 	var seenOneofs set.Ints
-
+	fieldDescs := messageDesc.Fields()
 	for _, tfield := range tmsg {
 		tkey := tfield[0]
 		tval := tfield[1]
@@ -128,7 +131,7 @@
 
 		if fd == nil {
 			// Ignore reserved names.
-			if reservedNames.Has(name) {
+			if messageDesc.ReservedNames().Has(name) {
 				continue
 			}
 			// TODO: Can provide option to ignore unknown message fields.
@@ -193,13 +196,7 @@
 	if err == nil {
 		return xt, nil
 	}
-
-	// Check if this is a MessageSet extension field.
-	xt, err = o.Resolver.FindExtensionByName(xtName + ".message_set_extension")
-	if err == nil && isMessageSetExtension(xt) {
-		return xt, nil
-	}
-	return nil, protoregistry.NotFound
+	return messageset.FindMessageSetExtension(o.Resolver, xtName)
 }
 
 // unmarshalSingular unmarshals given text.Value into the non-repeated field.
diff --git a/encoding/prototext/decode_test.go b/encoding/prototext/decode_test.go
index 31de4d7..20ce133 100644
--- a/encoding/prototext/decode_test.go
+++ b/encoding/prototext/decode_test.go
@@ -1310,6 +1310,7 @@
 			})
 			return m
 		}(),
+		skip: !flags.Proto1Legacy,
 	}, {
 		desc:         "not real MessageSet 1",
 		inputMessage: &pb2.FakeMessageSet{},
@@ -1325,6 +1326,7 @@
 			})
 			return m
 		}(),
+		skip: !flags.Proto1Legacy,
 	}, {
 		desc:         "not real MessageSet 2",
 		inputMessage: &pb2.FakeMessageSet{},
@@ -1334,6 +1336,7 @@
 }
 `,
 		wantErr: true,
+		skip:    !flags.Proto1Legacy,
 	}, {
 		desc:         "not real MessageSet 3",
 		inputMessage: &pb2.MessageSet{},
@@ -1348,6 +1351,7 @@
 			})
 			return m
 		}(),
+		skip: !flags.Proto1Legacy,
 	}, {
 		desc:         "Any not expanded",
 		inputMessage: &anypb.Any{},
diff --git a/encoding/prototext/encode.go b/encoding/prototext/encode.go
index 7d244a9..5d8cd46 100644
--- a/encoding/prototext/encode.go
+++ b/encoding/prototext/encode.go
@@ -9,10 +9,12 @@
 	"sort"
 	"unicode/utf8"
 
+	"google.golang.org/protobuf/internal/encoding/messageset"
 	"google.golang.org/protobuf/internal/encoding/text"
 	"google.golang.org/protobuf/internal/encoding/wire"
 	"google.golang.org/protobuf/internal/errors"
 	"google.golang.org/protobuf/internal/fieldnum"
+	"google.golang.org/protobuf/internal/flags"
 	"google.golang.org/protobuf/internal/mapsort"
 	"google.golang.org/protobuf/internal/pragma"
 	"google.golang.org/protobuf/proto"
@@ -72,8 +74,10 @@
 
 // marshalMessage converts a protoreflect.Message to a text.Value.
 func (o MarshalOptions) marshalMessage(m pref.Message) (text.Value, error) {
-	var msgFields [][2]text.Value
 	messageDesc := m.Descriptor()
+	if !flags.Proto1Legacy && messageset.IsMessageSet(messageDesc) {
+		return text.Value{}, errors.New("no support for proto1 MessageSets")
+	}
 
 	// Handle Any expansion.
 	if messageDesc.FullName() == "google.protobuf.Any" {
@@ -85,6 +89,7 @@
 	}
 
 	// Handle known fields.
+	var msgFields [][2]text.Value
 	fieldDescs := messageDesc.Fields()
 	size := fieldDescs.Len()
 	for i := 0; i < size; i++ {
@@ -253,10 +258,10 @@
 			return true
 		}
 
-		// If extended type is a MessageSet, set field name to be the message type name.
+		// For MessageSet extensions, the name used is the parent message.
 		name := fd.FullName()
-		if isMessageSetExtension(fd) {
-			name = fd.Message().FullName()
+		if messageset.IsMessageSetExtension(fd) {
+			name = name.Parent()
 		}
 
 		// Use string type to produce [name] format.
@@ -279,22 +284,6 @@
 	return append(msgFields, entries...), nil
 }
 
-// isMessageSetExtension reports whether extension extends a message set.
-func isMessageSetExtension(fd pref.FieldDescriptor) bool {
-	if fd.Name() != "message_set_extension" {
-		return false
-	}
-	md := fd.Message()
-	if md == nil {
-		return false
-	}
-	if fd.FullName().Parent() != md.FullName() {
-		return false
-	}
-	xmd, ok := fd.ContainingMessage().(interface{ IsMessageSet() bool })
-	return ok && xmd.IsMessageSet()
-}
-
 // appendUnknown parses the given []byte and appends field(s) into the given fields slice.
 // This function assumes proper encoding in the given []byte.
 func appendUnknown(fields [][2]text.Value, b []byte) [][2]text.Value {
diff --git a/encoding/prototext/encode_test.go b/encoding/prototext/encode_test.go
index 003d490..c29169b 100644
--- a/encoding/prototext/encode_test.go
+++ b/encoding/prototext/encode_test.go
@@ -12,6 +12,7 @@
 	"google.golang.org/protobuf/encoding/prototext"
 	"google.golang.org/protobuf/internal/detrand"
 	"google.golang.org/protobuf/internal/encoding/pack"
+	"google.golang.org/protobuf/internal/flags"
 	pimpl "google.golang.org/protobuf/internal/impl"
 	"google.golang.org/protobuf/proto"
 	preg "google.golang.org/protobuf/reflect/protoregistry"
@@ -39,6 +40,7 @@
 		input   proto.Message
 		want    string
 		wantErr bool // TODO: Verify error message content.
+		skip    bool
 	}{{
 		desc:  "proto2 optional scalars not set",
 		input: &pb2.Scalars{},
@@ -1082,6 +1084,7 @@
   opt_string: "not a messageset extension"
 }
 `,
+		skip: !flags.Proto1Legacy,
 	}, {
 		desc: "not real MessageSet 1",
 		input: func() proto.Message {
@@ -1095,6 +1098,7 @@
   opt_string: "not a messageset extension"
 }
 `,
+		skip: !flags.Proto1Legacy,
 	}, {
 		desc: "not real MessageSet 2",
 		input: func() proto.Message {
@@ -1108,6 +1112,7 @@
   opt_string: "another not a messageset extension"
 }
 `,
+		skip: !flags.Proto1Legacy,
 	}, {
 		desc: "Any not expanded",
 		mo: prototext.MarshalOptions{
@@ -1201,6 +1206,9 @@
 
 	for _, tt := range tests {
 		tt := tt
+		if tt.skip {
+			continue
+		}
 		t.Run(tt.desc, func(t *testing.T) {
 			// Use 2-space indentation on all MarshalOptions.
 			tt.mo.Indent = "  "