proto: support message_set_wire_format

MessageSets are a deprecated proto1 feature, long since superseded by
extensions. Add disabled-by-default support behind flags.Proto1Legacy.

Change-Id: I7d3ace07f3b0efd59673034f3dc633b908345a88
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/185538
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/encoding/messageset/messageset.go b/internal/encoding/messageset/messageset.go
new file mode 100644
index 0000000..a8b6c0d
--- /dev/null
+++ b/internal/encoding/messageset/messageset.go
@@ -0,0 +1,138 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+// Package messageset encodes and decodes the obsolete MessageSet wire format.
+package messageset
+
+import (
+	"google.golang.org/protobuf/internal/encoding/wire"
+	"google.golang.org/protobuf/internal/errors"
+	pref "google.golang.org/protobuf/reflect/protoreflect"
+)
+
+// The MessageSet wire format is equivalent to a message defiend as follows,
+// where each Item defines an extension field with a field number of 'type_id'
+// and content of 'message'. MessageSet extensions must be non-repeated message
+// fields.
+//
+//	message MessageSet {
+//		repeated group Item = 1 {
+//			required int32 type_id = 2;
+//			required string message = 3;
+//		}
+//	}
+const (
+	FieldItem    = wire.Number(1)
+	FieldTypeID  = wire.Number(2)
+	FieldMessage = wire.Number(3)
+)
+
+// IsMessageSet returns whether the message uses the MessageSet wire format.
+func IsMessageSet(md pref.MessageDescriptor) bool {
+	xmd, ok := md.(interface{ IsMessageSet() bool })
+	return ok && xmd.IsMessageSet()
+}
+
+// SizeField returns the size of a MessageSet item field containing an extension
+// with the given field number, not counting the contents of the message subfield.
+func SizeField(num wire.Number) int {
+	return 2*wire.SizeTag(FieldItem) + wire.SizeTag(FieldTypeID) + wire.SizeVarint(uint64(num))
+}
+
+// ConsumeField parses a MessageSet item field and returns the contents of the
+// type_id and message subfields and the total item length.
+func ConsumeField(b []byte) (typeid wire.Number, message []byte, n int, err error) {
+	num, wtyp, n := wire.ConsumeTag(b)
+	if n < 0 {
+		return 0, nil, 0, wire.ParseError(n)
+	}
+	if num != FieldItem || wtyp != wire.StartGroupType {
+		return 0, nil, 0, errors.New("invalid MessageSet field number")
+	}
+	typeid, message, fieldLen, err := ConsumeFieldValue(b[n:], false)
+	if err != nil {
+		return 0, nil, 0, err
+	}
+	return typeid, message, n + fieldLen, nil
+}
+
+// ConsumeFieldValue parses b as a MessageSet item field value until and including
+// the trailing end group marker. It assumes the start group tag has already been parsed.
+// It returns the contents of the type_id and message subfields and the total
+// item length.
+//
+// If wantLen is true, the returned message value includes the length prefix.
+// This is ugly, but simplifies the fast-path decoder in internal/impl.
+func ConsumeFieldValue(b []byte, wantLen bool) (typeid wire.Number, message []byte, n int, err error) {
+	ilen := len(b)
+	for {
+		num, wtyp, n := wire.ConsumeTag(b)
+		if n < 0 {
+			return 0, nil, 0, wire.ParseError(n)
+		}
+		b = b[n:]
+		switch {
+		case num == FieldItem && wtyp == wire.EndGroupType:
+			if wantLen && len(message) == 0 {
+				// The message field was missing, which should never happen.
+				// Be prepared for this case anyway.
+				message = wire.AppendVarint(message, 0)
+			}
+			return typeid, message, ilen - len(b), nil
+		case num == FieldTypeID && wtyp == wire.VarintType:
+			v, n := wire.ConsumeVarint(b)
+			if n < 0 {
+				return 0, nil, 0, wire.ParseError(n)
+			}
+			b = b[n:]
+			typeid = wire.Number(v)
+		case num == FieldMessage && wtyp == wire.BytesType:
+			m, n := wire.ConsumeBytes(b)
+			if n < 0 {
+				return 0, nil, 0, wire.ParseError(n)
+			}
+			if message == nil {
+				if wantLen {
+					message = b[:n]
+				} else {
+					message = m
+				}
+			} else {
+				// This case should never happen in practice, but handle it for
+				// correctness: The MessageSet item contains multiple message
+				// fields, which need to be merged.
+				//
+				// In the case where we're returning the length, this becomes
+				// quite inefficient since we need to strip the length off
+				// the existing data and reconstruct it with the combined length.
+				if wantLen {
+					_, nn := wire.ConsumeVarint(message)
+					m0 := message[nn:]
+					message = message[:0]
+					message = wire.AppendVarint(message, uint64(len(m0)+len(m)))
+					message = append(message, m0...)
+					message = append(message, m...)
+				} else {
+					message = append(message, m...)
+				}
+			}
+			b = b[n:]
+		}
+	}
+}
+
+// AppendFieldStart appends the start of a MessageSet item field containing
+// an extension with the given number. The caller must add the message
+// subfield (including the tag).
+func AppendFieldStart(b []byte, num wire.Number) []byte {
+	b = wire.AppendTag(b, FieldItem, wire.StartGroupType)
+	b = wire.AppendTag(b, FieldTypeID, wire.VarintType)
+	b = wire.AppendVarint(b, uint64(num))
+	return b
+}
+
+// AppendFieldEnd appends the trailing end group marker for a MessageSet item field.
+func AppendFieldEnd(b []byte) []byte {
+	return wire.AppendTag(b, FieldItem, wire.EndGroupType)
+}
diff --git a/internal/impl/codec_message.go b/internal/impl/codec_message.go
index 2b8b047..e43c812 100644
--- a/internal/impl/codec_message.go
+++ b/internal/impl/codec_message.go
@@ -5,9 +5,11 @@
 package impl
 
 import (
+	"fmt"
 	"reflect"
 	"sort"
 
+	"google.golang.org/protobuf/internal/encoding/messageset"
 	"google.golang.org/protobuf/internal/encoding/wire"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 	piface "google.golang.org/protobuf/runtime/protoiface"
@@ -77,6 +79,21 @@
 		mi.orderedCoderFields = append(mi.orderedCoderFields, cf)
 		mi.coderFields[cf.num] = cf
 	}
