blob: 1c6ac299bc2b99103c1d51122bb6ae07ffea5e55 [file] [log] [blame]
Damien Neil302cb322019-06-19 15:22:13 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style.
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8 "google.golang.org/protobuf/internal/encoding/messageset"
9 "google.golang.org/protobuf/internal/encoding/wire"
10 "google.golang.org/protobuf/internal/errors"
11 "google.golang.org/protobuf/internal/flags"
12 "google.golang.org/protobuf/reflect/protoreflect"
13 "google.golang.org/protobuf/reflect/protoregistry"
14)
15
16func sizeMessageSet(m protoreflect.Message) (size int) {
17 m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
18 size += messageset.SizeField(fd.Number())
19 size += wire.SizeTag(messageset.FieldMessage)
20 size += wire.SizeBytes(sizeMessage(v.Message()))
21 return true
22 })
23 size += len(m.GetUnknown())
24 return size
25}
26
27func marshalMessageSet(b []byte, m protoreflect.Message, o MarshalOptions) ([]byte, error) {
28 if !flags.Proto1Legacy {
29 return b, errors.New("no support for message_set_wire_format")
30 }
31 var err error
32 o.rangeFields(m, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
33 b, err = marshalMessageSetField(b, fd, v, o)
34 return err == nil
35 })
36 if err != nil {
37 return b, err
38 }
39 b = append(b, m.GetUnknown()...)
40 return b, nil
41}
42
43func marshalMessageSetField(b []byte, fd protoreflect.FieldDescriptor, value protoreflect.Value, o MarshalOptions) ([]byte, error) {
44 b = messageset.AppendFieldStart(b, fd.Number())
45 b = wire.AppendTag(b, messageset.FieldMessage, wire.BytesType)
46 b = wire.AppendVarint(b, uint64(o.Size(value.Message().Interface())))
47 b, err := o.marshalMessage(b, value.Message())
48 if err != nil {
49 return b, err
50 }
51 b = messageset.AppendFieldEnd(b)
52 return b, nil
53}
54
55func unmarshalMessageSet(b []byte, m protoreflect.Message, o UnmarshalOptions) error {
56 if !flags.Proto1Legacy {
57 return errors.New("no support for message_set_wire_format")
58 }
59 md := m.Descriptor()
60 for len(b) > 0 {
61 err := func() error {
62 num, v, n, err := messageset.ConsumeField(b)
63 if err != nil {
64 // Not a message set field.
65 //
66 // Return errUnknown to try to add this to the unknown fields.
67 // If the field is completely unparsable, we'll catch it
68 // when trying to skip the field.
69 return errUnknown
70 }
71 if !md.ExtensionRanges().Has(num) {
72 return errUnknown
73 }
74 fd, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
75 if err == protoregistry.NotFound {
76 return errUnknown
77 }
78 if err != nil {
79 return err
80 }
81 if err := o.unmarshalMessage(v, m.Mutable(fd).Message()); err != nil {
82 // Contents cannot be unmarshaled.
83 return err
84 }
85 b = b[n:]
86 return nil
87 }()
88 if err == errUnknown {
89 _, _, n := wire.ConsumeField(b)
90 if n < 0 {
91 return wire.ParseError(n)
92 }
93 m.SetUnknown(append(m.GetUnknown(), b[:n]...))
94 b = b[n:]
95 continue
96 }
97 if err != nil {
98 return err
99 }
100 }
101 return nil
102}