blob: c8c3b0a5ba131f20253fe0d9a18e4053860ee561 [file] [log] [blame]
Joe Tsai2fc306a2019-06-20 03:01:22 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto_test
6
7import (
8 "testing"
9
10 "google.golang.org/protobuf/internal/encoding/pack"
11 "google.golang.org/protobuf/internal/scalar"
12 "google.golang.org/protobuf/proto"
13
14 testpb "google.golang.org/protobuf/internal/testprotos/test"
15)
16
17func TestMerge(t *testing.T) {
18 dst := new(testpb.TestAllTypes)
19 src := (*testpb.TestAllTypes)(nil)
20 proto.Merge(dst, src)
21
22 // Mutating the source should not affect dst.
23
24 tests := []struct {
25 desc string
26 dst proto.Message
27 src proto.Message
28 want proto.Message
29 mutator func(proto.Message) // if provided, is run on src after merging
Joe Tsai2fc306a2019-06-20 03:01:22 -070030 }{{
31 desc: "merge from nil message",
32 dst: new(testpb.TestAllTypes),
33 src: (*testpb.TestAllTypes)(nil),
34 want: new(testpb.TestAllTypes),
35 }, {
36 desc: "clone a large message",
37 dst: new(testpb.TestAllTypes),
38 src: &testpb.TestAllTypes{
39 OptionalInt64: scalar.Int64(0),
40 OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(1).Enum(),
41 OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
42 A: scalar.Int32(100),
43 },
44 RepeatedSfixed32: []int32{1, 2, 3},
45 RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
46 {A: scalar.Int32(200)},
47 {A: scalar.Int32(300)},
48 },
49 MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
50 "fizz": 400,
51 "buzz": 500,
52 },
53 MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
54 "foo": {A: scalar.Int32(600)},
55 "bar": {A: scalar.Int32(700)},
56 },
57 OneofField: &testpb.TestAllTypes_OneofNestedMessage{
58 &testpb.TestAllTypes_NestedMessage{
59 A: scalar.Int32(800),
60 },
61 },
62 },
63 want: &testpb.TestAllTypes{
64 OptionalInt64: scalar.Int64(0),
65 OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(1).Enum(),
66 OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
67 A: scalar.Int32(100),
68 },
69 RepeatedSfixed32: []int32{1, 2, 3},
70 RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
71 {A: scalar.Int32(200)},
72 {A: scalar.Int32(300)},
73 },
74 MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
75 "fizz": 400,
76 "buzz": 500,
77 },
78 MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
79 "foo": {A: scalar.Int32(600)},
80 "bar": {A: scalar.Int32(700)},
81 },
82 OneofField: &testpb.TestAllTypes_OneofNestedMessage{
83 &testpb.TestAllTypes_NestedMessage{
84 A: scalar.Int32(800),
85 },
86 },
87 },
88 mutator: func(mi proto.Message) {
89 m := mi.(*testpb.TestAllTypes)
90 *m.OptionalInt64++
91 *m.OptionalNestedEnum++
92 *m.OptionalNestedMessage.A++
93 m.RepeatedSfixed32[0]++
94 *m.RepeatedNestedMessage[0].A++
95 delete(m.MapStringNestedEnum, "fizz")
96 *m.MapStringNestedMessage["foo"].A++
97 *m.OneofField.(*testpb.TestAllTypes_OneofNestedMessage).OneofNestedMessage.A++
98 },
99 }, {
100 desc: "merge bytes",
101 dst: &testpb.TestAllTypes{
102 OptionalBytes: []byte{1, 2, 3},
103 RepeatedBytes: [][]byte{{1, 2}, {3, 4}},
104 MapStringBytes: map[string][]byte{"alpha": {1, 2, 3}},
105 },
106 src: &testpb.TestAllTypes{
107 OptionalBytes: []byte{4, 5, 6},
108 RepeatedBytes: [][]byte{{5, 6}, {7, 8}},
109 MapStringBytes: map[string][]byte{"alpha": {4, 5, 6}, "bravo": {1, 2, 3}},
110 },
111 want: &testpb.TestAllTypes{
112 OptionalBytes: []byte{4, 5, 6},
113 RepeatedBytes: [][]byte{{1, 2}, {3, 4}, {5, 6}, {7, 8}},
114 MapStringBytes: map[string][]byte{"alpha": {4, 5, 6}, "bravo": {1, 2, 3}},
115 },
116 mutator: func(mi proto.Message) {
117 m := mi.(*testpb.TestAllTypes)
118 m.OptionalBytes[0]++
119 m.RepeatedBytes[0][0]++
120 m.MapStringBytes["alpha"][0]++
121 },
122 }, {
123 desc: "merge singular fields",
124 dst: &testpb.TestAllTypes{
125 OptionalInt32: scalar.Int32(1),
126 OptionalInt64: scalar.Int64(1),
127 OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(10).Enum(),
128 OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
129 A: scalar.Int32(100),
130 Corecursive: &testpb.TestAllTypes{
131 OptionalInt64: scalar.Int64(1000),
132 },
133 },
134 },
135 src: &testpb.TestAllTypes{
136 OptionalInt64: scalar.Int64(2),
137 OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(20).Enum(),
138 OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
139 A: scalar.Int32(200),
140 },
141 },
142 want: &testpb.TestAllTypes{
143 OptionalInt32: scalar.Int32(1),
144 OptionalInt64: scalar.Int64(2),
145 OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(20).Enum(),
146 OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
147 A: scalar.Int32(200),
148 Corecursive: &testpb.TestAllTypes{
149 OptionalInt64: scalar.Int64(1000),
150 },
151 },
152 },
153 mutator: func(mi proto.Message) {
154 m := mi.(*testpb.TestAllTypes)
155 *m.OptionalInt64++
156 *m.OptionalNestedEnum++
157 *m.OptionalNestedMessage.A++
158 },
159 }, {
160 desc: "merge list fields",
161 dst: &testpb.TestAllTypes{
162 RepeatedSfixed32: []int32{1, 2, 3},
163 RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
164 {A: scalar.Int32(100)},
165 {A: scalar.Int32(200)},
166 },
167 },
168 src: &testpb.TestAllTypes{
169 RepeatedSfixed32: []int32{4, 5, 6},
170 RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
171 {A: scalar.Int32(300)},
172 {A: scalar.Int32(400)},
173 },
174 },
175 want: &testpb.TestAllTypes{
176 RepeatedSfixed32: []int32{1, 2, 3, 4, 5, 6},
177 RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
178 {A: scalar.Int32(100)},
179 {A: scalar.Int32(200)},
180 {A: scalar.Int32(300)},
181 {A: scalar.Int32(400)},
182 },
183 },
184 mutator: func(mi proto.Message) {
185 m := mi.(*testpb.TestAllTypes)
186 m.RepeatedSfixed32[0]++
187 *m.RepeatedNestedMessage[0].A++
188 },
189 }, {
190 desc: "merge map fields",
191 dst: &testpb.TestAllTypes{
192 MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
193 "fizz": 100,
194 "buzz": 200,
195 "guzz": 300,
196 },
197 MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
198 "foo": {A: scalar.Int32(400)},
199 },
200 },
201 src: &testpb.TestAllTypes{
202 MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
203 "fizz": 1000,
204 "buzz": 2000,
205 },
206 MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
207 "foo": {A: scalar.Int32(3000)},
208 "bar": {},
209 },
210 },
211 want: &testpb.TestAllTypes{
212 MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
213 "fizz": 1000,
214 "buzz": 2000,
215 "guzz": 300,
216 },
217 MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
218 "foo": {A: scalar.Int32(3000)},
219 "bar": {},
220 },
221 },
222 mutator: func(mi proto.Message) {
223 m := mi.(*testpb.TestAllTypes)
224 delete(m.MapStringNestedEnum, "fizz")
225 m.MapStringNestedMessage["bar"].A = scalar.Int32(1)
226 },
227 }, {
228 desc: "merge oneof message fields",
229 dst: &testpb.TestAllTypes{
230 OneofField: &testpb.TestAllTypes_OneofNestedMessage{
231 &testpb.TestAllTypes_NestedMessage{
232 A: scalar.Int32(100),
233 },
234 },
235 },
236 src: &testpb.TestAllTypes{
237 OneofField: &testpb.TestAllTypes_OneofNestedMessage{
238 &testpb.TestAllTypes_NestedMessage{
239 Corecursive: &testpb.TestAllTypes{
240 OptionalInt64: scalar.Int64(1000),
241 },
242 },
243 },
244 },
245 want: &testpb.TestAllTypes{
246 OneofField: &testpb.TestAllTypes_OneofNestedMessage{
247 &testpb.TestAllTypes_NestedMessage{
248 A: scalar.Int32(100),
249 Corecursive: &testpb.TestAllTypes{
250 OptionalInt64: scalar.Int64(1000),
251 },
252 },
253 },
254 },
255 mutator: func(mi proto.Message) {
256 m := mi.(*testpb.TestAllTypes)
257 *m.OneofField.(*testpb.TestAllTypes_OneofNestedMessage).OneofNestedMessage.Corecursive.OptionalInt64++
258 },
Joe Tsai2fc306a2019-06-20 03:01:22 -0700259 }, {
260 desc: "merge oneof scalar fields",
261 dst: &testpb.TestAllTypes{
262 OneofField: &testpb.TestAllTypes_OneofUint32{100},
263 },
264 src: &testpb.TestAllTypes{
265 OneofField: &testpb.TestAllTypes_OneofFloat{3.14152},
266 },
267 want: &testpb.TestAllTypes{
268 OneofField: &testpb.TestAllTypes_OneofFloat{3.14152},
269 },
270 mutator: func(mi proto.Message) {
271 m := mi.(*testpb.TestAllTypes)
272 m.OneofField.(*testpb.TestAllTypes_OneofFloat).OneofFloat++
273 },
274 }, {
275 desc: "merge extension fields",
276 dst: func() proto.Message {
277 m := new(testpb.TestAllExtensions)
278 m.ProtoReflect().Set(
279 testpb.E_OptionalInt32Extension.Type,
280 testpb.E_OptionalInt32Extension.Type.ValueOf(int32(32)),
281 )
282 m.ProtoReflect().Set(
283 testpb.E_OptionalNestedMessageExtension.Type,
284 testpb.E_OptionalNestedMessageExtension.Type.ValueOf(&testpb.TestAllTypes_NestedMessage{
285 A: scalar.Int32(50),
286 }),
287 )
288 m.ProtoReflect().Set(
289 testpb.E_RepeatedFixed32Extension.Type,
290 testpb.E_RepeatedFixed32Extension.Type.ValueOf(&[]uint32{1, 2, 3}),
291 )
292 return m
293 }(),
294 src: func() proto.Message {
295 m := new(testpb.TestAllExtensions)
296 m.ProtoReflect().Set(
297 testpb.E_OptionalInt64Extension.Type,
298 testpb.E_OptionalInt64Extension.Type.ValueOf(int64(64)),
299 )
300 m.ProtoReflect().Set(
301 testpb.E_OptionalNestedMessageExtension.Type,
302 testpb.E_OptionalNestedMessageExtension.Type.ValueOf(&testpb.TestAllTypes_NestedMessage{
303 Corecursive: &testpb.TestAllTypes{
304 OptionalInt64: scalar.Int64(1000),
305 },
306 }),
307 )
308 m.ProtoReflect().Set(
309 testpb.E_RepeatedFixed32Extension.Type,
310 testpb.E_RepeatedFixed32Extension.Type.ValueOf(&[]uint32{4, 5, 6}),
311 )
312 return m
313 }(),
314 want: func() proto.Message {
315 m := new(testpb.TestAllExtensions)
316 m.ProtoReflect().Set(
317 testpb.E_OptionalInt32Extension.Type,
318 testpb.E_OptionalInt32Extension.Type.ValueOf(int32(32)),
319 )
320 m.ProtoReflect().Set(
321 testpb.E_OptionalInt64Extension.Type,
322 testpb.E_OptionalInt64Extension.Type.ValueOf(int64(64)),
323 )
324 m.ProtoReflect().Set(
325 testpb.E_OptionalNestedMessageExtension.Type,
326 testpb.E_OptionalNestedMessageExtension.Type.ValueOf(&testpb.TestAllTypes_NestedMessage{
327 A: scalar.Int32(50),
328 Corecursive: &testpb.TestAllTypes{
329 OptionalInt64: scalar.Int64(1000),
330 },
331 }),
332 )
333 m.ProtoReflect().Set(
334 testpb.E_RepeatedFixed32Extension.Type,
335 testpb.E_RepeatedFixed32Extension.Type.ValueOf(&[]uint32{1, 2, 3, 4, 5, 6}),
336 )
337 return m
338 }(),
339 }, {
340 desc: "merge unknown fields",
341 dst: func() proto.Message {
342 m := new(testpb.TestAllTypes)
343 m.ProtoReflect().SetUnknown(pack.Message{
344 pack.Tag{Number: 50000, Type: pack.VarintType}, pack.Svarint(-5),
345 }.Marshal())
346 return m
347 }(),
348 src: func() proto.Message {
349 m := new(testpb.TestAllTypes)
350 m.ProtoReflect().SetUnknown(pack.Message{
351 pack.Tag{Number: 500000, Type: pack.VarintType}, pack.Svarint(-50),
352 }.Marshal())
353 return m
354 }(),
355 want: func() proto.Message {
356 m := new(testpb.TestAllTypes)
357 m.ProtoReflect().SetUnknown(pack.Message{
358 pack.Tag{Number: 50000, Type: pack.VarintType}, pack.Svarint(-5),
359 pack.Tag{Number: 500000, Type: pack.VarintType}, pack.Svarint(-50),
360 }.Marshal())
361 return m
362 }(),
363 }}
364
365 for _, tt := range tests {
366 t.Run(tt.desc, func(t *testing.T) {
367 // Merge should be semantically equivalent to unmarshaling the
368 // encoded form of src into the current dst.
369 b1, err := proto.MarshalOptions{AllowPartial: true}.Marshal(tt.dst)
370 if err != nil {
371 t.Fatalf("Marshal(dst) error: %v", err)
372 }
373 b2, err := proto.MarshalOptions{AllowPartial: true}.Marshal(tt.src)
374 if err != nil {
375 t.Fatalf("Marshal(src) error: %v", err)
376 }
377 dst := tt.dst.ProtoReflect().New().Interface()
378 err = proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(append(b1, b2...), dst)
379 if err != nil {
380 t.Fatalf("Unmarshal() error: %v", err)
381 }
Joe Tsai6c286742019-07-11 23:15:05 -0700382 if !proto.Equal(dst, tt.want) {
Joe Tsai2fc306a2019-06-20 03:01:22 -0700383 t.Fatalf("Unmarshal(Marshal(dst)+Marshal(src)) mismatch: got %v, want %v", dst, tt.want)
384 }
385
386 proto.Merge(tt.dst, tt.src)
387 if tt.mutator != nil {
388 tt.mutator(tt.src) // should not be observable by dst
389 }
390 if !proto.Equal(tt.dst, tt.want) {
391 t.Fatalf("Merge() mismatch: got %v, want %v", tt.dst, tt.want)
392 }
393 })
394 }
395}