blob: bf7308a81dfe48bdb11cea404480624d86458bf0 [file] [log] [blame]
Damien Neilc37adef2019-04-01 13:49:56 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "sort"
9 "sync/atomic"
10
Damien Neilc37adef2019-04-01 13:49:56 -070011 proto "google.golang.org/protobuf/proto"
12 pref "google.golang.org/protobuf/reflect/protoreflect"
13 piface "google.golang.org/protobuf/runtime/protoiface"
14)
15
Damien Neile91877d2019-06-27 10:54:42 -070016// marshalOptions is a more efficient representation of MarshalOptions.
17//
18// We don't preserve the AllowPartial flag, because fast-path (un)marshal
19// operations always allow partial messages.
Damien Neilc37adef2019-04-01 13:49:56 -070020type marshalOptions uint
21
22const (
Damien Neile91877d2019-06-27 10:54:42 -070023 marshalDeterministic marshalOptions = 1 << iota
Damien Neilc37adef2019-04-01 13:49:56 -070024 marshalUseCachedSize
25)
26
27func newMarshalOptions(opts piface.MarshalOptions) marshalOptions {
28 var o marshalOptions
Damien Neilc37adef2019-04-01 13:49:56 -070029 if opts.Deterministic {
30 o |= marshalDeterministic
31 }
32 if opts.UseCachedSize {
33 o |= marshalUseCachedSize
34 }
35 return o
36}
37
38func (o marshalOptions) Options() proto.MarshalOptions {
39 return proto.MarshalOptions{
Damien Neile91877d2019-06-27 10:54:42 -070040 AllowPartial: true,
Damien Neilc37adef2019-04-01 13:49:56 -070041 Deterministic: o.Deterministic(),
42 UseCachedSize: o.UseCachedSize(),
43 }
44}
45
Damien Neilc37adef2019-04-01 13:49:56 -070046func (o marshalOptions) Deterministic() bool { return o&marshalDeterministic != 0 }
47func (o marshalOptions) UseCachedSize() bool { return o&marshalUseCachedSize != 0 }
48
49// size is protoreflect.Methods.Size.
Joe Tsai4fe96632019-05-22 05:12:36 -040050func (mi *MessageInfo) size(msg pref.ProtoMessage) (size int) {
Damien Neilc37adef2019-04-01 13:49:56 -070051 return mi.sizePointer(pointerOfIface(msg), 0)
52}
53
Joe Tsai4fe96632019-05-22 05:12:36 -040054func (mi *MessageInfo) sizePointer(p pointer, opts marshalOptions) (size int) {
Damien Neilc37adef2019-04-01 13:49:56 -070055 mi.init()
56 if p.IsNil() {
57 return 0
58 }
59 if opts.UseCachedSize() && mi.sizecacheOffset.IsValid() {
60 return int(atomic.LoadInt32(p.Apply(mi.sizecacheOffset).Int32()))
61 }
62 return mi.sizePointerSlow(p, opts)
63}
64
Joe Tsai4fe96632019-05-22 05:12:36 -040065func (mi *MessageInfo) sizePointerSlow(p pointer, opts marshalOptions) (size int) {
Damien Neilc37adef2019-04-01 13:49:56 -070066 if mi.extensionOffset.IsValid() {
67 e := p.Apply(mi.extensionOffset).Extensions()
68 size += mi.sizeExtensions(e, opts)
69 }
Damien Neil4ae30bb2019-06-20 10:12:23 -070070 for _, f := range mi.orderedCoderFields {
Damien Neilc37adef2019-04-01 13:49:56 -070071 fptr := p.Apply(f.offset)
72 if f.isPointer && fptr.Elem().IsNil() {
73 continue
74 }
75 if f.funcs.size == nil {
76 continue
77 }
78 size += f.funcs.size(fptr, f.tagsize, opts)
79 }
80 if mi.unknownOffset.IsValid() {
81 u := *p.Apply(mi.unknownOffset).Bytes()
82 size += len(u)
83 }
84 if mi.sizecacheOffset.IsValid() {
85 atomic.StoreInt32(p.Apply(mi.sizecacheOffset).Int32(), int32(size))
86 }
87 return size
88}
89
90// marshalAppend is protoreflect.Methods.MarshalAppend.
Joe Tsai4fe96632019-05-22 05:12:36 -040091func (mi *MessageInfo) marshalAppend(b []byte, msg pref.ProtoMessage, opts piface.MarshalOptions) ([]byte, error) {
Damien Neilc37adef2019-04-01 13:49:56 -070092 return mi.marshalAppendPointer(b, pointerOfIface(msg), newMarshalOptions(opts))
93}
94
Joe Tsai4fe96632019-05-22 05:12:36 -040095func (mi *MessageInfo) marshalAppendPointer(b []byte, p pointer, opts marshalOptions) ([]byte, error) {
Damien Neilc37adef2019-04-01 13:49:56 -070096 mi.init()
97 if p.IsNil() {
98 return b, nil
99 }
100 var err error
Damien Neilc37adef2019-04-01 13:49:56 -0700101 // The old marshaler encodes extensions at beginning.
102 if mi.extensionOffset.IsValid() {
103 e := p.Apply(mi.extensionOffset).Extensions()
104 // TODO: Special handling for MessageSet?
105 b, err = mi.appendExtensions(b, e, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -0700106 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -0700107 return b, err
108 }
109 }
Damien Neil4ae30bb2019-06-20 10:12:23 -0700110 for _, f := range mi.orderedCoderFields {
Damien Neilc37adef2019-04-01 13:49:56 -0700111 fptr := p.Apply(f.offset)
112 if f.isPointer && fptr.Elem().IsNil() {
113 continue
114 }
115 if f.funcs.marshal == nil {
116 continue
117 }
118 b, err = f.funcs.marshal(b, fptr, f.wiretag, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -0700119 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -0700120 return b, err
121 }
122 }
123 if mi.unknownOffset.IsValid() {
124 u := *p.Apply(mi.unknownOffset).Bytes()
125 b = append(b, u...)
126 }
Damien Neil8c86fc52019-06-19 09:28:29 -0700127 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -0700128}
129
Joe Tsai89d49632019-06-04 16:20:00 -0700130func (mi *MessageInfo) sizeExtensions(ext *map[int32]ExtensionField, opts marshalOptions) (n int) {
Damien Neilc37adef2019-04-01 13:49:56 -0700131 if ext == nil {
132 return 0
133 }
Joe Tsai89d49632019-06-04 16:20:00 -0700134 for _, x := range *ext {
135 xi := mi.extensionFieldInfo(x.GetType())
136 if xi.funcs.size == nil {
Damien Neilc37adef2019-04-01 13:49:56 -0700137 continue
138 }
Joe Tsai89d49632019-06-04 16:20:00 -0700139 n += xi.funcs.size(x.GetValue(), xi.tagsize, opts)
Damien Neilc37adef2019-04-01 13:49:56 -0700140 }
141 return n
142}
143
Joe Tsai89d49632019-06-04 16:20:00 -0700144func (mi *MessageInfo) appendExtensions(b []byte, ext *map[int32]ExtensionField, opts marshalOptions) ([]byte, error) {
Damien Neilc37adef2019-04-01 13:49:56 -0700145 if ext == nil {
146 return b, nil
147 }
148
149 switch len(*ext) {
150 case 0:
151 return b, nil
152 case 1:
153 // Fast-path for one extension: Don't bother sorting the keys.
154 var err error
Joe Tsai89d49632019-06-04 16:20:00 -0700155 for _, x := range *ext {
156 xi := mi.extensionFieldInfo(x.GetType())
157 b, err = xi.funcs.marshal(b, x.GetValue(), xi.wiretag, opts)
Damien Neilc37adef2019-04-01 13:49:56 -0700158 }
159 return b, err
160 default:
161 // Sort the keys to provide a deterministic encoding.
162 // Not sure this is required, but the old code does it.
163 keys := make([]int, 0, len(*ext))
164 for k := range *ext {
165 keys = append(keys, int(k))
166 }
167 sort.Ints(keys)
168 var err error
Damien Neilc37adef2019-04-01 13:49:56 -0700169 for _, k := range keys {
Joe Tsai89d49632019-06-04 16:20:00 -0700170 x := (*ext)[int32(k)]
171 xi := mi.extensionFieldInfo(x.GetType())
172 b, err = xi.funcs.marshal(b, x.GetValue(), xi.wiretag, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -0700173 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -0700174 return b, err
175 }
176 }
Damien Neil8c86fc52019-06-19 09:28:29 -0700177 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -0700178 }
179}