proto: support message_set_wire_format
MessageSets are a deprecated proto1 feature, long since superseded by
extensions. Add disabled-by-default support behind flags.Proto1Legacy.
Change-Id: I7d3ace07f3b0efd59673034f3dc633b908345a88
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/185538
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/proto/decode.go b/proto/decode.go
index 0b98366..f147e68 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -5,6 +5,7 @@
package proto
import (
+ "google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/pragma"
@@ -68,8 +69,11 @@
}
func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
- messageDesc := m.Descriptor()
- fieldDescs := messageDesc.Fields()
+ md := m.Descriptor()
+ if messageset.IsMessageSet(md) {
+ return unmarshalMessageSet(b, m, o)
+ }
+ fields := md.Fields()
for len(b) > 0 {
// Parse the tag (field number and wire type).
num, wtyp, tagLen := wire.ConsumeTag(b)
@@ -78,9 +82,9 @@
}
// Parse the field value.
- fd := fieldDescs.ByNumber(num)
- if fd == nil && messageDesc.ExtensionRanges().Has(num) {
- extType, err := o.Resolver.FindExtensionByNumber(messageDesc.FullName(), num)
+ fd := fields.ByNumber(num)
+ if fd == nil && md.ExtensionRanges().Has(num) {
+ extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
if err != nil && err != protoregistry.NotFound {
return err
}
diff --git a/proto/encode.go b/proto/encode.go
index bc7da0c..5511d16 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -7,6 +7,7 @@
import (
"sort"
+ "google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/mapsort"
"google.golang.org/protobuf/internal/pragma"
@@ -111,6 +112,9 @@
}
func (o MarshalOptions) marshalMessageSlow(b []byte, m protoreflect.Message) ([]byte, error) {
+ if messageset.IsMessageSet(m.Descriptor()) {
+ return marshalMessageSet(b, m, o)
+ }
// There are many choices for what order we visit fields in. The default one here
// is chosen for reasonable efficiency and simplicity given the protoreflect API.
// It is not deterministic, since Message.Range does not return fields in any
diff --git a/proto/messageset.go b/proto/messageset.go
new file mode 100644
index 0000000..1c6ac29
--- /dev/null
+++ b/proto/messageset.go
@@ -0,0 +1,102 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+package proto
+
+import (
+ "google.golang.org/protobuf/internal/encoding/messageset"
+ "google.golang.org/protobuf/internal/encoding/wire"
+ "google.golang.org/protobuf/internal/errors"
+ "google.golang.org/protobuf/internal/flags"
+ "google.golang.org/protobuf/reflect/protoreflect"
+ "google.golang.org/protobuf/reflect/protoregistry"
+)
+
+func sizeMessageSet(m protoreflect.Message) (size int) {
+ m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
+ size += messageset.SizeField(fd.Number())
+ size += wire.SizeTag(messageset.FieldMessage)
+ size += wire.SizeBytes(sizeMessage(v.Message()))
+ return true
+ })
+ size += len(m.GetUnknown())
+ return size
+}
+
+func marshalMessageSet(b []byte, m protoreflect.Message, o MarshalOptions) ([]byte, error) {
+ if !flags.Proto1Legacy {
+ return b, errors.New("no support for message_set_wire_format")
+ }
+ var err error
+ o.rangeFields(m, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
+ b, err = marshalMessageSetField(b, fd, v, o)
+ return err == nil
+ })
+ if err != nil {
+ return b, err
+ }
+ b = append(b, m.GetUnknown()...)
+ return b, nil
+}
+
+func marshalMessageSetField(b []byte, fd protoreflect.FieldDescriptor, value protoreflect.Value, o MarshalOptions) ([]byte, error) {
+ b = messageset.AppendFieldStart(b, fd.Number())
+ b = wire.AppendTag(b, messageset.FieldMessage, wire.BytesType)
+ b = wire.AppendVarint(b, uint64(o.Size(value.Message().Interface())))
+ b, err := o.marshalMessage(b, value.Message())
+ if err != nil {
+ return b, err
+ }
+ b = messageset.AppendFieldEnd(b)
+ return b, nil
+}
+
+func unmarshalMessageSet(b []byte, m protoreflect.Message, o UnmarshalOptions) error {
+ if !flags.Proto1Legacy {
+ return errors.New("no support for message_set_wire_format")
+ }
+ md := m.Descriptor()
+ for len(b) > 0 {
+ err := func() error {
+ num, v, n, err := messageset.ConsumeField(b)
+ if err != nil {
+ // Not a message set field.
+ //
+ // Return errUnknown to try to add this to the unknown fields.
+ // If the field is completely unparsable, we'll catch it
+ // when trying to skip the field.
+ return errUnknown
+ }
+ if !md.ExtensionRanges().Has(num) {
+ return errUnknown
+ }
+ fd, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
+ if err == protoregistry.NotFound {
+ return errUnknown
+ }
+ if err != nil {
+ return err
+ }
+ if err := o.unmarshalMessage(v, m.Mutable(fd).Message()); err != nil {
+ // Contents cannot be unmarshaled.
+ return err
+ }
+ b = b[n:]
+ return nil
+ }()
+ if err == errUnknown {
+ _, _, n := wire.ConsumeField(b)
+ if n < 0 {
+ return wire.ParseError(n)
+ }
+ m.SetUnknown(append(m.GetUnknown(), b[:n]...))
+ b = b[n:]
+ continue
+ }
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/proto/messageset_test.go b/proto/messageset_test.go
new file mode 100644
index 0000000..07b2d70
--- /dev/null
+++ b/proto/messageset_test.go
@@ -0,0 +1,190 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+package proto_test
+
+import (
+ "google.golang.org/protobuf/internal/encoding/pack"
+ "google.golang.org/protobuf/internal/flags"
+ "google.golang.org/protobuf/proto"
+
+ messagesetpb "google.golang.org/protobuf/internal/testprotos/messageset/messagesetpb"
+ msetextpb "google.golang.org/protobuf/internal/testprotos/messageset/msetextpb"
+)
+
+func init() {
+ if flags.Proto1Legacy {
+ testProtos = append(testProtos, messageSetTestProtos...)
+ }
+}
+
+var messageSetTestProtos = []testProto{
+ {
+ desc: "MessageSet type_id before message content",
+ decodeTo: []proto.Message{build(
+ &messagesetpb.MessageSet{},
+ extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+ Ext1Field1: proto.Int32(10),
+ }),
+ )},
+ wire: pack.Message{
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(10),
+ }),
+ pack.Tag{1, pack.EndGroupType},
+ }.Marshal(),
+ },
+ {
+ desc: "MessageSet type_id after message content",
+ decodeTo: []proto.Message{build(
+ &messagesetpb.MessageSet{},
+ extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+ Ext1Field1: proto.Int32(10),
+ }),
+ )},
+ wire: pack.Message{
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(10),
+ }),
+ pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+ pack.Tag{1, pack.EndGroupType},
+ }.Marshal(),
+ },
+ {
+ desc: "MessageSet preserves unknown field",
+ decodeTo: []proto.Message{build(
+ &messagesetpb.MessageSet{},
+ extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+ Ext1Field1: proto.Int32(10),
+ }),
+ unknown(pack.Message{
+ pack.Tag{4, pack.VarintType}, pack.Varint(30),
+ }.Marshal()),
+ )},
+ wire: pack.Message{
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(10),
+ }),
+ pack.Tag{1, pack.EndGroupType},
+ // Unknown field
+ pack.Tag{4, pack.VarintType}, pack.Varint(30),
+ }.Marshal(),
+ },
+ {
+ desc: "MessageSet with unknown type_id",
+ decodeTo: []proto.Message{build(
+ &messagesetpb.MessageSet{},
+ unknown(pack.Message{
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{2, pack.VarintType}, pack.Varint(1002),
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(10),
+ }),
+ pack.Tag{1, pack.EndGroupType},
+ }.Marshal()),
+ )},
+ wire: pack.Message{
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{2, pack.VarintType}, pack.Varint(1002),
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(10),
+ }),
+ pack.Tag{1, pack.EndGroupType},
+ }.Marshal(),
+ },
+ {
+ desc: "MessageSet merges repeated message fields in item",
+ decodeTo: []proto.Message{build(
+ &messagesetpb.MessageSet{},
+ extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+ Ext1Field1: proto.Int32(10),
+ Ext1Field2: proto.Int32(20),
+ }),
+ )},
+ wire: pack.Message{
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(10),
+ }),
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{2, pack.VarintType}, pack.Varint(20),
+ }),
+ pack.Tag{1, pack.EndGroupType},
+ }.Marshal(),
+ },
+ {
+ desc: "MessageSet merges message fields in repeated items",
+ decodeTo: []proto.Message{build(
+ &messagesetpb.MessageSet{},
+ extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{
+ Ext1Field1: proto.Int32(10),
+ Ext1Field2: proto.Int32(20),
+ }),
+ extend(msetextpb.E_Ext2_MessageSetExtension, &msetextpb.Ext2{
+ Ext2Field1: proto.Int32(30),
+ }),
+ )},
+ wire: pack.Message{
+ // Ext1, field1
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(10),
+ }),
+ pack.Tag{1, pack.EndGroupType},
+ // Ext2, field1
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{2, pack.VarintType}, pack.Varint(1001),
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(30),
+ }),
+ pack.Tag{1, pack.EndGroupType},
+ // Ext2, field2
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{2, pack.VarintType}, pack.Varint(20),
+ }),
+ pack.Tag{1, pack.EndGroupType},
+ }.Marshal(),
+ },
+ {
+ desc: "MessageSet with missing type_id",
+ decodeTo: []proto.Message{build(
+ &messagesetpb.MessageSet{},
+ unknown(pack.Message{
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(10),
+ }),
+ pack.Tag{1, pack.EndGroupType},
+ }.Marshal()),
+ )},
+ wire: pack.Message{
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(10),
+ }),
+ pack.Tag{1, pack.EndGroupType},
+ }.Marshal(),
+ },
+ {
+ desc: "MessageSet with missing message",
+ decodeTo: []proto.Message{build(
+ &messagesetpb.MessageSet{},
+ extend(msetextpb.E_Ext1_MessageSetExtension, &msetextpb.Ext1{}),
+ )},
+ wire: pack.Message{
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{2, pack.VarintType}, pack.Varint(1000),
+ pack.Tag{1, pack.EndGroupType},
+ }.Marshal(),
+ },
+}
diff --git a/proto/size.go b/proto/size.go
index 1266143..9947580 100644
--- a/proto/size.go
+++ b/proto/size.go
@@ -5,6 +5,7 @@
package proto
import (
+ "google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"
@@ -28,6 +29,9 @@
}
func sizeMessageSlow(m protoreflect.Message) (size int) {
+ if messageset.IsMessageSet(m.Descriptor()) {
+ return sizeMessageSet(m)
+ }
m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
size += sizeField(fd, v)
return true