proto: support message_set_wire_format
MessageSets are a deprecated proto1 feature, long since superseded by
extensions. Add disabled-by-default support behind flags.Proto1Legacy.
Change-Id: I7d3ace07f3b0efd59673034f3dc633b908345a88
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/185538
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/encoding/messageset/messageset.go b/internal/encoding/messageset/messageset.go
new file mode 100644
index 0000000..a8b6c0d
--- /dev/null
+++ b/internal/encoding/messageset/messageset.go
@@ -0,0 +1,138 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+// Package messageset encodes and decodes the obsolete MessageSet wire format.
+package messageset
+
+import (
+ "google.golang.org/protobuf/internal/encoding/wire"
+ "google.golang.org/protobuf/internal/errors"
+ pref "google.golang.org/protobuf/reflect/protoreflect"
+)
+
+// The MessageSet wire format is equivalent to a message defiend as follows,
+// where each Item defines an extension field with a field number of 'type_id'
+// and content of 'message'. MessageSet extensions must be non-repeated message
+// fields.
+//
+// message MessageSet {
+// repeated group Item = 1 {
+// required int32 type_id = 2;
+// required string message = 3;
+// }
+// }
+const (
+ FieldItem = wire.Number(1)
+ FieldTypeID = wire.Number(2)
+ FieldMessage = wire.Number(3)
+)
+
+// IsMessageSet returns whether the message uses the MessageSet wire format.
+func IsMessageSet(md pref.MessageDescriptor) bool {
+ xmd, ok := md.(interface{ IsMessageSet() bool })
+ return ok && xmd.IsMessageSet()
+}
+
+// SizeField returns the size of a MessageSet item field containing an extension
+// with the given field number, not counting the contents of the message subfield.
+func SizeField(num wire.Number) int {
+ return 2*wire.SizeTag(FieldItem) + wire.SizeTag(FieldTypeID) + wire.SizeVarint(uint64(num))
+}
+
+// ConsumeField parses a MessageSet item field and returns the contents of the
+// type_id and message subfields and the total item length.
+func ConsumeField(b []byte) (typeid wire.Number, message []byte, n int, err error) {
+ num, wtyp, n := wire.ConsumeTag(b)
+ if n < 0 {
+ return 0, nil, 0, wire.ParseError(n)
+ }
+ if num != FieldItem || wtyp != wire.StartGroupType {
+ return 0, nil, 0, errors.New("invalid MessageSet field number")
+ }
+ typeid, message, fieldLen, err := ConsumeFieldValue(b[n:], false)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ return typeid, message, n + fieldLen, nil
+}
+
+// ConsumeFieldValue parses b as a MessageSet item field value until and including
+// the trailing end group marker. It assumes the start group tag has already been parsed.
+// It returns the contents of the type_id and message subfields and the total
+// item length.
+//
+// If wantLen is true, the returned message value includes the length prefix.
+// This is ugly, but simplifies the fast-path decoder in internal/impl.
+func ConsumeFieldValue(b []byte, wantLen bool) (typeid wire.Number, message []byte, n int, err error) {
+ ilen := len(b)
+ for {
+ num, wtyp, n := wire.ConsumeTag(b)
+ if n < 0 {
+ return 0, nil, 0, wire.ParseError(n)
+ }
+ b = b[n:]
+ switch {
+ case num == FieldItem && wtyp == wire.EndGroupType:
+ if wantLen && len(message) == 0 {
+ // The message field was missing, which should never happen.
+ // Be prepared for this case anyway.
+ message = wire.AppendVarint(message, 0)
+ }
+ return typeid, message, ilen - len(b), nil
+ case num == FieldTypeID && wtyp == wire.VarintType:
+ v, n := wire.ConsumeVarint(b)
+ if n < 0 {
+ return 0, nil, 0, wire.ParseError(n)
+ }
+ b = b[n:]
+ typeid = wire.Number(v)
+ case num == FieldMessage && wtyp == wire.BytesType:
+ m, n := wire.ConsumeBytes(b)
+ if n < 0 {
+ return 0, nil, 0, wire.ParseError(n)
+ }
+ if message == nil {
+ if wantLen {
+ message = b[:n]
+ } else {
+ message = m
+ }
+ } else {
+ // This case should never happen in practice, but handle it for
+ // correctness: The MessageSet item contains multiple message
+ // fields, which need to be merged.
+ //
+ // In the case where we're returning the length, this becomes
+ // quite inefficient since we need to strip the length off
+ // the existing data and reconstruct it with the combined length.
+ if wantLen {
+ _, nn := wire.ConsumeVarint(message)
+ m0 := message[nn:]
+ message = message[:0]
+ message = wire.AppendVarint(message, uint64(len(m0)+len(m)))
+ message = append(message, m0...)
+ message = append(message, m...)
+ } else {
+ message = append(message, m...)
+ }
+ }
+ b = b[n:]
+ }
+ }
+}
+
+// AppendFieldStart appends the start of a MessageSet item field containing
+// an extension with the given number. The caller must add the message
+// subfield (including the tag).
+func AppendFieldStart(b []byte, num wire.Number) []byte {
+ b = wire.AppendTag(b, FieldItem, wire.StartGroupType)
+ b = wire.AppendTag(b, FieldTypeID, wire.VarintType)
+ b = wire.AppendVarint(b, uint64(num))
+ return b
+}
+
+// AppendFieldEnd appends the trailing end group marker for a MessageSet item field.
+func AppendFieldEnd(b []byte) []byte {
+ return wire.AppendTag(b, FieldItem, wire.EndGroupType)
+}
diff --git a/internal/impl/codec_message.go b/internal/impl/codec_message.go
index 2b8b047..e43c812 100644
--- a/internal/impl/codec_message.go
+++ b/internal/impl/codec_message.go
@@ -5,9 +5,11 @@
package impl
import (
+ "fmt"
"reflect"
"sort"
+ "google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/encoding/wire"
pref "google.golang.org/protobuf/reflect/protoreflect"
piface "google.golang.org/protobuf/runtime/protoiface"
@@ -77,6 +79,21 @@
mi.orderedCoderFields = append(mi.orderedCoderFields, cf)
mi.coderFields[cf.num] = cf
}
+ if messageset.IsMessageSet(mi.PBType.Descriptor()) {
+ if !mi.extensionOffset.IsValid() {
+ panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.PBType.FullName()))
+ }
+ cf := &coderFieldInfo{
+ num: messageset.FieldItem,
+ offset: si.extensionOffset,
+ isPointer: true,
+ funcs: makeMessageSetFieldCoder(mi),
+ }
+ mi.orderedCoderFields = append(mi.orderedCoderFields, cf)
+ mi.coderFields[cf.num] = cf
+ // Invalidate the extension offset, since the field codec handles extensions.
+ mi.extensionOffset = invalidOffset
+ }
sort.Slice(mi.orderedCoderFields, func(i, j int) bool {
return mi.orderedCoderFields[i].num < mi.orderedCoderFields[j].num
})
diff --git a/internal/impl/codec_messageset.go b/internal/impl/codec_messageset.go
new file mode 100644
index 0000000..073c4dc
--- /dev/null
+++ b/internal/impl/codec_messageset.go
@@ -0,0 +1,119 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+package impl
+
+import (
+ "sort"
+
+ "google.golang.org/protobuf/internal/encoding/messageset"
+ "google.golang.org/protobuf/internal/encoding/wire"
+ "google.golang.org/protobuf/internal/errors"
+ "google.golang.org/protobuf/internal/flags"
+)
+
+func makeMessageSetFieldCoder(mi *MessageInfo) pointerCoderFuncs {
+ return pointerCoderFuncs{
+ size: func(p pointer, tagsize int, opts marshalOptions) int {
+ return sizeMessageSet(mi, p, tagsize, opts)
+ },
+ marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
+ return marshalMessageSet(mi, b, p, wiretag, opts)
+ },
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ return unmarshalMessageSet(mi, b, p, wtyp, opts)
+ },
+ }
+}
+
+func sizeMessageSet(mi *MessageInfo, p pointer, tagsize int, opts marshalOptions) (n int) {
+ ext := *p.Extensions()
+ if ext == nil {
+ return 0
+ }
+ for _, x := range ext {
+ xi := mi.extensionFieldInfo(x.GetType())
+ if xi.funcs.size == nil {
+ continue
+ }
+ num, _ := wire.DecodeTag(xi.wiretag)
+ n += messageset.SizeField(num)
+ n += xi.funcs.size(x.GetValue(), wire.SizeTag(messageset.FieldMessage), opts)
+ }
+ return n
+}
+
+func marshalMessageSet(mi *MessageInfo, b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
+ if !flags.Proto1Legacy {
+ return b, errors.New("no support for message_set_wire_format")
+ }
+ ext := *p.Extensions()
+ if ext == nil {
+ return b, nil
+ }
+ switch len(ext) {
+ case 0:
+ return b, nil
+ case 1:
+ // Fast-path for one extension: Don't bother sorting the keys.
+ for _, x := range ext {
+ var err error
+ b, err = marshalMessageSetField(mi, b, x, opts)
+ if err != nil {
+ return b, err
+ }
+ }
+ return b, nil
+ default:
+ // Sort the keys to provide a deterministic encoding.
+ // Not sure this is required, but the old code does it.
+ keys := make([]int, 0, len(ext))
+ for k := range ext {
+ keys = append(keys, int(k))
+ }
+ sort.Ints(keys)
+ for _, k := range keys {
+ var err error
+ b, err = marshalMessageSetField(mi, b, ext[int32(k)], opts)
+ if err != nil {
+ return b, err
+ }
+ }
+ return b, nil
+ }
+}
+
+func marshalMessageSetField(mi *MessageInfo, b []byte, x ExtensionField, opts marshalOptions) ([]byte, error) {
+ xi := mi.extensionFieldInfo(x.GetType())
+ num, _ := wire.DecodeTag(xi.wiretag)
+ b = messageset.AppendFieldStart(b, num)
+ b, err := xi.funcs.marshal(b, x.GetValue(), wire.EncodeTag(messageset.FieldMessage, wire.BytesType), opts)
+ if err != nil {
+ return b, err
+ }
+ b = messageset.AppendFieldEnd(b)
+ return b, nil
+}
+
+func unmarshalMessageSet(mi *MessageInfo, b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ if !flags.Proto1Legacy {
+ return 0, errors.New("no support for message_set_wire_format")
+ }
+ if wtyp != wire.StartGroupType {
+ return 0, errUnknown
+ }
+ ep := p.Extensions()
+ if *ep == nil {
+ *ep = make(map[int32]ExtensionField)
+ }
+ ext := *ep
+ num, v, n, err := messageset.ConsumeFieldValue(b, true)
+ if err != nil {
+ return 0, err
+ }
+ if _, err := mi.unmarshalExtension(v, num, wire.BytesType, ext, opts); err != nil {
+ return 0, err
+ }
+ return n, nil
+}
diff --git a/internal/testprotos/messageset/messagesetpb/message_set.pb.go b/internal/testprotos/messageset/messagesetpb/message_set.pb.go
new file mode 100644
index 0000000..aaa14da
--- /dev/null
+++ b/internal/testprotos/messageset/messagesetpb/message_set.pb.go
@@ -0,0 +1,140 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// source: messageset/messagesetpb/message_set.proto
+
+package messagesetpb
+
+import (
+ protoreflect "google.golang.org/protobuf/reflect/protoreflect"
+ protoiface "google.golang.org/protobuf/runtime/protoiface"
+ protoimpl "google.golang.org/protobuf/runtime/protoimpl"
+ sync "sync"
+)
+
+const (
+ // Verify that runtime/protoimpl is sufficiently up-to-date.
+ _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 0)
+ // Verify that this generated code is sufficiently up-to-date.
+ _ = protoimpl.EnforceVersion(0 - protoimpl.MinVersion)
+)
+
+type MessageSet struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+ extensionFields protoimpl.ExtensionFields
+}
+
+func (x *MessageSet) Reset() {
+ *x = MessageSet{}
+}
+
+func (x *MessageSet) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*MessageSet) ProtoMessage() {}
+
+func (x *MessageSet) ProtoReflect() protoreflect.Message {
+ mi := &file_messageset_messagesetpb_message_set_proto_msgTypes[0]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use MessageSet.ProtoReflect.Type instead.
+func (*MessageSet) Descriptor() ([]byte, []int) {
+ return file_messageset_messagesetpb_message_set_proto_rawDescGZIP(), []int{0}
+}
+
+var extRange_MessageSet = []protoiface.ExtensionRangeV1{
+ {Start: 4, End: 2147483646},
+}
+
+// Deprecated: Use MessageSet.ProtoReflect.Type.ExtensionRanges instead.
+func (*MessageSet) ExtensionRangeArray() []protoiface.ExtensionRangeV1 {
+ return extRange_MessageSet
+}
+
+var File_messageset_messagesetpb_message_set_proto protoreflect.FileDescriptor
+
+var file_messageset_messagesetpb_message_set_proto_rawDesc = []byte{
+ 0x0a, 0x29, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2f, 0x6d, 0x65, 0x73,
+ 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x70, 0x62, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67,
+ 0x65, 0x5f, 0x73, 0x65, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x18, 0x67, 0x6f, 0x70,
+ 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x65, 0x73, 0x73, 0x61,
+ 0x67, 0x65, 0x73, 0x65, 0x74, 0x22, 0x1a, 0x0a, 0x0a, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
+ 0x53, 0x65, 0x74, 0x2a, 0x08, 0x08, 0x04, 0x10, 0xff, 0xff, 0xff, 0xff, 0x07, 0x3a, 0x02, 0x08,
+ 0x01, 0x42, 0x48, 0x5a, 0x46, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x67, 0x6f, 0x6c, 0x61,
+ 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f,
+ 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f,
+ 0x74, 0x6f, 0x73, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2f, 0x6d,
+ 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x70, 0x62,
+}
+
+var (
+ file_messageset_messagesetpb_message_set_proto_rawDescOnce sync.Once
+ file_messageset_messagesetpb_message_set_proto_rawDescData = file_messageset_messagesetpb_message_set_proto_rawDesc
+)
+
+func file_messageset_messagesetpb_message_set_proto_rawDescGZIP() []byte {
+ file_messageset_messagesetpb_message_set_proto_rawDescOnce.Do(func() {
+ file_messageset_messagesetpb_message_set_proto_rawDescData = protoimpl.X.CompressGZIP(file_messageset_messagesetpb_message_set_proto_rawDescData)
+ })
+ return file_messageset_messagesetpb_message_set_proto_rawDescData
+}
+
+var file_messageset_messagesetpb_message_set_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
+var file_messageset_messagesetpb_message_set_proto_goTypes = []interface{}{
+ (*MessageSet)(nil), // 0: goproto.proto.messageset.MessageSet
+}
+var file_messageset_messagesetpb_message_set_proto_depIdxs = []int32{
+ 0, // starting offset of method output_type sub-list
+ 0, // starting offset of method input_type sub-list
+ 0, // starting offset of extension type_name sub-list
+ 0, // starting offset of extension extendee sub-list
+ 0, // starting offset of field type_name sub-list
+}
+
+func init() { file_messageset_messagesetpb_message_set_proto_init() }
+func file_messageset_messagesetpb_message_set_proto_init() {
+ if File_messageset_messagesetpb_message_set_proto != nil {
+ return
+ }
+ if !protoimpl.UnsafeEnabled {
+ file_messageset_messagesetpb_message_set_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*MessageSet); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ case 3:
+ return &v.extensionFields
+ default:
+ return nil
+ }
+ }
+ }
+ out := protoimpl.TypeBuilder{
+ File: protoimpl.DescBuilder{
+ RawDescriptor: file_messageset_messagesetpb_message_set_proto_rawDesc,
+ NumEnums: 0,
+ NumMessages: 1,
+ NumExtensions: 0,
+ NumServices: 0,
+ },
+ GoTypes: file_messageset_messagesetpb_message_set_proto_goTypes,
+ DependencyIndexes: file_messageset_messagesetpb_message_set_proto_depIdxs,
+ MessageInfos: file_messageset_messagesetpb_message_set_proto_msgTypes,
+ }.Build()
+ File_messageset_messagesetpb_message_set_proto = out.File
+ file_messageset_messagesetpb_message_set_proto_rawDesc = nil
+ file_messageset_messagesetpb_message_set_proto_goTypes = nil
+ file_messageset_messagesetpb_message_set_proto_depIdxs = nil
+}
diff --git a/internal/testprotos/messageset/messagesetpb/message_set.proto b/internal/testprotos/messageset/messagesetpb/message_set.proto
new file mode 100644
index 0000000..08f7a4d
--- /dev/null
+++ b/internal/testprotos/messageset/messagesetpb/message_set.proto
@@ -0,0 +1,14 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+syntax = "proto2";
+
+package goproto.proto.messageset;
+
+option go_package = "google.golang.org/protobuf/internal/testprotos/messageset/messagesetpb";
+
+message MessageSet {
+ option message_set_wire_format = true;
+ extensions 4 to max;
+}
diff --git a/internal/testprotos/messageset/msetextpb/msetextpb.pb.go b/internal/testprotos/messageset/msetextpb/msetextpb.pb.go
new file mode 100644
index 0000000..9af5304
--- /dev/null
+++ b/internal/testprotos/messageset/msetextpb/msetextpb.pb.go
@@ -0,0 +1,253 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// source: messageset/msetextpb/msetextpb.proto
+
+package msetextpb
+
+import (
+ messagesetpb "google.golang.org/protobuf/internal/testprotos/messageset/messagesetpb"
+ protoreflect "google.golang.org/protobuf/reflect/protoreflect"
+ protoiface "google.golang.org/protobuf/runtime/protoiface"
+ protoimpl "google.golang.org/protobuf/runtime/protoimpl"
+ sync "sync"
+)
+
+const (
+ // Verify that runtime/protoimpl is sufficiently up-to-date.
+ _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 0)
+ // Verify that this generated code is sufficiently up-to-date.
+ _ = protoimpl.EnforceVersion(0 - protoimpl.MinVersion)
+)
+
+type Ext1 struct {
+ state protoimpl.MessageState
+ Ext1Field1 *int32 `protobuf:"varint,1,opt,name=ext1_field1,json=ext1Field1" json:"ext1_field1,omitempty"`
+ Ext1Field2 *int32 `protobuf:"varint,2,opt,name=ext1_field2,json=ext1Field2" json:"ext1_field2,omitempty"`
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+}
+
+func (x *Ext1) Reset() {
+ *x = Ext1{}
+}
+
+func (x *Ext1) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*Ext1) ProtoMessage() {}
+
+func (x *Ext1) ProtoReflect() protoreflect.Message {
+ mi := &file_messageset_msetextpb_msetextpb_proto_msgTypes[0]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use Ext1.ProtoReflect.Type instead.
+func (*Ext1) Descriptor() ([]byte, []int) {
+ return file_messageset_msetextpb_msetextpb_proto_rawDescGZIP(), []int{0}
+}
+
+func (x *Ext1) GetExt1Field1() int32 {
+ if x != nil && x.Ext1Field1 != nil {
+ return *x.Ext1Field1
+ }
+ return 0
+}
+
+func (x *Ext1) GetExt1Field2() int32 {
+ if x != nil && x.Ext1Field2 != nil {
+ return *x.Ext1Field2
+ }
+ return 0
+}
+
+type Ext2 struct {
+ state protoimpl.MessageState
+ Ext2Field1 *int32 `protobuf:"varint,1,opt,name=ext2_field1,json=ext2Field1" json:"ext2_field1,omitempty"`
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+}
+
+func (x *Ext2) Reset() {
+ *x = Ext2{}
+}
+
+func (x *Ext2) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*Ext2) ProtoMessage() {}
+
+func (x *Ext2) ProtoReflect() protoreflect.Message {
+ mi := &file_messageset_msetextpb_msetextpb_proto_msgTypes[1]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use Ext2.ProtoReflect.Type instead.
+func (*Ext2) Descriptor() ([]byte, []int) {
+ return file_messageset_msetextpb_msetextpb_proto_rawDescGZIP(), []int{1}
+}
+
+func (x *Ext2) GetExt2Field1() int32 {
+ if x != nil && x.Ext2Field1 != nil {
+ return *x.Ext2Field1
+ }
+ return 0
+}
+
+var file_messageset_msetextpb_msetextpb_proto_extDescs = []protoiface.ExtensionDescV1{
+ {
+ ExtendedType: (*messagesetpb.MessageSet)(nil),
+ ExtensionType: (*Ext1)(nil),
+ Field: 1000,
+ Name: "goproto.proto.messageset.Ext1.message_set_extension",
+ Tag: "bytes,1000,opt,name=message_set_extension",
+ Filename: "messageset/msetextpb/msetextpb.proto",
+ },
+ {
+ ExtendedType: (*messagesetpb.MessageSet)(nil),
+ ExtensionType: (*Ext2)(nil),
+ Field: 1001,
+ Name: "goproto.proto.messageset.Ext2.message_set_extension",
+ Tag: "bytes,1001,opt,name=message_set_extension",
+ Filename: "messageset/msetextpb/msetextpb.proto",
+ },
+}
+var (
+ // extend goproto.proto.messageset.MessageSet { optional goproto.proto.messageset.Ext1 message_set_extension = 1000; }
+ E_Ext1_MessageSetExtension = &file_messageset_msetextpb_msetextpb_proto_extDescs[0]
+
+ // extend goproto.proto.messageset.MessageSet { optional goproto.proto.messageset.Ext2 message_set_extension = 1001; }
+ E_Ext2_MessageSetExtension = &file_messageset_msetextpb_msetextpb_proto_extDescs[1]
+)
+var File_messageset_msetextpb_msetextpb_proto protoreflect.FileDescriptor
+
+var file_messageset_msetextpb_msetextpb_proto_rawDesc = []byte{
+ 0x0a, 0x24, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2f, 0x6d, 0x73, 0x65,
+ 0x74, 0x65, 0x78, 0x74, 0x70, 0x62, 0x2f, 0x6d, 0x73, 0x65, 0x74, 0x65, 0x78, 0x74, 0x70, 0x62,
+ 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x18, 0x67, 0x6f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e,
+ 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74,
+ 0x1a, 0x29, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2f, 0x6d, 0x65, 0x73,
+ 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x70, 0x62, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67,
+ 0x65, 0x5f, 0x73, 0x65, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xc3, 0x01, 0x0a, 0x04,
+ 0x45, 0x78, 0x74, 0x31, 0x12, 0x1f, 0x0a, 0x0b, 0x65, 0x78, 0x74, 0x31, 0x5f, 0x66, 0x69, 0x65,
+ 0x6c, 0x64, 0x31, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x65, 0x78, 0x74, 0x31, 0x46,
+ 0x69, 0x65, 0x6c, 0x64, 0x31, 0x12, 0x1f, 0x0a, 0x0b, 0x65, 0x78, 0x74, 0x31, 0x5f, 0x66, 0x69,
+ 0x65, 0x6c, 0x64, 0x32, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x65, 0x78, 0x74, 0x31,
+ 0x46, 0x69, 0x65, 0x6c, 0x64, 0x32, 0x32, 0x79, 0x0a, 0x15, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67,
+ 0x65, 0x5f, 0x73, 0x65, 0x74, 0x5f, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x12,
+ 0x24, 0x2e, 0x67, 0x6f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e,
+ 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61,
+ 0x67, 0x65, 0x53, 0x65, 0x74, 0x18, 0xe8, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x67,
+ 0x6f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x65, 0x73,
+ 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2e, 0x45, 0x78, 0x74, 0x31, 0x52, 0x13, 0x6d, 0x65,
+ 0x73, 0x73, 0x61, 0x67, 0x65, 0x53, 0x65, 0x74, 0x45, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f,
+ 0x6e, 0x22, 0xa2, 0x01, 0x0a, 0x04, 0x45, 0x78, 0x74, 0x32, 0x12, 0x1f, 0x0a, 0x0b, 0x65, 0x78,
+ 0x74, 0x32, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52,
+ 0x0a, 0x65, 0x78, 0x74, 0x32, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x31, 0x32, 0x79, 0x0a, 0x15, 0x6d,
+ 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x5f, 0x73, 0x65, 0x74, 0x5f, 0x65, 0x78, 0x74, 0x65, 0x6e,
+ 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x24, 0x2e, 0x67, 0x6f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70,
+ 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2e,
+ 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x53, 0x65, 0x74, 0x18, 0xe9, 0x07, 0x20, 0x01, 0x28,
+ 0x0b, 0x32, 0x1e, 0x2e, 0x67, 0x6f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74,
+ 0x6f, 0x2e, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x65, 0x74, 0x2e, 0x45, 0x78, 0x74,
+ 0x32, 0x52, 0x13, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x53, 0x65, 0x74, 0x45, 0x78, 0x74,
+ 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x42, 0x45, 0x5a, 0x43, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65,
+ 0x2e, 0x67, 0x6f, 0x6c, 0x61, 0x6e, 0x67, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x70, 0x72, 0x6f, 0x74,
+ 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x74, 0x65,
+ 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
+ 0x73, 0x65, 0x74, 0x2f, 0x6d, 0x73, 0x65, 0x74, 0x65, 0x78, 0x74, 0x70, 0x62,
+}
+
+var (
+ file_messageset_msetextpb_msetextpb_proto_rawDescOnce sync.Once
+ file_messageset_msetextpb_msetextpb_proto_rawDescData = file_messageset_msetextpb_msetextpb_proto_rawDesc
+)
+
+func file_messageset_msetextpb_msetextpb_proto_rawDescGZIP() []byte {
+ file_messageset_msetextpb_msetextpb_proto_rawDescOnce.Do(func() {
+ file_messageset_msetextpb_msetextpb_proto_rawDescData = protoimpl.X.CompressGZIP(file_messageset_msetextpb_msetextpb_proto_rawDescData)
+ })
+ return file_messageset_msetextpb_msetextpb_proto_rawDescData
+}
+
+var file_messageset_msetextpb_msetextpb_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
+var file_messageset_msetextpb_msetextpb_proto_goTypes = []interface{}{
+ (*Ext1)(nil), // 0: goproto.proto.messageset.Ext1
+ (*Ext2)(nil), // 1: goproto.proto.messageset.Ext2
+ (*messagesetpb.MessageSet)(nil), // 2: goproto.proto.messageset.MessageSet
+}
+var file_messageset_msetextpb_msetextpb_proto_depIdxs = []int32{
+ 2, // goproto.proto.messageset.Ext1.message_set_extension:extendee -> goproto.proto.messageset.MessageSet
+ 2, // goproto.proto.messageset.Ext2.message_set_extension:extendee -> goproto.proto.messageset.MessageSet
+ 0, // goproto.proto.messageset.Ext1.message_set_extension:type_name -> goproto.proto.messageset.Ext1
+ 1, // goproto.proto.messageset.Ext2.message_set_extension:type_name -> goproto.proto.messageset.Ext2
+ 4, // starting offset of method output_type sub-list
+ 4, // starting offset of method input_type sub-list
+ 2, // starting offset of extension type_name sub-list
+ 0, // starting offset of extension extendee sub-list
+ 0, // starting offset of field type_name sub-list
+}
+
+func init() { file_messageset_msetextpb_msetextpb_proto_init() }
+func file_messageset_msetextpb_msetextpb_proto_init() {
+ if File_messageset_msetextpb_msetextpb_proto != nil {
+ return
+ }
+ if !protoimpl.UnsafeEnabled {
+ file_messageset_msetextpb_msetextpb_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*Ext1); i {
+ case 0:
+ return &v.state
+ case 3:
+ return &v.sizeCache
+ case 4:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ file_messageset_msetextpb_msetextpb_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*Ext2); i {
+ case 0:
+ return &v.state
+ case 2:
+ return &v.sizeCache
+ case 3:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ }
+ out := protoimpl.TypeBuilder{
+ File: protoimpl.DescBuilder{
+ RawDescriptor: file_messageset_msetextpb_msetextpb_proto_rawDesc,
+ NumEnums: 0,
+ NumMessages: 2,
+ NumExtensions: 2,
+ NumServices: 0,
+ },
+ GoTypes: file_messageset_msetextpb_msetextpb_proto_goTypes,
+ DependencyIndexes: file_messageset_msetextpb_msetextpb_proto_depIdxs,
+ MessageInfos: file_messageset_msetextpb_msetextpb_proto_msgTypes,
+ LegacyExtensions: file_messageset_msetextpb_msetextpb_proto_extDescs,
+ }.Build()
+ File_messageset_msetextpb_msetextpb_proto = out.File
+ file_messageset_msetextpb_msetextpb_proto_rawDesc = nil
+ file_messageset_msetextpb_msetextpb_proto_goTypes = nil
+ file_messageset_msetextpb_msetextpb_proto_depIdxs = nil
+}
diff --git a/internal/testprotos/messageset/msetextpb/msetextpb.proto b/internal/testprotos/messageset/msetextpb/msetextpb.proto
new file mode 100644
index 0000000..5d1bf08
--- /dev/null
+++ b/internal/testprotos/messageset/msetextpb/msetextpb.proto
@@ -0,0 +1,26 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+syntax = "proto2";
+
+package goproto.proto.messageset;
+
+option go_package = "google.golang.org/protobuf/internal/testprotos/messageset/msetextpb";
+
+import "messageset/messagesetpb/message_set.proto";
+
+message Ext1 {
+ extend MessageSet {
+ optional Ext1 message_set_extension = 1000;
+ }
+ optional int32 ext1_field1 = 1;
+ optional int32 ext1_field2 = 2;
+}
+
+message Ext2 {
+ extend MessageSet {
+ optional Ext2 message_set_extension = 1001;
+ }
+ optional int32 ext2_field1 = 1;
+}
diff --git a/proto/decode.go b/proto/decode.go
index 0b98366..f147e68 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -5,6 +5,7 @@
package proto
import (
+ "google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/pragma"
@@ -68,8 +69,11 @@
}
func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
- messageDesc := m.Descriptor()
- fieldDescs := messageDesc.Fields()
+ md := m.Descriptor()
+ if messageset.IsMessageSet(md) {
+ return unmarshalMessageSet(b, m, o)
+ }
+ fields := md.Fields()
for len(b) > 0 {
// Parse the tag (field number and wire type).
num, wtyp, tagLen := wire.ConsumeTag(b)
@@ -78,9 +82,9 @@
}
// Parse the field value.
- fd := fieldDescs.ByNumber(num)
- if fd == nil && messageDesc.ExtensionRanges().Has(num) {
- extType, err := o.Resolver.FindExtensionByNumber(messageDesc.FullName(), num)
+ fd := fields.ByNumber(num)
+ if fd == nil && md.ExtensionRanges().Has(num) {
+ extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
if err != nil && err != protoregistry.NotFound {
return err
}
diff --git a/proto/encode.go b/proto/encode.go
index bc7da0c..5511d16 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -7,6 +7,7 @@
import (
"sort"
+ "google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/mapsort"
"google.golang.org/protobuf/internal/pragma"
@@ -111,6 +112,9 @@
}
func (o MarshalOptions) marshalMessageSlow(b []byte, m protoreflect.Message) ([]byte, error) {
+ if messageset.IsMessageSet(m.Descriptor()) {
+ return marshalMessageSet(b, m, o)
+ }
// There are many choices for what order we visit fields in. The default one here
// is chosen for reasonable efficiency and simplicity given the protoreflect API.
// It is not deterministic, since Message.Range does not return fields in any
diff --git a/proto/messageset.go b/proto/messageset.go
new file mode 100644
index 0000000..1c6ac29
--- /dev/null
+++ b/proto/messageset.go
@@ -0,0 +1,102 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+package proto
+
+import (
+ "google.golang.org/protobuf/internal/encoding/messageset"
+ "google.golang.org/protobuf/internal/encoding/wire"
+ "google.golang.org/protobuf/internal/errors"
+ "google.golang.org/protobuf/internal/flags"
+ "google.golang.org/protobuf/reflect/protoreflect"
+ "google.golang.org/protobuf/reflect/protoregistry"
+)
+
+func sizeMessageSet(m protoreflect.Message) (size int) {
+ m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
+ size += messageset.SizeField(fd.Number())
+ size += wire.SizeTag(messageset.FieldMessage)
+ size += wire.SizeBytes(sizeMessage(v.Message()))
+ return true
+ })
+ size += len(m.GetUnknown())
+ return size
+}
+
+func marshalMessageSet(b []byte, m protoreflect.Message, o MarshalOptions) ([]byte, error) {
+ if !flags.Proto1Legacy {
+ return b, errors.New("no support for message_set_wire_format")
+ }
+ var err error
+ o.rangeFields(m, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
+ b, err = marshalMessageSetField(b, fd, v, o)
+ return err == nil
+ })
+ if err != nil {
+ return b, err
+ }
+ b = append(b, m.GetUnknown()...)
+ return b, nil
+}
+
+func marshalMessageSetField(b []byte, fd protoreflect.FieldDescriptor, value protoreflect.Value, o MarshalOptions) ([]byte, error) {
+ b = messageset.AppendFieldStart(b, fd.Number())
+ b = wire.AppendTag(b, messageset.FieldMessage, wire.BytesType)
+ b = wire.AppendVarint(b, uint64(o.Size(value.Message().Interface())))
+ b, err := o.marshalMessage(b, value.Message())
+ if err != nil {
+ return b, err
+ }
+ b = messageset.AppendFieldEnd(b)
+ return b, nil
+}
+
+func unmarshalMessageSet(b []byte, m protoreflect.Message, o UnmarshalOptions) error {
+ if !flags.Proto1Legacy {
+ return errors.New("no support for message_set_wire_format")
+ }
+ md := m.Descriptor()
+ for len(b) > 0 {
+ err := func() error {
+ num, v, n, err := messageset.ConsumeField(b)
+ if err != nil {
+ // Not a message set field.
+ //
+ // Return errUnknown to try to add this to the unknown fields.
+ // If the field is completely unparsable, we'll catch it
+ // when trying to skip the field.
+ return errUnknown
+ }
+ if !md.ExtensionRanges().Has(num) {
+ return errUnknown
+ }
+ fd, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
+ if err == protoregistry.NotFound {
+ return errUnknown
+ }
+ if err != nil {
+ return err
+ }
+ if err := o.unmarshalMessage(v, m.Mutable(fd).Message()); err != nil {
+ // Contents cannot be unmarshaled.
+ return err
+ }
+ b = b[n:]
+ return nil
+ }()
+ if err == errUnknown {
+ _, _, n := wire.ConsumeField(b)
+ if n < 0 {
+ return wire.ParseError(n)
+ }
+ m.SetUnknown(append(m.GetUnknown(), b[:n]...))
+ b = b[n:]
+ continue
+ }
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/proto/messageset_test.go b/proto/messageset_test.go
new file mode 100644
index 0000000..07b2d70
--- /dev/null
+++ b/proto/messageset_test.go
@@ -0,0 +1,190 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+package proto_test
+
+import (
+ "google.golang.org/protobuf/internal/encoding/pack"
+ "google.golang.org/protobuf/internal/flags"
+ "google.golang.org/protobuf/proto"
+
+ messagesetpb "google.golang.org/protobuf/internal/testprotos/messageset/messagesetpb"
+ msetextpb "google.golang.org/protobuf/internal/testprotos/messageset/msetextpb"
+)
+
+func init() {
+ if flags.Proto1Legacy {
+ testProtos = append(testProtos, messageSetTestProtos...)
+ }
+}
+
+var messageSetTestProtos = []testProto{
+ {
+ desc: "MessageSet type_id before message content",
+ decodeTo: []proto.Message{build(
+ &messagesetpb.MessageSet{},
+ extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+ Ext1Field1: proto.Int32(10),
+ }),
+ )},
+ wire: pack.Message{
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(10),
+ }),
+ pack.Tag{1, pack.EndGroupType},
+ }.Marshal(),
+ },
+ {
+ desc: "MessageSet type_id after message content",
+ decodeTo: []proto.Message{build(
+ &messagesetpb.MessageSet{},
+ extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+ Ext1Field1: proto.Int32(10),
+ }),
+ )},
+ wire: pack.Message{
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(10),
+ }),
+ pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+ pack.Tag{1, pack.EndGroupType},
+ }.Marshal(),
+ },
+ {
+ desc: "MessageSet preserves unknown field",
+ decodeTo: []proto.Message{build(
+ &messagesetpb.MessageSet{},
+ extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+ Ext1Field1: proto.Int32(10),
+ }),
+ unknown(pack.Message{
+ pack.Tag{4, pack.VarintType}, pack.Varint(30),
+ }.Marshal()),
+ )},
+ wire: pack.Message{
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(10),
+ }),
+ pack.Tag{1, pack.EndGroupType},
+ // Unknown field
+ pack.Tag{4, pack.VarintType}, pack.Varint(30),
+ }.Marshal(),
+ },
+ {
+ desc: "MessageSet with unknown type_id",
+ decodeTo: []proto.Message{build(
+ &messagesetpb.MessageSet{},
+ unknown(pack.Message{
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{2, pack.VarintType}, pack.Varint(1002),
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(10),
+ }),
+ pack.Tag{1, pack.EndGroupType},
+ }.Marshal()),
+ )},
+ wire: pack.Message{
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{2, pack.VarintType}, pack.Varint(1002),
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(10),
+ }),
+ pack.Tag{1, pack.EndGroupType},
+ }.Marshal(),
+ },
+ {
+ desc: "MessageSet merges repeated message fields in item",
+ decodeTo: []proto.Message{build(
+ &messagesetpb.MessageSet{},
+ extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+ Ext1Field1: proto.Int32(10),
+ Ext1Field2: proto.Int32(20),
+ }),
+ )},
+ wire: pack.Message{
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(10),
+ }),
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{2, pack.VarintType}, pack.Varint(20),
+ }),
+ pack.Tag{1, pack.EndGroupType},
+ }.Marshal(),
+ },
+ {
+ desc: "MessageSet merges message fields in repeated items",
+ decodeTo: []proto.Message{build(
+ &messagesetpb.MessageSet{},
+ extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+ Ext1Field1: proto.Int32(10),
+ Ext1Field2: proto.Int32(20),
+ }),
+ extend(msetextpb.E_Ext2_MessageSetExtension, &msetextpb.Ext2{
+ Ext2Field1: proto.Int32(30),
+ }),
+ )},
+ wire: pack.Message{
+ // Ext1, field1
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(10),
+ }),
+ pack.Tag{1, pack.EndGroupType},
+ // Ext2, field1
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{2, pack.VarintType}, pack.Varint(1001),
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(30),
+ }),
+ pack.Tag{1, pack.EndGroupType},
+ // Ext2, field2
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{2, pack.VarintType}, pack.Varint(20),
+ }),
+ pack.Tag{1, pack.EndGroupType},
+ }.Marshal(),
+ },
+ {
+ desc: "MessageSet with missing type_id",
+ decodeTo: []proto.Message{build(
+ &messagesetpb.MessageSet{},
+ unknown(pack.Message{
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(10),
+ }),
+ pack.Tag{1, pack.EndGroupType},
+ }.Marshal()),
+ )},
+ wire: pack.Message{
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(10),
+ }),
+ pack.Tag{1, pack.EndGroupType},
+ }.Marshal(),
+ },
+ {
+ desc: "MessageSet with missing message",
+ decodeTo: []proto.Message{build(
+ &messagesetpb.MessageSet{},
+ extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{}),
+ )},
+ wire: pack.Message{
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+ pack.Tag{1, pack.EndGroupType},
+ }.Marshal(),
+ },
+}
diff --git a/proto/size.go b/proto/size.go
index 1266143..9947580 100644
--- a/proto/size.go
+++ b/proto/size.go
@@ -5,6 +5,7 @@
package proto
import (
+ "google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"
@@ -28,6 +29,9 @@
}
func sizeMessageSlow(m protoreflect.Message) (size int) {
+ if messageset.IsMessageSet(m.Descriptor()) {
+ return sizeMessageSet(m)
+ }
m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
size += sizeField(fd, v)
return true