+	if messageset.IsMessageSet(mi.PBType.Descriptor()) {
+		if !mi.extensionOffset.IsValid() {
+			panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.PBType.FullName()))
+		}
+		cf := &coderFieldInfo{
+			num:       messageset.FieldItem,
+			offset:    si.extensionOffset,
+			isPointer: true,
+			funcs:     makeMessageSetFieldCoder(mi),
+		}
+		mi.orderedCoderFields = append(mi.orderedCoderFields, cf)
+		mi.coderFields[cf.num] = cf
+		// Invalidate the extension offset, since the field codec handles extensions.
+		mi.extensionOffset = invalidOffset
+	}
 	sort.Slice(mi.orderedCoderFields, func(i, j int) bool {
 		return mi.orderedCoderFields[i].num < mi.orderedCoderFields[j].num
 	})
diff --git a/internal/impl/codec_messageset.go b/internal/impl/codec_messageset.go
new file mode 100644
index 0000000..073c4dc
--- /dev/null
+++ b/internal/impl/codec_messageset.go
@@ -0,0 +1,119 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+package impl
+
+import (
+	"sort"
+
+	"google.golang.org/protobuf/internal/encoding/messageset"
+	"google.golang.org/protobuf/internal/encoding/wire"
+	"google.golang.org/protobuf/internal/errors"
+	"google.golang.org/protobuf/internal/flags"
+)
+
+func makeMessageSetFieldCoder(mi *MessageInfo) pointerCoderFuncs {
+	return pointerCoderFuncs{
+		size: func(p pointer, tagsize int, opts marshalOptions) int {
+			return sizeMessageSet(mi, p, tagsize, opts)
+		},
+		marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
+			return marshalMessageSet(mi, b, p, wiretag, opts)
+		},
+		unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+			return unmarshalMessageSet(mi, b, p, wtyp, opts)
+		},
+	}
+}
+
+func sizeMessageSet(mi *MessageInfo, p pointer, tagsize int, opts marshalOptions) (n int) {
+	ext := *p.Extensions()
+	if ext == nil {
+		return 0
+	}
+	for _, x := range ext {
+		xi := mi.extensionFieldInfo(x.GetType())
+		if xi.funcs.size == nil {
+			continue
+		}
+		num, _ := wire.DecodeTag(xi.wiretag)
+		n += messageset.SizeField(num)
+		n += xi.funcs.size(x.GetValue(), wire.SizeTag(messageset.FieldMessage), opts)
+	}
+	return n
+}
+
+func marshalMessageSet(mi *MessageInfo, b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
+	if !flags.Proto1Legacy {
+		return b, errors.New("no support for message_set_wire_format")
+	}
+	ext := *p.Extensions()
+	if ext == nil {
+		return b, nil
+	}
+	switch len(ext) {
+	case 0:
+		return b, nil
+	case 1:
+		// Fast-path for one extension: Don't bother sorting the keys.
+		for _, x := range ext {
+			var err error
+			b, err = marshalMessageSetField(mi, b, x, opts)
+			if err != nil {
+				return b, err
+			}
+		}
+		return b, nil
+	default:
+		// Sort the keys to provide a deterministic encoding.
+		// Not sure this is required, but the old code does it.
+		keys := make([]int, 0, len(ext))
+		for k := range ext {
+			keys = append(keys, int(k))
+		}
+		sort.Ints(keys)
+		for _, k := range keys {
+			var err error
+			b, err = marshalMessageSetField(mi, b, ext[int32(k)], opts)
+			if err != nil {
+				return b, err
+			}
+		}
+		return b, nil
+	}
+}
+
+func marshalMessageSetField(mi *MessageInfo, b []byte, x ExtensionField, opts marshalOptions) ([]byte, error) {
+	xi := mi.extensionFieldInfo(x.GetType())
+	num, _ := wire.DecodeTag(xi.wiretag)
+	b = messageset.AppendFieldStart(b, num)
+	b, err := xi.funcs.marshal(b, x.GetValue(), wire.EncodeTag(messageset.FieldMessage, wire.BytesType), opts)
+	if err != nil {
+		return b, err
+	}
+	b = messageset.AppendFieldEnd(b)
+	return b, nil
+}
+
+func unmarshalMessageSet(mi *MessageInfo, b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+	if !flags.Proto1Legacy {
+		return 0, errors.New("no support for message_set_wire_format")
+	}
+	if wtyp != wire.StartGroupType {
+		return 0, errUnknown
+	}
+	ep := p.Extensions()
+	if *ep == nil {
+		*ep = make(map[int32]ExtensionField)
+	}
+	ext := *ep
+	num, v, n, err := messageset.ConsumeFieldValue(b, true)
+	if err != nil {
+		return 0, err
+	}
+	if _, err := mi.unmarshalExtension(v, num, wire.BytesType, ext, opts); err != nil {
+		return 0, err
+	}
+	return n, nil
+}
diff --git a/internal/testprotos/messageset/messagesetpb/message_set.pb.go b/internal/testprotos/messageset/messagesetpb/message_set.pb.go
new file mode 100644
index 0000000..aaa14da
--- /dev/null
+++ b/internal/testprotos/messageset/messagesetpb/message_set.pb.go
@@ -0,0 +1,140 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// source: messageset/messagesetpb/message_set.proto
+
+package messagesetpb
+
+import (
+	protoreflect "google.golang.org/protobuf/reflect/protoreflect"
+	protoiface "google.golang.org/protobuf/runtime/protoiface"
+	protoimpl "google.golang.org/protobuf/runtime/protoimpl"
+	sync "sync"
+)
+
+const (
+	// Verify that runtime/protoimpl is sufficiently up-to-date.
+	_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 0)
+	// Verify that this generated code is sufficiently up-to-date.
+	_ = protoimpl.EnforceVersion(0 - protoimpl.MinVersion)
+)
+
+type MessageSet struct {
+	state           protoimpl.MessageState
+	sizeCache       protoimpl.SizeCache
+	unknownFields   protoimpl.UnknownFields
+	extensionFields protoimpl.ExtensionFields
+}
+
+func (x *MessageSet) Reset() {
+	*x = MessageSet{}
+}
+
+func (x *MessageSet) String() string {
+	return protoimpl.X.MessageStringOf(x)
+}
+
+func (*MessageSet) ProtoMessage() {}
+
+func (x *MessageSet) ProtoReflect() protoreflect.Message {
+	mi := &file_messageset_messagesetpb_message_set_proto_msgTypes[0]
+	if protoimpl.UnsafeEnabled && x != nil {
+		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+		if ms.LoadMessageInfo() == nil {
+			ms.StoreMessageInfo(mi)
+		}
+		return ms
+	}
+	return mi.MessageOf(x)
+}
+
+// Deprecated: Use MessageSet.ProtoReflect.Type instead.
+func (*MessageSet) Descriptor() ([]byte, []int) {
+	return file_messageset_messagesetpb_message_set_proto_rawDescGZIP(), []int{0}
+}
+
+var extRange_MessageSet = []protoiface.ExtensionRangeV1{
+	{Start: 4, End: 2147483646},
+}
+
+// Deprecated: Use MessageSet.ProtoReflect.Type.ExtensionRanges instead.
+func (*MessageSet) ExtensionRangeArray() []protoiface.ExtensionRangeV1 {
+	return extRange_MessageSet
+}
+
+var File_messageset_messagesetpb_message_set_proto protoreflect.FileDescriptor
+
+var file_messageset_messagesetpb_message_set_proto_rawDesc = []byte{
+	0x0a, 0x29, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2f, 0x6d, 0x65, 0x73,
+	0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x70, 0x62, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67,
+	0x65, 0x5f, 0x73, 0x65, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x18, 0x67, 0x6f, 0x70,
+	0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x65, 0x73, 0x73, 0x61,
+	0x67, 0x65, 0x73, 0x65, 0x74, 0x22, 0x1a, 0x0a, 0x0a, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
+	0x53, 0x65, 0x74, 0x2a, 0x08, 0x08, 0x04, 0x10, 0xff, 0xff, 0xff, 0xff, 0x07, 0x3a, 0x02, 0x08,
+	0x01, 0x42, 0x48, 0x5a, 0x46, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61,
+	0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f,
+	0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f,
+	0x74, 0x6f, 0x73, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2f, 0x6d,
+	0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x70, 0x62,
+}
+
+var (
+	file_messageset_messagesetpb_message_set_proto_rawDescOnce sync.Once
+	file_messageset_messagesetpb_message_set_proto_rawDescData = file_messageset_messagesetpb_message_set_proto_rawDesc
+)
+
+func file_messageset_messagesetpb_message_set_proto_rawDescGZIP() []byte {
+	file_messageset_messagesetpb_message_set_proto_rawDescOnce.Do(func() {
+		file_messageset_messagesetpb_message_set_proto_rawDescData = protoimpl.X.CompressGZIP(file_messageset_messagesetpb_message_set_proto_rawDescData)
+	})
+	return file_messageset_messagesetpb_message_set_proto_rawDescData
+}
+
+var file_messageset_messagesetpb_message_set_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
+var file_messageset_messagesetpb_message_set_proto_goTypes = []interface{}{
+	(*MessageSet)(nil), // 0: goproto.proto.messageset.MessageSet
+}
+var file_messageset_messagesetpb_message_set_proto_depIdxs = []int32{
+	0, // starting offset of method output_type sub-list
+	0, // starting offset of method input_type sub-list
+	0, // starting offset of extension type_name sub-list
+	0, // starting offset of extension extendee sub-list
+	0, // starting offset of field type_name sub-list
+}
+
+func init() { file_messageset_messagesetpb_message_set_proto_init() }
+func file_messageset_messagesetpb_message_set_proto_init() {
+	if File_messageset_messagesetpb_message_set_proto != nil {
+		return
+	}
+	if !protoimpl.UnsafeEnabled {
+		file_messageset_messagesetpb_message_set_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
+			switch v := v.(*MessageSet); i {
+			case 0:
+				return &v.state
+			case 1:
+				return &v.sizeCache
+			case 2:
+				return &v.unknownFields
+			case 3:
+				return &v.extensionFields
+			default:
+				return nil
+			}
+		}
+	}
+	out := protoimpl.TypeBuilder{
+		File: protoimpl.DescBuilder{
+			RawDescriptor: file_messageset_messagesetpb_message_set_proto_rawDesc,
+			NumEnums:      0,
+			NumMessages:   1,
+			NumExtensions: 0,
+			NumServices:   0,
+		},
+		GoTypes:           file_messageset_messagesetpb_message_set_proto_goTypes,
+		DependencyIndexes: file_messageset_messagesetpb_message_set_proto_depIdxs,
+		MessageInfos:      file_messageset_messagesetpb_message_set_proto_msgTypes,
+	}.Build()
+	File_messageset_messagesetpb_message_set_proto = out.File
+	file_messageset_messagesetpb_message_set_proto_rawDesc = nil
+	file_messageset_messagesetpb_message_set_proto_goTypes = nil
+	file_messageset_messagesetpb_message_set_proto_depIdxs = nil
+}
diff --git a/internal/testprotos/messageset/messagesetpb/message_set.proto b/internal/testprotos/messageset/messagesetpb/message_set.proto
new file mode 100644
index 0000000..08f7a4d
--- /dev/null
+++ b/internal/testprotos/messageset/messagesetpb/message_set.proto
@@ -0,0 +1,14 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+syntax = "proto2";
+
+package goproto.proto.messageset;
+
+option go_package = "google.golang.org/protobuf/internal/testprotos/messageset/messagesetpb";
+
+message MessageSet {
+  option message_set_wire_format = true;
+  extensions 4 to max;
+}
diff --git a/internal/testprotos/messageset/msetextpb/msetextpb.pb.go b/internal/testprotos/messageset/msetextpb/msetextpb.pb.go
new file mode 100644
index 0000000..9af5304
--- /dev/null
+++ b/internal/testprotos/messageset/msetextpb/msetextpb.pb.go
@@ -0,0 +1,253 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// source: messageset/msetextpb/msetextpb.proto
+
+package msetextpb
+
+import (
+	messagesetpb "google.golang.org/protobuf/internal/testprotos/messageset/messagesetpb"
+	protoreflect "google.golang.org/protobuf/reflect/protoreflect"
+	protoiface "google.golang.org/protobuf/runtime/protoiface"
+	protoimpl "google.golang.org/protobuf/runtime/protoimpl"
+	sync "sync"
+)
+
+const (
+	// Verify that runtime/protoimpl is sufficiently up-to-date.
+	_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 0)
+	// Verify that this generated code is sufficiently up-to-date.
+	_ = protoimpl.EnforceVersion(0 - protoimpl.MinVersion)
+)
+
+type Ext1 struct {
+	state         protoimpl.MessageState
+	Ext1Field1    *int32 `protobuf:"varint,1,opt,name=ext1_field1,json=ext1Field1" json:"ext1_field1,omitempty"`
+	Ext1Field2    *int32 `protobuf:"varint,2,opt,name=ext1_field2,json=ext1Field2" json:"ext1_field2,omitempty"`
+	sizeCache     protoimpl.SizeCache
+	unknownFields protoimpl.UnknownFields
+}
+
+func (x *Ext1) Reset() {
+	*x = Ext1{}
+}
+
+func (x *Ext1) String() string {
+	return protoimpl.X.MessageStringOf(x)
+}
+
+func (*Ext1) ProtoMessage() {}
+
+func (x *Ext1) ProtoReflect() protoreflect.Message {
+	mi := &file_messageset_msetextpb_msetextpb_proto_msgTypes[0]
+	if protoimpl.UnsafeEnabled && x != nil {
+		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+		if ms.LoadMessageInfo() == nil {
+			ms.StoreMessageInfo(mi)
+		}
+		return ms
+	}
+	return mi.MessageOf(x)
+}
+
+// Deprecated: Use Ext1.ProtoReflect.Type instead.
+func (*Ext1) Descriptor() ([]byte, []int) {
+	return file_messageset_msetextpb_msetextpb_proto_rawDescGZIP(), []int{0}
+}
+
+func (x *Ext1) GetExt1Field1() int32 {
+	if x != nil && x.Ext1Field1 != nil {
+		return *x.Ext1Field1
+	}
+	return 0
+}
+
+func (x *Ext1) GetExt1Field2() int32 {
+	if x != nil && x.Ext1Field2 != nil {
+		return *x.Ext1Field2
+	}
+	return 0
+}
+
+type Ext2 struct {
+	state         protoimpl.MessageState
+	Ext2Field1    *int32 `protobuf:"varint,1,opt,name=ext2_field1,json=ext2Field1" json:"ext2_field1,omitempty"`
+	sizeCache     protoimpl.SizeCache
+	unknownFields protoimpl.UnknownFields
+}
+
+func (x *Ext2) Reset() {
+	*x = Ext2{}
+}
+
+func (x *Ext2) String() string {
+	return protoimpl.X.MessageStringOf(x)
+}
+
+func (*Ext2) ProtoMessage() {}
+
+func (x *Ext2) ProtoReflect() protoreflect.Message {
+	mi := &file_messageset_msetextpb_msetextpb_proto_msgTypes[1]
+	if protoimpl.UnsafeEnabled && x != nil {
+		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+		if ms.LoadMessageInfo() == nil {
+			ms.StoreMessageInfo(mi)
+		}
+		return ms
+	}
+	return mi.MessageOf(x)
+}
+
+// Deprecated: Use Ext2.ProtoReflect.Type instead.
+func (*Ext2) Descriptor() ([]byte, []int) {
+	return file_messageset_msetextpb_msetextpb_proto_rawDescGZIP(), []int{1}
+}
+
+func (x *Ext2) GetExt2Field1() int32 {
+	if x != nil && x.Ext2Field1 != nil {
+		return *x.Ext2Field1
+	}
+	return 0
+}
+
+var file_messageset_msetextpb_msetextpb_proto_extDescs = []protoiface.ExtensionDescV1{
+	{
+		ExtendedType:  (*messagesetpb.MessageSet)(nil),
+		ExtensionType: (*Ext1)(nil),
+		Field:         1000,
+		Name:          "goproto.proto.messageset.Ext1.message_set_extension",
+		Tag:           "bytes,1000,opt,name=message_set_extension",
+		Filename:      "messageset/msetextpb/msetextpb.proto",
+	},
+	{
+		ExtendedType:  (*messagesetpb.MessageSet)(nil),
+		ExtensionType: (*Ext2)(nil),
+		Field:         1001,
+		Name:          "goproto.proto.messageset.Ext2.message_set_extension",
+		Tag:           "bytes,1001,opt,name=message_set_extension",
+		Filename:      "messageset/msetextpb/msetextpb.proto",
+	},
+}
+var (
+	// extend goproto.proto.messageset.MessageSet { optional goproto.proto.messageset.Ext1 message_set_extension = 1000; }
+	E_Ext1_MessageSetExtension = &file_messageset_msetextpb_msetextpb_proto_extDescs[0]
+
+	// extend goproto.proto.messageset.MessageSet { optional goproto.proto.messageset.Ext2 message_set_extension = 1001; }
+	E_Ext2_MessageSetExtension = &file_messageset_msetextpb_msetextpb_proto_extDescs[1]
+)
+var File_messageset_msetextpb_msetextpb_proto protoreflect.FileDescriptor
+
+var file_messageset_msetextpb_msetextpb_proto_rawDesc = []byte{
+	0x0a, 0x24, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2f, 0x6d, 0x73, 0x65,
+	0x74, 0x65, 0x78, 0x74, 0x70, 0x62, 0x2f, 0x6d, 0x73, 0x65, 0x74, 0x65, 0x78, 0x74, 0x70, 0x62,
+	0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x18, 0x67, 0x6f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e,
+	0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74,
+	0x1a, 0x29, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2f, 0x6d, 0x65, 0x73,
+	0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x70, 0x62, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67,
+	0x65, 0x5f, 0x73, 0x65, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xc3, 0x01, 0x0a, 0x04,
+	0x45, 0x78, 0x74, 0x31, 0x12, 0x1f, 0x0a, 0x0b, 0x65, 0x78, 0x74, 0x31, 0x5f, 0x66, 0x69, 0x65,
+	0x6c, 0x64, 0x31, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x65, 0x78, 0x74, 0x31, 0x46,
+	0x69, 0x65, 0x6c, 0x64, 0x31, 0x12, 0x1f, 0x0a, 0x0b, 0x65, 0x78, 0x74, 0x31, 0x5f, 0x66, 0x69,
+	0x65, 0x6c, 0x64, 0x32, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x65, 0x78, 0x74, 0x31,
+	0x46, 0x69, 0x65, 0x6c, 0x64, 0x32, 0x32, 0x79, 0x0a, 0x15, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67,
+	0x65, 0x5f, 0x73, 0x65, 0x74, 0x5f, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x12,
+	0x24, 0x2e, 0x67, 0x6f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e,
+	0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61,
+	0x67, 0x65, 0x53, 0x65, 0x74, 0x18, 0xe8, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x67,
+	0x6f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x65, 0x73,
+	0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2e, 0x45, 0x78, 0x74, 0x31, 0x52, 0x13, 0x6d, 0x65,
+	0x73, 0x73, 0x61, 0x67, 0x65, 0x53, 0x65, 0x74, 0x45, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f,
+	0x6e, 0x22, 0xa2, 0x01, 0x0a, 0x04, 0x45, 0x78, 0x74, 0x32, 0x12, 0x1f, 0x0a, 0x0b, 0x65, 0x78,
+	0x74, 0x32, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52,
+	0x0a, 0x65, 0x78, 0x74, 0x32, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x32, 0x79, 0x0a, 0x15, 0x6d,
+	0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x5f, 0x73, 0x65, 0x74, 0x5f, 0x65, 0x78, 0x74, 0x65, 0x6e,
+	0x73, 0x69, 0x6f, 0x6e, 0x12, 0x24, 0x2e, 0x67, 0x6f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70,
+	0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2e,
+	0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x53, 0x65, 0x74, 0x18, 0xe9, 0x07, 0x20, 0x01, 0x28,
+	0x0b, 0x32, 0x1e, 0x2e, 0x67, 0x6f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74,
+	0x6f, 0x2e, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2e, 0x45, 0x78, 0x74,
+	0x32, 0x52, 0x13, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x53, 0x65, 0x74, 0x45, 0x78, 0x74,
+	0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x42, 0x45, 0x5a, 0x43, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65,
+	0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x70, 0x72, 0x6f, 0x74,
+	0x6f, 0x62, 0x75, 0x66, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x74, 0x65,
+	0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
+	0x73, 0x65, 0x74, 0x2f, 0x6d, 0x73, 0x65, 0x74, 0x65, 0x78, 0x74, 0x70, 0x62,
+}
+
+var (
+	file_messageset_msetextpb_msetextpb_proto_rawDescOnce sync.Once
+	file_messageset_msetextpb_msetextpb_proto_rawDescData = file_messageset_msetextpb_msetextpb_proto_rawDesc
+)
+
+func file_messageset_msetextpb_msetextpb_proto_rawDescGZIP() []byte {
+	file_messageset_msetextpb_msetextpb_proto_rawDescOnce.Do(func() {
+		file_messageset_msetextpb_msetextpb_proto_rawDescData = protoimpl.X.CompressGZIP(file_messageset_msetextpb_msetextpb_proto_rawDescData)
+	})
+	return file_messageset_msetextpb_msetextpb_proto_rawDescData
+}
+
+var file_messageset_msetextpb_msetextpb_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
+var file_messageset_msetextpb_msetextpb_proto_goTypes = []interface{}{
+	(*Ext1)(nil),                    // 0: goproto.proto.messageset.Ext1
+	(*Ext2)(nil),                    // 1: goproto.proto.messageset.Ext2
+	(*messagesetpb.MessageSet)(nil), // 2: goproto.proto.messageset.MessageSet
+}
+var file_messageset_msetextpb_msetextpb_proto_depIdxs = []int32{
+	2, // goproto.proto.messageset.Ext1.message_set_extension:extendee -> goproto.proto.messageset.MessageSet
+	2, // goproto.proto.messageset.Ext2.message_set_extension:extendee -> goproto.proto.messageset.MessageSet
+	0, // goproto.proto.messageset.Ext1.message_set_extension:type_name -> goproto.proto.messageset.Ext1
+	1, // goproto.proto.messageset.Ext2.message_set_extension:type_name -> goproto.proto.messageset.Ext2
+	4, // starting offset of method output_type sub-list
+	4, // starting offset of method input_type sub-list
+	2, // starting offset of extension type_name sub-list
+	0, // starting offset of extension extendee sub-list
+	0, // starting offset of field type_name sub-list
+}
+
+func init() { file_messageset_msetextpb_msetextpb_proto_init() }
+func file_messageset_msetextpb_msetextpb_proto_init() {
+	if File_messageset_msetextpb_msetextpb_proto != nil {
+		return
+	}
+	if !protoimpl.UnsafeEnabled {
+		file_messageset_msetextpb_msetextpb_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
+			switch v := v.(*Ext1); i {
+			case 0:
+				return &v.state
+			case 3:
+				return &v.sizeCache
+			case 4:
+				return &v.unknownFields
+			default:
+				return nil
+			}
+		}
+		file_messageset_msetextpb_msetextpb_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
+			switch v := v.(*Ext2); i {
+			case 0:
+				return &v.state
+			case 2:
+				return &v.sizeCache
+			case 3:
+				return &v.unknownFields
+			default:
+				return nil
+			}
+		}
+	}
+	out := protoimpl.TypeBuilder{
+		File: protoimpl.DescBuilder{
+			RawDescriptor: file_messageset_msetextpb_msetextpb_proto_rawDesc,
+			NumEnums:      0,
+			NumMessages:   2,
+			NumExtensions: 2,
+			NumServices:   0,
+		},
+		GoTypes:           file_messageset_msetextpb_msetextpb_proto_goTypes,
+		DependencyIndexes: file_messageset_msetextpb_msetextpb_proto_depIdxs,
+		MessageInfos:      file_messageset_msetextpb_msetextpb_proto_msgTypes,
+		LegacyExtensions:  file_messageset_msetextpb_msetextpb_proto_extDescs,
+	}.Build()
+	File_messageset_msetextpb_msetextpb_proto = out.File
+	file_messageset_msetextpb_msetextpb_proto_rawDesc = nil
+	file_messageset_msetextpb_msetextpb_proto_goTypes = nil
+	file_messageset_msetextpb_msetextpb_proto_depIdxs = nil
+}
diff --git a/internal/testprotos/messageset/msetextpb/msetextpb.proto b/internal/testprotos/messageset/msetextpb/msetextpb.proto
new file mode 100644
index 0000000..5d1bf08
--- /dev/null
+++ b/internal/testprotos/messageset/msetextpb/msetextpb.proto
@@ -0,0 +1,26 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+syntax = "proto2";
+
+package goproto.proto.messageset;
+
+option go_package = "google.golang.org/protobuf/internal/testprotos/messageset/msetextpb";
+
+import "messageset/messagesetpb/message_set.proto";
+
+message Ext1 {
+  extend MessageSet {
+    optional Ext1 message_set_extension = 1000;
+  }
+  optional int32 ext1_field1 = 1;
+  optional int32 ext1_field2 = 2;
+}
+
+message Ext2 {
+  extend MessageSet {
+    optional Ext2 message_set_extension = 1001;
+  }
+  optional int32 ext2_field1 = 1;
+}
diff --git a/proto/decode.go b/proto/decode.go
index 0b98366..f147e68 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -5,6 +5,7 @@
 package proto
 
 import (
+	"google.golang.org/protobuf/internal/encoding/messageset"
 	"google.golang.org/protobuf/internal/encoding/wire"
 	"google.golang.org/protobuf/internal/errors"
 	"google.golang.org/protobuf/internal/pragma"
@@ -68,8 +69,11 @@
 }
 
 func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
