blob: b7b5477aa034b522fe908c5cf83ddd9ce878f9d6 [file] [log] [blame]
Damien Neil302cb322019-06-19 15:22:13 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style.
3// license that can be found in the LICENSE file.
4
5// Package messageset encodes and decodes the obsolete MessageSet wire format.
6package messageset
7
8import (
9 "google.golang.org/protobuf/internal/encoding/wire"
10 "google.golang.org/protobuf/internal/errors"
11 pref "google.golang.org/protobuf/reflect/protoreflect"
Joe Tsai5ae10aa2019-07-11 18:23:08 -070012 preg "google.golang.org/protobuf/reflect/protoregistry"
Damien Neil302cb322019-06-19 15:22:13 -070013)
14
15// The MessageSet wire format is equivalent to a message defiend as follows,
16// where each Item defines an extension field with a field number of 'type_id'
17// and content of 'message'. MessageSet extensions must be non-repeated message
18// fields.
19//
20// message MessageSet {
21// repeated group Item = 1 {
22// required int32 type_id = 2;
23// required string message = 3;
24// }
25// }
26const (
27 FieldItem = wire.Number(1)
28 FieldTypeID = wire.Number(2)
29 FieldMessage = wire.Number(3)
30)
31
Joe Tsai5ae10aa2019-07-11 18:23:08 -070032// ExtensionName is the field name for extensions of MessageSet.
33//
34// A valid MessageSet extension must be of the form:
35// message MyMessage {
36// extend proto2.bridge.MessageSet {
37// optional MyMessage message_set_extension = 1234;
38// }
39// ...
40// }
41const ExtensionName = "message_set_extension"
42
Damien Neil302cb322019-06-19 15:22:13 -070043// IsMessageSet returns whether the message uses the MessageSet wire format.
44func IsMessageSet(md pref.MessageDescriptor) bool {
45 xmd, ok := md.(interface{ IsMessageSet() bool })
46 return ok && xmd.IsMessageSet()
47}
48
Joe Tsai5ae10aa2019-07-11 18:23:08 -070049// IsMessageSetExtension reports this field extends a MessageSet.
50func IsMessageSetExtension(fd pref.FieldDescriptor) bool {
51 if fd.Name() != ExtensionName {
52 return false
53 }
54 if fd.FullName().Parent() != fd.Message().FullName() {
55 return false
56 }
57 return IsMessageSet(fd.ContainingMessage())
58}
59
60// FindMessageSetExtension locates a MessageSet extension field by name.
61// In text and JSON formats, the extension name used is the message itself.
62// The extension field name is derived by appending ExtensionName.
63func FindMessageSetExtension(r preg.ExtensionTypeResolver, s pref.FullName) (pref.ExtensionType, error) {
64 xt, err := r.FindExtensionByName(s.Append(ExtensionName))
65 if err != nil {
66 return nil, err
67 }
Damien Neil79bfdbe2019-08-28 11:08:22 -070068 if !IsMessageSetExtension(xt.TypeDescriptor()) {
Joe Tsai5ae10aa2019-07-11 18:23:08 -070069 return nil, preg.NotFound
70 }
71 return xt, nil
72}
73
Damien Neil302cb322019-06-19 15:22:13 -070074// SizeField returns the size of a MessageSet item field containing an extension
75// with the given field number, not counting the contents of the message subfield.
76func SizeField(num wire.Number) int {
77 return 2*wire.SizeTag(FieldItem) + wire.SizeTag(FieldTypeID) + wire.SizeVarint(uint64(num))
78}
79
80// ConsumeField parses a MessageSet item field and returns the contents of the
81// type_id and message subfields and the total item length.
82func ConsumeField(b []byte) (typeid wire.Number, message []byte, n int, err error) {
83 num, wtyp, n := wire.ConsumeTag(b)
84 if n < 0 {
85 return 0, nil, 0, wire.ParseError(n)
86 }
87 if num != FieldItem || wtyp != wire.StartGroupType {
88 return 0, nil, 0, errors.New("invalid MessageSet field number")
89 }
90 typeid, message, fieldLen, err := ConsumeFieldValue(b[n:], false)
91 if err != nil {
92 return 0, nil, 0, err
93 }
94 return typeid, message, n + fieldLen, nil
95}
96
97// ConsumeFieldValue parses b as a MessageSet item field value until and including
98// the trailing end group marker. It assumes the start group tag has already been parsed.
99// It returns the contents of the type_id and message subfields and the total
100// item length.
101//
102// If wantLen is true, the returned message value includes the length prefix.
103// This is ugly, but simplifies the fast-path decoder in internal/impl.
104func ConsumeFieldValue(b []byte, wantLen bool) (typeid wire.Number, message []byte, n int, err error) {
105 ilen := len(b)
106 for {
107 num, wtyp, n := wire.ConsumeTag(b)
108 if n < 0 {
109 return 0, nil, 0, wire.ParseError(n)
110 }
111 b = b[n:]
112 switch {
113 case num == FieldItem && wtyp == wire.EndGroupType:
114 if wantLen && len(message) == 0 {
115 // The message field was missing, which should never happen.
116 // Be prepared for this case anyway.
117 message = wire.AppendVarint(message, 0)
118 }
119 return typeid, message, ilen - len(b), nil
120 case num == FieldTypeID && wtyp == wire.VarintType:
121 v, n := wire.ConsumeVarint(b)
122 if n < 0 {
123 return 0, nil, 0, wire.ParseError(n)
124 }
125 b = b[n:]
126 typeid = wire.Number(v)
127 case num == FieldMessage && wtyp == wire.BytesType:
128 m, n := wire.ConsumeBytes(b)
129 if n < 0 {
130 return 0, nil, 0, wire.ParseError(n)
131 }
132 if message == nil {
133 if wantLen {
134 message = b[:n]
135 } else {
136 message = m
137 }
138 } else {
139 // This case should never happen in practice, but handle it for
140 // correctness: The MessageSet item contains multiple message
141 // fields, which need to be merged.
142 //
143 // In the case where we're returning the length, this becomes
144 // quite inefficient since we need to strip the length off
145 // the existing data and reconstruct it with the combined length.
146 if wantLen {
147 _, nn := wire.ConsumeVarint(message)
148 m0 := message[nn:]
149 message = message[:0]
150 message = wire.AppendVarint(message, uint64(len(m0)+len(m)))
151 message = append(message, m0...)
152 message = append(message, m...)
153 } else {
154 message = append(message, m...)
155 }
156 }
157 b = b[n:]
158 }
159 }
160}
161
162// AppendFieldStart appends the start of a MessageSet item field containing
163// an extension with the given number. The caller must add the message
164// subfield (including the tag).
165func AppendFieldStart(b []byte, num wire.Number) []byte {
166 b = wire.AppendTag(b, FieldItem, wire.StartGroupType)
167 b = wire.AppendTag(b, FieldTypeID, wire.VarintType)
168 b = wire.AppendVarint(b, uint64(num))
169 return b
170}
171
172// AppendFieldEnd appends the trailing end group marker for a MessageSet item field.
173func AppendFieldEnd(b []byte) []byte {
174 return wire.AppendTag(b, FieldItem, wire.EndGroupType)
175}