blob: 010d412f83a4dab232dbe80a743c2e0ea72fa051 [file] [log] [blame]
Damien Neilc37adef2019-04-01 13:49:56 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "sort"
9 "sync/atomic"
10
Damien Neilc37adef2019-04-01 13:49:56 -070011 proto "google.golang.org/protobuf/proto"
12 pref "google.golang.org/protobuf/reflect/protoreflect"
13 piface "google.golang.org/protobuf/runtime/protoiface"
14)
15
Damien Neile91877d2019-06-27 10:54:42 -070016// marshalOptions is a more efficient representation of MarshalOptions.
17//
18// We don't preserve the AllowPartial flag, because fast-path (un)marshal
19// operations always allow partial messages.
Damien Neilc37adef2019-04-01 13:49:56 -070020type marshalOptions uint
21
22const (
Damien Neile91877d2019-06-27 10:54:42 -070023 marshalDeterministic marshalOptions = 1 << iota
Damien Neilc37adef2019-04-01 13:49:56 -070024 marshalUseCachedSize
25)
26
27func newMarshalOptions(opts piface.MarshalOptions) marshalOptions {
28 var o marshalOptions
Damien Neilc37adef2019-04-01 13:49:56 -070029 if opts.Deterministic {
30 o |= marshalDeterministic
31 }
32 if opts.UseCachedSize {
33 o |= marshalUseCachedSize
34 }
35 return o
36}
37
38func (o marshalOptions) Options() proto.MarshalOptions {
39 return proto.MarshalOptions{
Damien Neile91877d2019-06-27 10:54:42 -070040 AllowPartial: true,
Damien Neilc37adef2019-04-01 13:49:56 -070041 Deterministic: o.Deterministic(),
42 UseCachedSize: o.UseCachedSize(),
43 }
44}
45
Damien Neilc37adef2019-04-01 13:49:56 -070046func (o marshalOptions) Deterministic() bool { return o&marshalDeterministic != 0 }
47func (o marshalOptions) UseCachedSize() bool { return o&marshalUseCachedSize != 0 }
48
49// size is protoreflect.Methods.Size.
Joe Tsai0f81b382019-07-10 23:14:31 -070050func (mi *MessageInfo) size(m pref.Message) (size int) {
51 var p pointer
52 if ms, ok := m.(*messageState); ok {
53 p = ms.pointer()
54 } else {
55 p = m.(*messageReflectWrapper).pointer()
56 }
57 return mi.sizePointer(p, 0)
Damien Neilc37adef2019-04-01 13:49:56 -070058}
59
Joe Tsai4fe96632019-05-22 05:12:36 -040060func (mi *MessageInfo) sizePointer(p pointer, opts marshalOptions) (size int) {
Damien Neilc37adef2019-04-01 13:49:56 -070061 mi.init()
62 if p.IsNil() {
63 return 0
64 }
65 if opts.UseCachedSize() && mi.sizecacheOffset.IsValid() {
66 return int(atomic.LoadInt32(p.Apply(mi.sizecacheOffset).Int32()))
67 }
68 return mi.sizePointerSlow(p, opts)
69}
70
Joe Tsai4fe96632019-05-22 05:12:36 -040071func (mi *MessageInfo) sizePointerSlow(p pointer, opts marshalOptions) (size int) {
Damien Neilc37adef2019-04-01 13:49:56 -070072 if mi.extensionOffset.IsValid() {
73 e := p.Apply(mi.extensionOffset).Extensions()
74 size += mi.sizeExtensions(e, opts)
75 }
Damien Neil4ae30bb2019-06-20 10:12:23 -070076 for _, f := range mi.orderedCoderFields {
Damien Neilc37adef2019-04-01 13:49:56 -070077 fptr := p.Apply(f.offset)
78 if f.isPointer && fptr.Elem().IsNil() {
79 continue
80 }
81 if f.funcs.size == nil {
82 continue
83 }
84 size += f.funcs.size(fptr, f.tagsize, opts)
85 }
86 if mi.unknownOffset.IsValid() {
87 u := *p.Apply(mi.unknownOffset).Bytes()
88 size += len(u)
89 }
90 if mi.sizecacheOffset.IsValid() {
91 atomic.StoreInt32(p.Apply(mi.sizecacheOffset).Int32(), int32(size))
92 }
93 return size
94}
95
96// marshalAppend is protoreflect.Methods.MarshalAppend.
Joe Tsai0f81b382019-07-10 23:14:31 -070097func (mi *MessageInfo) marshalAppend(b []byte, m pref.Message, opts piface.MarshalOptions) ([]byte, error) {
98 var p pointer
99 if ms, ok := m.(*messageState); ok {
100 p = ms.pointer()
101 } else {
102 p = m.(*messageReflectWrapper).pointer()
103 }
104 return mi.marshalAppendPointer(b, p, newMarshalOptions(opts))
Damien Neilc37adef2019-04-01 13:49:56 -0700105}
106
Joe Tsai4fe96632019-05-22 05:12:36 -0400107func (mi *MessageInfo) marshalAppendPointer(b []byte, p pointer, opts marshalOptions) ([]byte, error) {
Damien Neilc37adef2019-04-01 13:49:56 -0700108 mi.init()
109 if p.IsNil() {
110 return b, nil
111 }
112 var err error
Damien Neilc37adef2019-04-01 13:49:56 -0700113 // The old marshaler encodes extensions at beginning.
114 if mi.extensionOffset.IsValid() {
115 e := p.Apply(mi.extensionOffset).Extensions()
116 // TODO: Special handling for MessageSet?
117 b, err = mi.appendExtensions(b, e, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -0700118 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -0700119 return b, err
120 }
121 }
Damien Neil4ae30bb2019-06-20 10:12:23 -0700122 for _, f := range mi.orderedCoderFields {
Damien Neilc37adef2019-04-01 13:49:56 -0700123 fptr := p.Apply(f.offset)
124 if f.isPointer && fptr.Elem().IsNil() {
125 continue
126 }
127 if f.funcs.marshal == nil {
128 continue
129 }
130 b, err = f.funcs.marshal(b, fptr, f.wiretag, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -0700131 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -0700132 return b, err
133 }
134 }
135 if mi.unknownOffset.IsValid() {
136 u := *p.Apply(mi.unknownOffset).Bytes()
137 b = append(b, u...)
138 }
Damien Neil8c86fc52019-06-19 09:28:29 -0700139 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -0700140}
141
Joe Tsai89d49632019-06-04 16:20:00 -0700142func (mi *MessageInfo) sizeExtensions(ext *map[int32]ExtensionField, opts marshalOptions) (n int) {
Damien Neilc37adef2019-04-01 13:49:56 -0700143 if ext == nil {
144 return 0
145 }
Joe Tsai89d49632019-06-04 16:20:00 -0700146 for _, x := range *ext {
147 xi := mi.extensionFieldInfo(x.GetType())
148 if xi.funcs.size == nil {
Damien Neilc37adef2019-04-01 13:49:56 -0700149 continue
150 }
Joe Tsai89d49632019-06-04 16:20:00 -0700151 n += xi.funcs.size(x.GetValue(), xi.tagsize, opts)
Damien Neilc37adef2019-04-01 13:49:56 -0700152 }
153 return n
154}
155
Joe Tsai89d49632019-06-04 16:20:00 -0700156func (mi *MessageInfo) appendExtensions(b []byte, ext *map[int32]ExtensionField, opts marshalOptions) ([]byte, error) {
Damien Neilc37adef2019-04-01 13:49:56 -0700157 if ext == nil {
158 return b, nil
159 }
160
161 switch len(*ext) {
162 case 0:
163 return b, nil
164 case 1:
165 // Fast-path for one extension: Don't bother sorting the keys.
166 var err error
Joe Tsai89d49632019-06-04 16:20:00 -0700167 for _, x := range *ext {
168 xi := mi.extensionFieldInfo(x.GetType())
169 b, err = xi.funcs.marshal(b, x.GetValue(), xi.wiretag, opts)
Damien Neilc37adef2019-04-01 13:49:56 -0700170 }
171 return b, err
172 default:
173 // Sort the keys to provide a deterministic encoding.
174 // Not sure this is required, but the old code does it.
175 keys := make([]int, 0, len(*ext))
176 for k := range *ext {
177 keys = append(keys, int(k))
178 }
179 sort.Ints(keys)
180 var err error
Damien Neilc37adef2019-04-01 13:49:56 -0700181 for _, k := range keys {
Joe Tsai89d49632019-06-04 16:20:00 -0700182 x := (*ext)[int32(k)]
183 xi := mi.extensionFieldInfo(x.GetType())
184 b, err = xi.funcs.marshal(b, x.GetValue(), xi.wiretag, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -0700185 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -0700186 return b, err
187 }
188 }
Damien Neil8c86fc52019-06-19 09:28:29 -0700189 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -0700190 }
191}