blob: 109a413f9f562e719fb3bbc3898d19d414220489 [file] [log] [blame]
Joe Tsai2fc306a2019-06-20 03:01:22 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto_test
6
7import (
8 "testing"
9
10 "google.golang.org/protobuf/internal/encoding/pack"
11 "google.golang.org/protobuf/internal/scalar"
12 "google.golang.org/protobuf/proto"
13
14 testpb "google.golang.org/protobuf/internal/testprotos/test"
15)
16
17func TestMerge(t *testing.T) {
18 dst := new(testpb.TestAllTypes)
19 src := (*testpb.TestAllTypes)(nil)
20 proto.Merge(dst, src)
21
22 // Mutating the source should not affect dst.
23
24 tests := []struct {
25 desc string
26 dst proto.Message
27 src proto.Message
28 want proto.Message
29 mutator func(proto.Message) // if provided, is run on src after merging
30
31 skipMarshalUnmarshal bool // TODO: Remove this when proto.Unmarshal is fixed for messages in oneofs
32 }{{
33 desc: "merge from nil message",
34 dst: new(testpb.TestAllTypes),
35 src: (*testpb.TestAllTypes)(nil),
36 want: new(testpb.TestAllTypes),
37 }, {
38 desc: "clone a large message",
39 dst: new(testpb.TestAllTypes),
40 src: &testpb.TestAllTypes{
41 OptionalInt64: scalar.Int64(0),
42 OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(1).Enum(),
43 OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
44 A: scalar.Int32(100),
45 },
46 RepeatedSfixed32: []int32{1, 2, 3},
47 RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
48 {A: scalar.Int32(200)},
49 {A: scalar.Int32(300)},
50 },
51 MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
52 "fizz": 400,
53 "buzz": 500,
54 },
55 MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
56 "foo": {A: scalar.Int32(600)},
57 "bar": {A: scalar.Int32(700)},
58 },
59 OneofField: &testpb.TestAllTypes_OneofNestedMessage{
60 &testpb.TestAllTypes_NestedMessage{
61 A: scalar.Int32(800),
62 },
63 },
64 },
65 want: &testpb.TestAllTypes{
66 OptionalInt64: scalar.Int64(0),
67 OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(1).Enum(),
68 OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
69 A: scalar.Int32(100),
70 },
71 RepeatedSfixed32: []int32{1, 2, 3},
72 RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
73 {A: scalar.Int32(200)},
74 {A: scalar.Int32(300)},
75 },
76 MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
77 "fizz": 400,
78 "buzz": 500,
79 },
80 MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
81 "foo": {A: scalar.Int32(600)},
82 "bar": {A: scalar.Int32(700)},
83 },
84 OneofField: &testpb.TestAllTypes_OneofNestedMessage{
85 &testpb.TestAllTypes_NestedMessage{
86 A: scalar.Int32(800),
87 },
88 },
89 },
90 mutator: func(mi proto.Message) {
91 m := mi.(*testpb.TestAllTypes)
92 *m.OptionalInt64++
93 *m.OptionalNestedEnum++
94 *m.OptionalNestedMessage.A++
95 m.RepeatedSfixed32[0]++
96 *m.RepeatedNestedMessage[0].A++
97 delete(m.MapStringNestedEnum, "fizz")
98 *m.MapStringNestedMessage["foo"].A++
99 *m.OneofField.(*testpb.TestAllTypes_OneofNestedMessage).OneofNestedMessage.A++
100 },
101 }, {
102 desc: "merge bytes",
103 dst: &testpb.TestAllTypes{
104 OptionalBytes: []byte{1, 2, 3},
105 RepeatedBytes: [][]byte{{1, 2}, {3, 4}},
106 MapStringBytes: map[string][]byte{"alpha": {1, 2, 3}},
107 },
108 src: &testpb.TestAllTypes{
109 OptionalBytes: []byte{4, 5, 6},
110 RepeatedBytes: [][]byte{{5, 6}, {7, 8}},
111 MapStringBytes: map[string][]byte{"alpha": {4, 5, 6}, "bravo": {1, 2, 3}},
112 },
113 want: &testpb.TestAllTypes{
114 OptionalBytes: []byte{4, 5, 6},
115 RepeatedBytes: [][]byte{{1, 2}, {3, 4}, {5, 6}, {7, 8}},
116 MapStringBytes: map[string][]byte{"alpha": {4, 5, 6}, "bravo": {1, 2, 3}},
117 },
118 mutator: func(mi proto.Message) {
119 m := mi.(*testpb.TestAllTypes)
120 m.OptionalBytes[0]++
121 m.RepeatedBytes[0][0]++
122 m.MapStringBytes["alpha"][0]++
123 },
124 }, {
125 desc: "merge singular fields",
126 dst: &testpb.TestAllTypes{
127 OptionalInt32: scalar.Int32(1),
128 OptionalInt64: scalar.Int64(1),
129 OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(10).Enum(),
130 OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
131 A: scalar.Int32(100),
132 Corecursive: &testpb.TestAllTypes{
133 OptionalInt64: scalar.Int64(1000),
134 },
135 },
136 },
137 src: &testpb.TestAllTypes{
138 OptionalInt64: scalar.Int64(2),
139 OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(20).Enum(),
140 OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
141 A: scalar.Int32(200),
142 },
143 },
144 want: &testpb.TestAllTypes{
145 OptionalInt32: scalar.Int32(1),
146 OptionalInt64: scalar.Int64(2),
147 OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(20).Enum(),
148 OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
149 A: scalar.Int32(200),
150 Corecursive: &testpb.TestAllTypes{
151 OptionalInt64: scalar.Int64(1000),
152 },
153 },
154 },
155 mutator: func(mi proto.Message) {
156 m := mi.(*testpb.TestAllTypes)
157 *m.OptionalInt64++
158 *m.OptionalNestedEnum++
159 *m.OptionalNestedMessage.A++
160 },
161 }, {
162 desc: "merge list fields",
163 dst: &testpb.TestAllTypes{
164 RepeatedSfixed32: []int32{1, 2, 3},
165 RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
166 {A: scalar.Int32(100)},
167 {A: scalar.Int32(200)},
168 },
169 },
170 src: &testpb.TestAllTypes{
171 RepeatedSfixed32: []int32{4, 5, 6},
172 RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
173 {A: scalar.Int32(300)},
174 {A: scalar.Int32(400)},
175 },
176 },
177 want: &testpb.TestAllTypes{
178 RepeatedSfixed32: []int32{1, 2, 3, 4, 5, 6},
179 RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
180 {A: scalar.Int32(100)},
181 {A: scalar.Int32(200)},
182 {A: scalar.Int32(300)},
183 {A: scalar.Int32(400)},
184 },
185 },
186 mutator: func(mi proto.Message) {
187 m := mi.(*testpb.TestAllTypes)
188 m.RepeatedSfixed32[0]++
189 *m.RepeatedNestedMessage[0].A++
190 },
191 }, {
192 desc: "merge map fields",
193 dst: &testpb.TestAllTypes{
194 MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
195 "fizz": 100,
196 "buzz": 200,
197 "guzz": 300,
198 },
199 MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
200 "foo": {A: scalar.Int32(400)},
201 },
202 },
203 src: &testpb.TestAllTypes{
204 MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
205 "fizz": 1000,
206 "buzz": 2000,
207 },
208 MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
209 "foo": {A: scalar.Int32(3000)},
210 "bar": {},
211 },
212 },
213 want: &testpb.TestAllTypes{
214 MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
215 "fizz": 1000,
216 "buzz": 2000,
217 "guzz": 300,
218 },
219 MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
220 "foo": {A: scalar.Int32(3000)},
221 "bar": {},
222 },
223 },
224 mutator: func(mi proto.Message) {
225 m := mi.(*testpb.TestAllTypes)
226 delete(m.MapStringNestedEnum, "fizz")
227 m.MapStringNestedMessage["bar"].A = scalar.Int32(1)
228 },
229 }, {
230 desc: "merge oneof message fields",
231 dst: &testpb.TestAllTypes{
232 OneofField: &testpb.TestAllTypes_OneofNestedMessage{
233 &testpb.TestAllTypes_NestedMessage{
234 A: scalar.Int32(100),
235 },
236 },
237 },
238 src: &testpb.TestAllTypes{
239 OneofField: &testpb.TestAllTypes_OneofNestedMessage{
240 &testpb.TestAllTypes_NestedMessage{
241 Corecursive: &testpb.TestAllTypes{
242 OptionalInt64: scalar.Int64(1000),
243 },
244 },
245 },
246 },
247 want: &testpb.TestAllTypes{
248 OneofField: &testpb.TestAllTypes_OneofNestedMessage{
249 &testpb.TestAllTypes_NestedMessage{
250 A: scalar.Int32(100),
251 Corecursive: &testpb.TestAllTypes{
252 OptionalInt64: scalar.Int64(1000),
253 },
254 },
255 },
256 },
257 mutator: func(mi proto.Message) {
258 m := mi.(*testpb.TestAllTypes)
259 *m.OneofField.(*testpb.TestAllTypes_OneofNestedMessage).OneofNestedMessage.Corecursive.OptionalInt64++
260 },
261 skipMarshalUnmarshal: true,
262 }, {
263 desc: "merge oneof scalar fields",
264 dst: &testpb.TestAllTypes{
265 OneofField: &testpb.TestAllTypes_OneofUint32{100},
266 },
267 src: &testpb.TestAllTypes{
268 OneofField: &testpb.TestAllTypes_OneofFloat{3.14152},
269 },
270 want: &testpb.TestAllTypes{
271 OneofField: &testpb.TestAllTypes_OneofFloat{3.14152},
272 },
273 mutator: func(mi proto.Message) {
274 m := mi.(*testpb.TestAllTypes)
275 m.OneofField.(*testpb.TestAllTypes_OneofFloat).OneofFloat++
276 },
277 }, {
278 desc: "merge extension fields",
279 dst: func() proto.Message {
280 m := new(testpb.TestAllExtensions)
281 m.ProtoReflect().Set(
282 testpb.E_OptionalInt32Extension.Type,
283 testpb.E_OptionalInt32Extension.Type.ValueOf(int32(32)),
284 )
285 m.ProtoReflect().Set(
286 testpb.E_OptionalNestedMessageExtension.Type,
287 testpb.E_OptionalNestedMessageExtension.Type.ValueOf(&testpb.TestAllTypes_NestedMessage{
288 A: scalar.Int32(50),
289 }),
290 )
291 m.ProtoReflect().Set(
292 testpb.E_RepeatedFixed32Extension.Type,
293 testpb.E_RepeatedFixed32Extension.Type.ValueOf(&[]uint32{1, 2, 3}),
294 )
295 return m
296 }(),
297 src: func() proto.Message {
298 m := new(testpb.TestAllExtensions)
299 m.ProtoReflect().Set(
300 testpb.E_OptionalInt64Extension.Type,
301 testpb.E_OptionalInt64Extension.Type.ValueOf(int64(64)),
302 )
303 m.ProtoReflect().Set(
304 testpb.E_OptionalNestedMessageExtension.Type,
305 testpb.E_OptionalNestedMessageExtension.Type.ValueOf(&testpb.TestAllTypes_NestedMessage{
306 Corecursive: &testpb.TestAllTypes{
307 OptionalInt64: scalar.Int64(1000),
308 },
309 }),
310 )
311 m.ProtoReflect().Set(
312 testpb.E_RepeatedFixed32Extension.Type,
313 testpb.E_RepeatedFixed32Extension.Type.ValueOf(&[]uint32{4, 5, 6}),
314 )
315 return m
316 }(),
317 want: func() proto.Message {
318 m := new(testpb.TestAllExtensions)
319 m.ProtoReflect().Set(
320 testpb.E_OptionalInt32Extension.Type,
321 testpb.E_OptionalInt32Extension.Type.ValueOf(int32(32)),
322 )
323 m.ProtoReflect().Set(
324 testpb.E_OptionalInt64Extension.Type,
325 testpb.E_OptionalInt64Extension.Type.ValueOf(int64(64)),
326 )
327 m.ProtoReflect().Set(
328 testpb.E_OptionalNestedMessageExtension.Type,
329 testpb.E_OptionalNestedMessageExtension.Type.ValueOf(&testpb.TestAllTypes_NestedMessage{
330 A: scalar.Int32(50),
331 Corecursive: &testpb.TestAllTypes{
332 OptionalInt64: scalar.Int64(1000),
333 },
334 }),
335 )
336 m.ProtoReflect().Set(
337 testpb.E_RepeatedFixed32Extension.Type,
338 testpb.E_RepeatedFixed32Extension.Type.ValueOf(&[]uint32{1, 2, 3, 4, 5, 6}),
339 )
340 return m
341 }(),
342 }, {
343 desc: "merge unknown fields",
344 dst: func() proto.Message {
345 m := new(testpb.TestAllTypes)
346 m.ProtoReflect().SetUnknown(pack.Message{
347 pack.Tag{Number: 50000, Type: pack.VarintType}, pack.Svarint(-5),
348 }.Marshal())
349 return m
350 }(),
351 src: func() proto.Message {
352 m := new(testpb.TestAllTypes)
353 m.ProtoReflect().SetUnknown(pack.Message{
354 pack.Tag{Number: 500000, Type: pack.VarintType}, pack.Svarint(-50),
355 }.Marshal())
356 return m
357 }(),
358 want: func() proto.Message {
359 m := new(testpb.TestAllTypes)
360 m.ProtoReflect().SetUnknown(pack.Message{
361 pack.Tag{Number: 50000, Type: pack.VarintType}, pack.Svarint(-5),
362 pack.Tag{Number: 500000, Type: pack.VarintType}, pack.Svarint(-50),
363 }.Marshal())
364 return m
365 }(),
366 }}
367
368 for _, tt := range tests {
369 t.Run(tt.desc, func(t *testing.T) {
370 // Merge should be semantically equivalent to unmarshaling the
371 // encoded form of src into the current dst.
372 b1, err := proto.MarshalOptions{AllowPartial: true}.Marshal(tt.dst)
373 if err != nil {
374 t.Fatalf("Marshal(dst) error: %v", err)
375 }
376 b2, err := proto.MarshalOptions{AllowPartial: true}.Marshal(tt.src)
377 if err != nil {
378 t.Fatalf("Marshal(src) error: %v", err)
379 }
380 dst := tt.dst.ProtoReflect().New().Interface()
381 err = proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(append(b1, b2...), dst)
382 if err != nil {
383 t.Fatalf("Unmarshal() error: %v", err)
384 }
385 if !proto.Equal(dst, tt.want) && !tt.skipMarshalUnmarshal {
386 t.Fatalf("Unmarshal(Marshal(dst)+Marshal(src)) mismatch: got %v, want %v", dst, tt.want)
387 }
388
389 proto.Merge(tt.dst, tt.src)
390 if tt.mutator != nil {
391 tt.mutator(tt.src) // should not be observable by dst
392 }
393 if !proto.Equal(tt.dst, tt.want) {
394 t.Fatalf("Merge() mismatch: got %v, want %v", tt.dst, tt.want)
395 }
396 })
397 }
398}