proto, internal/impl: store unknown MessageSet items in non-mset format

In the v1 implementation, unknown MessageSet items are stored in a
message's unknown fields section in non-MessageSet format. For example,
consider a MessageSet containing an item with type_id T and value V.
If the type_id is not resolvable, the item will be placed in the unknown
fields as a bytes-valued field with number T and contents V. This
conversion is then reversed when marshaling a MessageSet containing
unknown fields.

Preserve this behavior in v2.

One consequence of this change is that actual unknown fields in a
MessageSet (any field other than 1) are now discarded. This matches
the previous behavior.

Change-Id: I3d913613f84e0ae82481078dbc91cb25628651cc
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/205697
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/proto/messageset.go b/proto/messageset.go
index 0d88097..e27e0b7 100644
--- a/proto/messageset.go
+++ b/proto/messageset.go
@@ -20,7 +20,7 @@
 		size += wire.SizeBytes(sizeMessage(v.Message()))
 		return true
 	})
-	size += len(m.GetUnknown())
+	size += messageset.SizeUnknown(m.GetUnknown())
 	return size
 }
 
@@ -36,8 +36,7 @@
 	if err != nil {
 		return b, err
 	}
-	b = append(b, m.GetUnknown()...)
-	return b, nil
+	return messageset.AppendUnknown(b, m.GetUnknown())
 }
 
 func marshalMessageSetField(b []byte, fd protoreflect.FieldDescriptor, value protoreflect.Value, o MarshalOptions) ([]byte, error) {
@@ -56,48 +55,34 @@
 	if !flags.ProtoLegacy {
 		return errors.New("no support for message_set_wire_format")
 	}
-	md := m.Descriptor()
-	for len(b) > 0 {
-		err := func() error {
-			num, v, n, err := messageset.ConsumeField(b)
-			if err != nil {
-				// Not a message set field.
-				//
-				// Return errUnknown to try to add this to the unknown fields.
-				// If the field is completely unparsable, we'll catch it
-				// when trying to skip the field.
-				return errUnknown
-			}
-			if !md.ExtensionRanges().Has(num) {
-				return errUnknown
-			}
-			xt, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
-			if err == protoregistry.NotFound {
-				return errUnknown
-			}
-			if err != nil {
-				return err
-			}
-			xd := xt.TypeDescriptor()
-			if err := o.unmarshalMessage(v, m.Mutable(xd).Message()); err != nil {
-				// Contents cannot be unmarshaled.
-				return err
-			}
-			b = b[n:]
-			return nil
-		}()
+	return messageset.Unmarshal(b, false, func(num wire.Number, v []byte) error {
+		err := unmarshalMessageSetField(m, num, v, o)
 		if err == errUnknown {
-			_, _, n := wire.ConsumeField(b)
-			if n < 0 {
-				return wire.ParseError(n)
-			}
-			m.SetUnknown(append(m.GetUnknown(), b[:n]...))
-			b = b[n:]
-			continue
+			unknown := m.GetUnknown()
+			unknown = wire.AppendTag(unknown, num, wire.BytesType)
+			unknown = wire.AppendBytes(unknown, v)
+			m.SetUnknown(unknown)
+			return nil
 		}
-		if err != nil {
-			return err
-		}
+		return err
+	})
+}
+
+func unmarshalMessageSetField(m protoreflect.Message, num wire.Number, v []byte, o UnmarshalOptions) error {
+	md := m.Descriptor()
+	if !md.ExtensionRanges().Has(num) {
+		return errUnknown
+	}
+	xt, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
+	if err == protoregistry.NotFound {
+		return errUnknown
+	}
+	if err != nil {
+		return err
+	}
+	xd := xt.TypeDescriptor()
+	if err := o.unmarshalMessage(v, m.Mutable(xd).Message()); err != nil {
+		return err
 	}
 	return nil
 }
diff --git a/proto/messageset_test.go b/proto/messageset_test.go
index c1ef6c9..b7c4c72 100644
--- a/proto/messageset_test.go
+++ b/proto/messageset_test.go
@@ -22,48 +22,51 @@
 var messageSetTestProtos = []testProto{
 	{
 		desc: "MessageSet type_id before message content",
-		decodeTo: []proto.Message{build(
-			&messagesetpb.MessageSet{},
-			extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+		decodeTo: []proto.Message{func() proto.Message {
+			m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
+			proto.SetExtension(m.MessageSet, msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
 				Ext1Field1: proto.Int32(10),
-			}),
-		)},
+			})
+			return m
+		}()},
 		wire: pack.Message{
-			pack.Tag{1, pack.StartGroupType},
-			pack.Tag{2, pack.VarintType}, pack.Varint(1000),
-			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
-				pack.Tag{1, pack.VarintType}, pack.Varint(10),
+			pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.StartGroupType},
+				pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+				pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+					pack.Tag{1, pack.VarintType}, pack.Varint(10),
+				}),
+				pack.Tag{1, pack.EndGroupType},
 			}),