-	messageDesc := m.Descriptor()
-	fieldDescs := messageDesc.Fields()
+	md := m.Descriptor()
+	if messageset.IsMessageSet(md) {
+		return unmarshalMessageSet(b, m, o)
+	}
+	fields := md.Fields()
 	for len(b) > 0 {
 		// Parse the tag (field number and wire type).
 		num, wtyp, tagLen := wire.ConsumeTag(b)
@@ -78,9 +82,9 @@
 		}
 
 		// Parse the field value.
-		fd := fieldDescs.ByNumber(num)
-		if fd == nil && messageDesc.ExtensionRanges().Has(num) {
-			extType, err := o.Resolver.FindExtensionByNumber(messageDesc.FullName(), num)
+		fd := fields.ByNumber(num)
+		if fd == nil && md.ExtensionRanges().Has(num) {
+			extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
 			if err != nil && err != protoregistry.NotFound {
 				return err
 			}
diff --git a/proto/encode.go b/proto/encode.go
index bc7da0c..5511d16 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -7,6 +7,7 @@
 import (
 	"sort"
 
+	"google.golang.org/protobuf/internal/encoding/messageset"
 	"google.golang.org/protobuf/internal/encoding/wire"
 	"google.golang.org/protobuf/internal/mapsort"
 	"google.golang.org/protobuf/internal/pragma"
@@ -111,6 +112,9 @@
 }
 
 func (o MarshalOptions) marshalMessageSlow(b []byte, m protoreflect.Message) ([]byte, error) {
+	if messageset.IsMessageSet(m.Descriptor()) {
+		return marshalMessageSet(b, m, o)
+	}
 	// There are many choices for what order we visit fields in. The default one here
 	// is chosen for reasonable efficiency and simplicity given the protoreflect API.
 	// It is not deterministic, since Message.Range does not return fields in any
diff --git a/proto/messageset.go b/proto/messageset.go
new file mode 100644
index 0000000..1c6ac29
--- /dev/null
+++ b/proto/messageset.go
@@ -0,0 +1,102 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+package proto
+
+import (
+	"google.golang.org/protobuf/internal/encoding/messageset"
+	"google.golang.org/protobuf/internal/encoding/wire"
+	"google.golang.org/protobuf/internal/errors"
+	"google.golang.org/protobuf/internal/flags"
+	"google.golang.org/protobuf/reflect/protoreflect"
+	"google.golang.org/protobuf/reflect/protoregistry"
+)
+
+func sizeMessageSet(m protoreflect.Message) (size int) {
+	m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
+		size += messageset.SizeField(fd.Number())
+		size += wire.SizeTag(messageset.FieldMessage)
+		size += wire.SizeBytes(sizeMessage(v.Message()))
+		return true
+	})
+	size += len(m.GetUnknown())
+	return size
+}
+
+func marshalMessageSet(b []byte, m protoreflect.Message, o MarshalOptions) ([]byte, error) {
+	if !flags.Proto1Legacy {
+		return b, errors.New("no support for message_set_wire_format")
+	}
+	var err error
+	o.rangeFields(m, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
+		b, err = marshalMessageSetField(b, fd, v, o)
+		return err == nil
+	})
+	if err != nil {
+		return b, err
+	}
+	b = append(b, m.GetUnknown()...)
+	return b, nil
+}
+
+func marshalMessageSetField(b []byte, fd protoreflect.FieldDescriptor, value protoreflect.Value, o MarshalOptions) ([]byte, error) {
+	b = messageset.AppendFieldStart(b, fd.Number())
+	b = wire.AppendTag(b, messageset.FieldMessage, wire.BytesType)
+	b = wire.AppendVarint(b, uint64(o.Size(value.Message().Interface())))
+	b, err := o.marshalMessage(b, value.Message())
+	if err != nil {
+		return b, err
+	}
+	b = messageset.AppendFieldEnd(b)
+	return b, nil
+}
+
+func unmarshalMessageSet(b []byte, m protoreflect.Message, o UnmarshalOptions) error {
+	if !flags.Proto1Legacy {
+		return errors.New("no support for message_set_wire_format")
+	}
+	md := m.Descriptor()
+	for len(b) > 0 {
+		err := func() error {
+			num, v, n, err := messageset.ConsumeField(b)
+			if err != nil {
+				// Not a message set field.
+				//
+				// Return errUnknown to try to add this to the unknown fields.
+				// If the field is completely unparsable, we'll catch it
+				// when trying to skip the field.
+				return errUnknown
+			}
+			if !md.ExtensionRanges().Has(num) {
+				return errUnknown
+			}
+			fd, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
+			if err == protoregistry.NotFound {
+				return errUnknown
+			}
+			if err != nil {
+				return err
+			}
+			if err := o.unmarshalMessage(v, m.Mutable(fd).Message()); err != nil {
+				// Contents cannot be unmarshaled.
+				return err
+			}
+			b = b[n:]
+			return nil
+		}()
+		if err == errUnknown {
+			_, _, n := wire.ConsumeField(b)
+			if n < 0 {
+				return wire.ParseError(n)
+			}
+			m.SetUnknown(append(m.GetUnknown(), b[:n]...))
+			b = b[n:]
+			continue
+		}
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
diff --git a/proto/messageset_test.go b/proto/messageset_test.go
new file mode 100644
index 0000000..07b2d70
--- /dev/null
+++ b/proto/messageset_test.go
@@ -0,0 +1,190 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+package proto_test
+
+import (
+	"google.golang.org/protobuf/internal/encoding/pack"
+	"google.golang.org/protobuf/internal/flags"
+	"google.golang.org/protobuf/proto"
+
+	messagesetpb "google.golang.org/protobuf/internal/testprotos/messageset/messagesetpb"
+	msetextpb "google.golang.org/protobuf/internal/testprotos/messageset/msetextpb"
+)
+
+func init() {
+	if flags.Proto1Legacy {
+		testProtos = append(testProtos, messageSetTestProtos...)
+	}
+}
+
+var messageSetTestProtos = []testProto{
+	{
+		desc: "MessageSet type_id before message content",
+		decodeTo: []proto.Message{build(
+			&messagesetpb.MessageSet{},
+			extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+				Ext1Field1: proto.Int32(10),
+			}),
+		)},
+		wire: pack.Message{
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(10),
+			}),
+			pack.Tag{1, pack.EndGroupType},
+		}.Marshal(),
+	},
+	{
+		desc: "MessageSet type_id after message content",
+		decodeTo: []proto.Message{build(
+			&messagesetpb.MessageSet{},
+			extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+				Ext1Field1: proto.Int32(10),
+			}),
+		)},
+		wire: pack.Message{
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(10),
+			}),
+			pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+			pack.Tag{1, pack.EndGroupType},
+		}.Marshal(),
+	},
+	{
+		desc: "MessageSet preserves unknown field",
+		decodeTo: []proto.Message{build(
+			&messagesetpb.MessageSet{},
+			extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+				Ext1Field1: proto.Int32(10),
+			}),
+			unknown(pack.Message{
+				pack.Tag{4, pack.VarintType}, pack.Varint(30),
+			}.Marshal()),
+		)},
+		wire: pack.Message{
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(10),
+			}),
+			pack.Tag{1, pack.EndGroupType},
+			// Unknown field
+			pack.Tag{4, pack.VarintType}, pack.Varint(30),
+		}.Marshal(),
+	},
+	{
+		desc: "MessageSet with unknown type_id",
+		decodeTo: []proto.Message{build(
+			&messagesetpb.MessageSet{},
+			unknown(pack.Message{
+				pack.Tag{1, pack.StartGroupType},
+				pack.Tag{2, pack.VarintType}, pack.Varint(1002),
+				pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+					pack.Tag{1, pack.VarintType}, pack.Varint(10),
+				}),
+				pack.Tag{1, pack.EndGroupType},
+			}.Marshal()),
+		)},
+		wire: pack.Message{
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{2, pack.VarintType}, pack.Varint(1002),
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(10),
+			}),
+			pack.Tag{1, pack.EndGroupType},
+		}.Marshal(),
+	},
+	{
+		desc: "MessageSet merges repeated message fields in item",
+		decodeTo: []proto.Message{build(
+			&messagesetpb.MessageSet{},
+			extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+				Ext1Field1: proto.Int32(10),
+				Ext1Field2: proto.Int32(20),
+			}),
+		)},
+		wire: pack.Message{
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(10),
+			}),
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{2, pack.VarintType}, pack.Varint(20),
+			}),
+			pack.Tag{1, pack.EndGroupType},
+		}.Marshal(),
+	},
+	{
+		desc: "MessageSet merges message fields in repeated items",
+		decodeTo: []proto.Message{build(
+			&messagesetpb.MessageSet{},
+			extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+				Ext1Field1: proto.Int32(10),
+				Ext1Field2: proto.Int32(20),
+			}),
+			extend(msetextpb.E_Ext2_MessageSetExtension, &msetextpb.Ext2{
+				Ext2Field1: proto.Int32(30),
+			}),
+		)},
+		wire: pack.Message{
+			// Ext1, field1
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(10),
+			}),
+			pack.Tag{1, pack.EndGroupType},
+			// Ext2, field1
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{2, pack.VarintType}, pack.Varint(1001),
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(30),
+			}),
+			pack.Tag{1, pack.EndGroupType},
+			// Ext2, field2
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{2, pack.VarintType}, pack.Varint(20),
+			}),
+			pack.Tag{1, pack.EndGroupType},
+		}.Marshal(),
+	},
+	{
+		desc: "MessageSet with missing type_id",
+		decodeTo: []proto.Message{build(
+			&messagesetpb.MessageSet{},
+			unknown(pack.Message{
+				pack.Tag{1, pack.StartGroupType},
+				pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+					pack.Tag{1, pack.VarintType}, pack.Varint(10),
+				}),
+				pack.Tag{1, pack.EndGroupType},
+			}.Marshal()),
+		)},
+		wire: pack.Message{
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(10),
+			}),
+			pack.Tag{1, pack.EndGroupType},
+		}.Marshal(),
+	},
+	{
+		desc: "MessageSet with missing message",
+		decodeTo: []proto.Message{build(
+			&messagesetpb.MessageSet{},
+			extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{}),
+		)},
+		wire: pack.Message{
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+			pack.Tag{1, pack.EndGroupType},
+		}.Marshal(),
+	},
+}
diff --git a/proto/size.go b/proto/size.go
index 1266143..9947580 100644
--- a/proto/size.go
+++ b/proto/size.go
@@ -5,6 +5,7 @@
 package proto
 
 import (
+	"google.golang.org/protobuf/internal/encoding/messageset"
 	"google.golang.org/protobuf/internal/encoding/wire"
 	"google.golang.org/protobuf/reflect/protoreflect"
 	"google.golang.org/protobuf/runtime/protoiface"
@@ -28,6 +29,9 @@
 }
 
 func sizeMessageSlow(m protoreflect.Message) (size int) {
+	if messageset.IsMessageSet(m.Descriptor()) {
+		return sizeMessageSet(m)
+	}
 	m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
 		size += sizeField(fd, v)
 		return true