proto: support message_set_wire_format

MessageSets are a deprecated proto1 feature, long since superseded by
extensions. Add disabled-by-default support behind flags.Proto1Legacy.

Change-Id: I7d3ace07f3b0efd59673034f3dc633b908345a88
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/185538
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/proto/decode.go b/proto/decode.go
index 0b98366..f147e68 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -5,6 +5,7 @@
 package proto
 
 import (
+	"google.golang.org/protobuf/internal/encoding/messageset"
 	"google.golang.org/protobuf/internal/encoding/wire"
 	"google.golang.org/protobuf/internal/errors"
 	"google.golang.org/protobuf/internal/pragma"
@@ -68,8 +69,11 @@
 }
 
 func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
-	messageDesc := m.Descriptor()
-	fieldDescs := messageDesc.Fields()
+	md := m.Descriptor()
+	if messageset.IsMessageSet(md) {
+		return unmarshalMessageSet(b, m, o)
+	}
+	fields := md.Fields()
 	for len(b) > 0 {
 		// Parse the tag (field number and wire type).
 		num, wtyp, tagLen := wire.ConsumeTag(b)
@@ -78,9 +82,9 @@
 		}
 
 		// Parse the field value.
-		fd := fieldDescs.ByNumber(num)
-		if fd == nil && messageDesc.ExtensionRanges().Has(num) {
-			extType, err := o.Resolver.FindExtensionByNumber(messageDesc.FullName(), num)
+		fd := fields.ByNumber(num)
+		if fd == nil && md.ExtensionRanges().Has(num) {
+			extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
 			if err != nil && err != protoregistry.NotFound {
 				return err
 			}
diff --git a/proto/encode.go b/proto/encode.go
index bc7da0c..5511d16 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -7,6 +7,7 @@
 import (
 	"sort"
 
+	"google.golang.org/protobuf/internal/encoding/messageset"
 	"google.golang.org/protobuf/internal/encoding/wire"
 	"google.golang.org/protobuf/internal/mapsort"
 	"google.golang.org/protobuf/internal/pragma"
@@ -111,6 +112,9 @@
 }
 
 func (o MarshalOptions) marshalMessageSlow(b []byte, m protoreflect.Message) ([]byte, error) {
+	if messageset.IsMessageSet(m.Descriptor()) {
+		return marshalMessageSet(b, m, o)
+	}
 	// There are many choices for what order we visit fields in. The default one here
 	// is chosen for reasonable efficiency and simplicity given the protoreflect API.
 	// It is not deterministic, since Message.Range does not return fields in any
diff --git a/proto/messageset.go b/proto/messageset.go
new file mode 100644
index 0000000..1c6ac29
--- /dev/null
+++ b/proto/messageset.go
@@ -0,0 +1,102 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+package proto
+
+import (
+	"google.golang.org/protobuf/internal/encoding/messageset"
+	"google.golang.org/protobuf/internal/encoding/wire"
+	"google.golang.org/protobuf/internal/errors"
+	"google.golang.org/protobuf/internal/flags"
+	"google.golang.org/protobuf/reflect/protoreflect"
+	"google.golang.org/protobuf/reflect/protoregistry"
+)
+
+func sizeMessageSet(m protoreflect.Message) (size int) {
+	m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
+		size += messageset.SizeField(fd.Number())
+		size += wire.SizeTag(messageset.FieldMessage)
+		size += wire.SizeBytes(sizeMessage(v.Message()))
+		return true
+	})
+	size += len(m.GetUnknown())
+	return size
+}
+
+func marshalMessageSet(b []byte, m protoreflect.Message, o MarshalOptions) ([]byte, error) {
+	if !flags.Proto1Legacy {
+		return b, errors.New("no support for message_set_wire_format")
+	}
+	var err error
+	o.rangeFields(m, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
+		b, err = marshalMessageSetField(b, fd, v, o)
+		return err == nil
+	})
+	if err != nil {
+		return b, err
+	}
+	b = append(b, m.GetUnknown()...)
+	return b, nil
+}
+
+func marshalMessageSetField(b []byte, fd protoreflect.FieldDescriptor, value protoreflect.Value, o MarshalOptions) ([]byte, error) {
+	b = messageset.AppendFieldStart(b, fd.Number())
+	b = wire.AppendTag(b, messageset.FieldMessage, wire.BytesType)
+	b = wire.AppendVarint(b, uint64(o.Size(value.Message().Interface())))
+	b, err := o.marshalMessage(b, value.Message())
+	if err != nil {
+		return b, err
+	}
+	b = messageset.AppendFieldEnd(b)
+	return b, nil
+}
+
+func unmarshalMessageSet(b []byte, m protoreflect.Message, o UnmarshalOptions) error {
+	if !flags.Proto1Legacy {
+		return errors.New("no support for message_set_wire_format")
+	}
+	md := m.Descriptor()
+	for len(b) > 0 {
+		err := func() error {
+			num, v, n, err := messageset.ConsumeField(b)
+			if err != nil {
+				// Not a message set field.
+				//
+				// Return errUnknown to try to add this to the unknown fields.
+				// If the field is completely unparsable, we'll catch it
+				// when trying to skip the field.
+				return errUnknown
+			}
+			if !md.ExtensionRanges().Has(num) {
+				return errUnknown
+			}
+			fd, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
+			if err == protoregistry.NotFound {
+				return errUnknown
+			}
+			if err != nil {
+				return err
+			}
+			if err := o.unmarshalMessage(v, m.Mutable(fd).Message()); err != nil {
+				// Contents cannot be unmarshaled.
+				return err
+			}
+			b = b[n:]
+			return nil
+		}()
+		if err == errUnknown {
+			_, _, n := wire.ConsumeField(b)
+			if n < 0 {
+				return wire.ParseError(n)
+			}
+			m.SetUnknown(append(m.GetUnknown(), b[:n]...))
+			b = b[n:]
+			continue
+		}
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
diff --git a/proto/messageset_test.go b/proto/messageset_test.go
new file mode 100644
index 0000000..07b2d70
--- /dev/null
+++ b/proto/messageset_test.go
@@ -0,0 +1,190 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+package proto_test
+
+import (
+	"google.golang.org/protobuf/internal/encoding/pack"
+	"google.golang.org/protobuf/internal/flags"
+	"google.golang.org/protobuf/proto"
+
+	messagesetpb "google.golang.org/protobuf/internal/testprotos/messageset/messagesetpb"
+	msetextpb "google.golang.org/protobuf/internal/testprotos/messageset/msetextpb"
+)
+
+func init() {
+	if flags.Proto1Legacy {
+		testProtos = append(testProtos, messageSetTestProtos...)
+	}
+}
+
+var messageSetTestProtos = []testProto{
+	{
+		desc: "MessageSet type_id before message content",
+		decodeTo: []proto.Message{build(
+			&messagesetpb.MessageSet{},
+			extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+				Ext1Field1: proto.Int32(10),
+			}),
+		)},
+		wire: pack.Message{
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(10),
+			}),
+			pack.Tag{1, pack.EndGroupType},
+		}.Marshal(),
+	},
+	{
+		desc: "MessageSet type_id after message content",
+		decodeTo: []proto.Message{build(
+			&messagesetpb.MessageSet{},
+			extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+				Ext1Field1: proto.Int32(10),
+			}),
+		)},
+		wire: pack.Message{
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(10),
+			}),
+			pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+			pack.Tag{1, pack.EndGroupType},
+		}.Marshal(),
+	},
+	{
+		desc: "MessageSet preserves unknown field",
+		decodeTo: []proto.Message{build(
+			&messagesetpb.MessageSet{},
+			extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+				Ext1Field1: proto.Int32(10),
+			}),
+			unknown(pack.Message{
+				pack.Tag{4, pack.VarintType}, pack.Varint(30),
+			}.Marshal()),
+		)},
+		wire: pack.Message{
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(10),
+			}),
+			pack.Tag{1, pack.EndGroupType},
+			// Unknown field
+			pack.Tag{4, pack.VarintType}, pack.Varint(30),
+		}.Marshal(),
+	},
+	{
+		desc: "MessageSet with unknown type_id",
+		decodeTo: []proto.Message{build(
+			&messagesetpb.MessageSet{},
+			unknown(pack.Message{
+				pack.Tag{1, pack.StartGroupType},
+				pack.Tag{2, pack.VarintType}, pack.Varint(1002),
+				pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+					pack.Tag{1, pack.VarintType}, pack.Varint(10),
+				}),
+				pack.Tag{1, pack.EndGroupType},
+			}.Marshal()),
+		)},
+		wire: pack.Message{
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{2, pack.VarintType}, pack.Varint(1002),
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(10),
+			}),
+			pack.Tag{1, pack.EndGroupType},
+		}.Marshal(),
+	},
+	{
+		desc: "MessageSet merges repeated message fields in item",
+		decodeTo: []proto.Message{build(
+			&messagesetpb.MessageSet{},
+			extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+				Ext1Field1: proto.Int32(10),
+				Ext1Field2: proto.Int32(20),
+			}),
+		)},
+		wire: pack.Message{
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(10),
+			}),
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{2, pack.VarintType}, pack.Varint(20),
+			}),
+			pack.Tag{1, pack.EndGroupType},
+		}.Marshal(),
+	},
+	{
+		desc: "MessageSet merges message fields in repeated items",
+		decodeTo: []proto.Message{build(
+			&messagesetpb.MessageSet{},
+			extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+				Ext1Field1: proto.Int32(10),
+				Ext1Field2: proto.Int32(20),
+			}),
+			extend(msetextpb.E_Ext2_MessageSetExtension, &msetextpb.Ext2{
+				Ext2Field1: proto.Int32(30),
+			}),
+		)},
+		wire: pack.Message{
+			// Ext1, field1
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(10),
+			}),
+			pack.Tag{1, pack.EndGroupType},
+			// Ext2, field1
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{2, pack.VarintType}, pack.Varint(1001),
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(30),
+			}),
+			pack.Tag{1, pack.EndGroupType},
+			// Ext2, field2
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{2, pack.VarintType}, pack.Varint(20),
+			}),
+			pack.Tag{1, pack.EndGroupType},
+		}.Marshal(),
+	},
+	{
+		desc: "MessageSet with missing type_id",
+		decodeTo: []proto.Message{build(
+			&messagesetpb.MessageSet{},
+			unknown(pack.Message{
+				pack.Tag{1, pack.StartGroupType},
+				pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+					pack.Tag{1, pack.VarintType}, pack.Varint(10),
+				}),
+				pack.Tag{1, pack.EndGroupType},
+			}.Marshal()),
+		)},
+		wire: pack.Message{
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(10),
+			}),
+			pack.Tag{1, pack.EndGroupType},
+		}.Marshal(),
+	},
+	{
+		desc: "MessageSet with missing message",
+		decodeTo: []proto.Message{build(
+			&messagesetpb.MessageSet{},
+			extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{}),
+		)},
+		wire: pack.Message{
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+			pack.Tag{1, pack.EndGroupType},
+		}.Marshal(),
+	},
+}
diff --git a/proto/size.go b/proto/size.go
index 1266143..9947580 100644
--- a/proto/size.go
+++ b/proto/size.go
@@ -5,6 +5,7 @@
 package proto
 
 import (
+	"google.golang.org/protobuf/internal/encoding/messageset"
 	"google.golang.org/protobuf/internal/encoding/wire"
 	"google.golang.org/protobuf/reflect/protoreflect"
 	"google.golang.org/protobuf/runtime/protoiface"
@@ -28,6 +29,9 @@
 }
 
 func sizeMessageSlow(m protoreflect.Message) (size int) {
+	if messageset.IsMessageSet(m.Descriptor()) {
+		return sizeMessageSet(m)
+	}
 	m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
 		size += sizeField(fd, v)
 		return true