-			pack.Tag{1, pack.EndGroupType},
 		}.Marshal(),
 	},
 	{
 		desc: "MessageSet type_id after message content",
-		decodeTo: []proto.Message{build(
-			&messagesetpb.MessageSet{},
-			extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+		decodeTo: []proto.Message{func() proto.Message {
+			m := &messagesetpb.MessageSetContainer{MessageSet: &messagesetpb.MessageSet{}}
+			proto.SetExtension(m.MessageSet, msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
 				Ext1Field1: proto.Int32(10),
-			}),
-		)},
+			})
+			return m
+		}()},
 		wire: pack.Message{
-			pack.Tag{1, pack.StartGroupType},
-			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
-				pack.Tag{1, pack.VarintType}, pack.Varint(10),
+			pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.StartGroupType},
+				pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+					pack.Tag{1, pack.VarintType}, pack.Varint(10),
+				}),
+				pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+				pack.Tag{1, pack.EndGroupType},
 			}),
-			pack.Tag{2, pack.VarintType}, pack.Varint(1000),
-			pack.Tag{1, pack.EndGroupType},
 		}.Marshal(),
 	},
 	{
-		desc: "MessageSet preserves unknown field",
+		desc: "MessageSet does not preserve unknown field",
 		decodeTo: []proto.Message{build(
 			&messagesetpb.MessageSet{},
 			extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
 				Ext1Field1: proto.Int32(10),
 			}),
-			unknown(pack.Message{
-				pack.Tag{4, pack.VarintType}, pack.Varint(30),
-			}.Marshal()),
 		)},
 		wire: pack.Message{
 			pack.Tag{1, pack.StartGroupType},
@@ -81,12 +84,9 @@
 		decodeTo: []proto.Message{build(
 			&messagesetpb.MessageSet{},
 			unknown(pack.Message{
-				pack.Tag{1, pack.StartGroupType},
-				pack.Tag{2, pack.VarintType}, pack.Varint(1002),
-				pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1002, pack.BytesType}, pack.LengthPrefix(pack.Message{
 					pack.Tag{1, pack.VarintType}, pack.Varint(10),
 				}),
-				pack.Tag{1, pack.EndGroupType},
 			}.Marshal()),
 		)},
 		wire: pack.Message{
@@ -159,13 +159,6 @@
 		desc: "MessageSet with missing type_id",
 		decodeTo: []proto.Message{build(
 			&messagesetpb.MessageSet{},
-			unknown(pack.Message{
-				pack.Tag{1, pack.StartGroupType},
-				pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
-					pack.Tag{1, pack.VarintType}, pack.Varint(10),
-				}),
-				pack.Tag{1, pack.EndGroupType},
-			}.Marshal()),
 		)},
 		wire: pack.Message{
 			pack.Tag{1, pack.StartGroupType},