blob: c9f7c7aec325f526de11d83221c326d1bfe90015 [file] [log] [blame]
Damien Neil5ec58b92019-04-30 11:36:39 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package prototest exercises protobuf reflection.
6package prototest
7
8import (
9 "bytes"
10 "fmt"
11 "math"
12 "sort"
13 "testing"
14
Damien Neile89e6242019-05-13 23:55:40 -070015 textpb "google.golang.org/protobuf/encoding/textpb"
16 "google.golang.org/protobuf/proto"
17 pref "google.golang.org/protobuf/reflect/protoreflect"
Damien Neil5ec58b92019-04-30 11:36:39 -070018)
19
20// TestMessage runs the provided message through a series of tests
21// exercising the protobuf reflection API.
22func TestMessage(t testing.TB, message proto.Message) {
Joe Tsai0fc49f82019-05-01 12:29:25 -070023 md := message.ProtoReflect().Descriptor()
Damien Neil5ec58b92019-04-30 11:36:39 -070024
Joe Tsai0fc49f82019-05-01 12:29:25 -070025 m := message.ProtoReflect().New()
Damien Neil5ec58b92019-04-30 11:36:39 -070026 for i := 0; i < md.Fields().Len(); i++ {
27 fd := md.Fields().Get(i)
28 switch {
Joe Tsaiac31a352019-05-13 14:32:56 -070029 case fd.IsList():
30 testFieldList(t, m, fd)
Damien Neil5ec58b92019-04-30 11:36:39 -070031 case fd.IsMap():
32 testFieldMap(t, m, fd)
Damien Neil5ec58b92019-04-30 11:36:39 -070033 case fd.Kind() == pref.FloatKind || fd.Kind() == pref.DoubleKind:
34 testFieldFloat(t, m, fd)
35 }
36 testField(t, m, fd)
37 }
38 for i := 0; i < md.Oneofs().Len(); i++ {
39 testOneof(t, m, md.Oneofs().Get(i))
40 }
41
42 // Test has/get/clear on a non-existent field.
43 for num := pref.FieldNumber(1); ; num++ {
44 if md.Fields().ByNumber(num) != nil {
45 continue
46 }
47 if md.ExtensionRanges().Has(num) {
48 continue
49 }
50 // Field num does not exist.
51 if m.KnownFields().Has(num) {
52 t.Errorf("non-existent field: Has(%v) = true, want false", num)
53 }
54 if v := m.KnownFields().Get(num); v.IsValid() {
55 t.Errorf("non-existent field: Get(%v) = %v, want invalid", num, formatValue(v))
56 }
57 m.KnownFields().Clear(num) // noop
58 break
59 }
60
61 // Test WhichOneof on a non-existent oneof.
62 const invalidName = "invalid-name"
63 if got, want := m.KnownFields().WhichOneof(invalidName), pref.FieldNumber(0); got != want {
64 t.Errorf("non-existent oneof: WhichOneof(%q) = %v, want %v", invalidName, got, want)
65 }
66
67 // TODO: Extensions, unknown fields.
68
69 // Test round-trip marshal/unmarshal.
Joe Tsai0fc49f82019-05-01 12:29:25 -070070 m1 := message.ProtoReflect().New().Interface()
Damien Neil5ec58b92019-04-30 11:36:39 -070071 populateMessage(m1.ProtoReflect(), 1, nil)
72 b, err := proto.Marshal(m1)
73 if err != nil {
74 t.Errorf("Marshal() = %v, want nil\n%v", err, marshalText(m1))
75 }
Joe Tsai0fc49f82019-05-01 12:29:25 -070076 m2 := message.ProtoReflect().New().Interface()
Damien Neil5ec58b92019-04-30 11:36:39 -070077 if err := proto.Unmarshal(b, m2); err != nil {
78 t.Errorf("Unmarshal() = %v, want nil\n%v", err, marshalText(m1))
79 }
80 if !proto.Equal(m1, m2) {
81 t.Errorf("round-trip marshal/unmarshal did not preserve message.\nOriginal:\n%v\nNew:\n%v", marshalText(m1), marshalText(m2))
82 }
83}
84
85func marshalText(m proto.Message) string {
86 b, _ := textpb.MarshalOptions{Indent: " "}.Marshal(m)
87 return string(b)
88}
89
90// testField exericises set/get/has/clear of a field.
91func testField(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
92 num := fd.Number()
93 name := fd.FullName()
94 known := m.KnownFields()
95
96 // Set to a non-zero value, the zero value, different non-zero values.
97 for _, n := range []seed{1, 0, minVal, maxVal} {
98 v := newValue(m, fd, n, nil)
99 known.Set(num, v)
100 wantHas := true
101 if n == 0 {
102 if fd.Syntax() == pref.Proto3 && fd.Message() == nil {
103 wantHas = false
104 }
105 if fd.Cardinality() == pref.Repeated {
106 wantHas = false
107 }
Joe Tsaiac31a352019-05-13 14:32:56 -0700108 if fd.ContainingOneof() != nil {
Damien Neil5ec58b92019-04-30 11:36:39 -0700109 wantHas = true
110 }
111 }
112 if got, want := known.Has(num), wantHas; got != want {
113 t.Errorf("after setting %q to %v:\nHas(%v) = %v, want %v", name, formatValue(v), num, got, want)
114 }
115 if got, want := known.Get(num), v; !valueEqual(got, want) {
116 t.Errorf("after setting %q:\nGet(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
117 }
118 }
119
120 known.Clear(num)
121 if got, want := known.Has(num), false; got != want {
122 t.Errorf("after clearing %q:\nHas(%v) = %v, want %v", name, num, got, want)
123 }
124 switch {
Joe Tsaiac31a352019-05-13 14:32:56 -0700125 case fd.IsList():
126 if got := known.Get(num); got.List().Len() != 0 {
Damien Neil5ec58b92019-04-30 11:36:39 -0700127 t.Errorf("after clearing %q:\nGet(%v) = %v, want empty list", name, num, formatValue(got))
128 }
Joe Tsaiac31a352019-05-13 14:32:56 -0700129 case fd.IsMap():
130 if got := known.Get(num); got.Map().Len() != 0 {
Damien Neil5ec58b92019-04-30 11:36:39 -0700131 t.Errorf("after clearing %q:\nGet(%v) = %v, want empty list", name, num, formatValue(got))
132 }
133 default:
134 if got, want := known.Get(num), fd.Default(); !valueEqual(got, want) {
135 t.Errorf("after clearing %q:\nGet(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want))
136 }
137 }
138}
139
140// testFieldMap tests set/get/has/clear of entries in a map field.
141func testFieldMap(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
142 num := fd.Number()
143 name := fd.FullName()
144 known := m.KnownFields()
145 known.Clear(num) // start with an empty map
146 mapv := known.Get(num).Map()
147
148 // Add values.
149 want := make(testMap)
150 for i, n := range []seed{1, 0, minVal, maxVal} {
151 if got, want := known.Has(num), i > 0; got != want {
152 t.Errorf("after inserting %d elements to %q:\nHas(%v) = %v, want %v", i, name, num, got, want)
153 }
154
155 k := newMapKey(fd, n)
156 v := newMapValue(fd, mapv, n, nil)
157 mapv.Set(k, v)
158 want.Set(k, v)
159 if got, want := known.Get(num), pref.ValueOf(want); !valueEqual(got, want) {
160 t.Errorf("after inserting %d elements to %q:\nGet(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want))
161 }
162 }
163
164 // Set values.
165 want.Range(func(k pref.MapKey, v pref.Value) bool {
166 nv := newMapValue(fd, mapv, 10, nil)
167 mapv.Set(k, nv)
168 want.Set(k, nv)
169 if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
170 t.Errorf("after setting element %v of %q:\nGet(%v) = %v, want %v", formatValue(k.Value()), name, num, formatValue(got), formatValue(want))
171 }
172 return true
173 })
174
175 // Clear values.
176 want.Range(func(k pref.MapKey, v pref.Value) bool {
177 mapv.Clear(k)
178 want.Clear(k)
179 if got, want := known.Has(num), want.Len() > 0; got != want {
180 t.Errorf("after clearing elements of %q:\nHas(%v) = %v, want %v", name, num, got, want)
181 }
182 if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
183 t.Errorf("after clearing elements of %q:\nGet(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
184 }
185 return true
186 })
187
188 // Non-existent map keys.
189 missingKey := newMapKey(fd, 1)
190 if got, want := mapv.Has(missingKey), false; got != want {
191 t.Errorf("non-existent map key in %q: Has(%v) = %v, want %v", name, formatValue(missingKey.Value()), got, want)
192 }
193 if got, want := mapv.Get(missingKey).IsValid(), false; got != want {
194 t.Errorf("non-existent map key in %q: Get(%v).IsValid() = %v, want %v", name, formatValue(missingKey.Value()), got, want)
195 }
196 mapv.Clear(missingKey) // noop
197}
198
199type testMap map[interface{}]pref.Value
200
201func (m testMap) Get(k pref.MapKey) pref.Value { return m[k.Interface()] }
202func (m testMap) Set(k pref.MapKey, v pref.Value) { m[k.Interface()] = v }
203func (m testMap) Has(k pref.MapKey) bool { return m.Get(k).IsValid() }
204func (m testMap) Clear(k pref.MapKey) { delete(m, k.Interface()) }
205func (m testMap) Len() int { return len(m) }
206func (m testMap) NewMessage() pref.Message { panic("unimplemented") }
207func (m testMap) Range(f func(pref.MapKey, pref.Value) bool) {
208 for k, v := range m {
209 if !f(pref.ValueOf(k).MapKey(), v) {
210 return
211 }
212 }
213}
214
215// testFieldList exercises set/get/append/truncate of values in a list.
216func testFieldList(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
217 num := fd.Number()
218 name := fd.FullName()
219 known := m.KnownFields()
220 known.Clear(num) // start with an empty list
221 list := known.Get(num).List()
222
223 // Append values.
224 var want pref.List = &testList{}
225 for i, n := range []seed{1, 0, minVal, maxVal} {
226 if got, want := known.Has(num), i > 0; got != want {
227 t.Errorf("after appending %d elements to %q:\nHas(%v) = %v, want %v", i, name, num, got, want)
228 }
229 v := newListElement(fd, list, n, nil)
230 want.Append(v)
231 list.Append(v)
232
233 if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
234 t.Errorf("after appending %d elements to %q:\nGet(%v) = %v, want %v", i+1, name, num, formatValue(got), formatValue(want))
235 }
236 }
237
238 // Set values.
239 for i := 0; i < want.Len(); i++ {
240 v := newListElement(fd, list, seed(i+10), nil)
241 want.Set(i, v)
242 list.Set(i, v)
243 if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
244 t.Errorf("after setting element %d of %q:\nGet(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want))
245 }
246 }
247
248 // Truncate.
249 for want.Len() > 0 {
250 n := want.Len() - 1
251 want.Truncate(n)
252 list.Truncate(n)
253 if got, want := known.Has(num), want.Len() > 0; got != want {
254 t.Errorf("after truncating %q to %d:\nHas(%v) = %v, want %v", name, n, num, got, want)
255 }
256 if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
257 t.Errorf("after truncating %q to %d:\nGet(%v) = %v, want %v", name, n, num, formatValue(got), formatValue(want))
258 }
259 }
260}
261
262type testList struct {
263 a []pref.Value
264}
265
266func (l *testList) Append(v pref.Value) { l.a = append(l.a, v) }
267func (l *testList) Get(n int) pref.Value { return l.a[n] }
268func (l *testList) Len() int { return len(l.a) }
269func (l *testList) Set(n int, v pref.Value) { l.a[n] = v }
270func (l *testList) Truncate(n int) { l.a = l.a[:n] }
271func (l *testList) NewMessage() pref.Message { panic("unimplemented") }
272
273// testFieldFloat exercises some interesting floating-point scalar field values.
274func testFieldFloat(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
275 num := fd.Number()
276 name := fd.FullName()
277 known := m.KnownFields()
278 for _, v := range []float64{math.Inf(-1), math.Inf(1), math.NaN(), math.Copysign(0, -1)} {
279 var val pref.Value
280 if fd.Kind() == pref.FloatKind {
281 val = pref.ValueOf(float32(v))
282 } else {
283 val = pref.ValueOf(v)
284 }
285 known.Set(num, val)
286 // Note that Has is true for -0.
287 if got, want := known.Has(num), true; got != want {
288 t.Errorf("after setting %v to %v: Get(%v) = %v, want %v", name, v, num, got, want)
289 }
290 if got, want := known.Get(num), val; !valueEqual(got, want) {
291 t.Errorf("after setting %v: Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
292 }
293 }
294}
295
296// testOneof tests the behavior of fields in a oneof.
297func testOneof(t testing.TB, m pref.Message, od pref.OneofDescriptor) {
298 known := m.KnownFields()
299 for i := 0; i < od.Fields().Len(); i++ {
300 fda := od.Fields().Get(i)
301 known.Set(fda.Number(), newValue(m, fda, 1, nil))
302 if got, want := known.WhichOneof(od.Name()), fda.Number(); got != want {
303 t.Errorf("after setting oneof field %q:\nWhichOneof(%q) = %v, want %v", fda.FullName(), fda.Name(), got, want)
304 }
305 for j := 0; j < od.Fields().Len(); j++ {
306 fdb := od.Fields().Get(j)
307 if got, want := known.Has(fdb.Number()), i == j; got != want {
308 t.Errorf("after setting oneof field %q:\nGet(%q) = %v, want %v", fda.FullName(), fdb.FullName(), got, want)
309 }
310 }
311 }
312}
313
314func formatValue(v pref.Value) string {
315 switch v := v.Interface().(type) {
316 case pref.List:
317 var buf bytes.Buffer
318 buf.WriteString("list[")
319 for i := 0; i < v.Len(); i++ {
320 if i > 0 {
321 buf.WriteString(" ")
322 }
323 buf.WriteString(formatValue(v.Get(i)))
324 }
325 buf.WriteString("]")
326 return buf.String()
327 case pref.Map:
328 var buf bytes.Buffer
329 buf.WriteString("map[")
330 var keys []pref.MapKey
331 v.Range(func(k pref.MapKey, v pref.Value) bool {
332 keys = append(keys, k)
333 return true
334 })
335 sort.Slice(keys, func(i, j int) bool {
336 return keys[i].String() < keys[j].String()
337 })
338 for i, k := range keys {
339 if i > 0 {
340 buf.WriteString(" ")
341 }
342 buf.WriteString(formatValue(k.Value()))
343 buf.WriteString(":")
344 buf.WriteString(formatValue(v.Get(k)))
345 }
346 buf.WriteString("]")
347 return buf.String()
348 case pref.Message:
349 b, err := textpb.Marshal(v.Interface())
350 if err != nil {
351 return fmt.Sprintf("<%v>", err)
352 }
Joe Tsai0fc49f82019-05-01 12:29:25 -0700353 return fmt.Sprintf("%v{%v}", v.Descriptor().FullName(), string(b))
Damien Neil5ec58b92019-04-30 11:36:39 -0700354 case string:
355 return fmt.Sprintf("%q", v)
356 default:
357 return fmt.Sprint(v)
358 }
359}
360
361func valueEqual(a, b pref.Value) bool {
362 ai, bi := a.Interface(), b.Interface()
363 switch ai.(type) {
364 case pref.Message:
365 return proto.Equal(
366 a.Message().Interface(),
367 b.Message().Interface(),
368 )
369 case pref.List:
370 lista, listb := a.List(), b.List()
371 if lista.Len() != listb.Len() {
372 return false
373 }
374 for i := 0; i < lista.Len(); i++ {
375 if !valueEqual(lista.Get(i), listb.Get(i)) {
376 return false
377 }
378 }
379 return true
380 case pref.Map:
381 mapa, mapb := a.Map(), b.Map()
382 if mapa.Len() != mapb.Len() {
383 return false
384 }
385 equal := true
386 mapa.Range(func(k pref.MapKey, v pref.Value) bool {
387 if !valueEqual(v, mapb.Get(k)) {
388 equal = false
389 return false
390 }
391 return true
392 })
393 return equal
394 case []byte:
395 return bytes.Equal(a.Bytes(), b.Bytes())
396 case float32, float64:
397 // NaNs are equal, but must be the same NaN.
398 return math.Float64bits(a.Float()) == math.Float64bits(a.Float())
399 default:
400 return ai == bi
401 }
402}
403
404// A seed is used to vary the content of a value.
405//
406// A seed of 0 is the zero value. Messages do not have a zero-value; a 0-seeded messages
407// is unpopulated.
408//
409// A seed of minVal or maxVal is the least or greatest value of the value type.
410type seed int
411
412const (
413 minVal seed = -1
414 maxVal seed = -2
415)
416
417// newValue returns a new value assignable to a field.
418//
419// The stack parameter is used to avoid infinite recursion when populating circular
420// data structures.
Joe Tsai0fc49f82019-05-01 12:29:25 -0700421func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.MessageDescriptor) pref.Value {
Damien Neil5ec58b92019-04-30 11:36:39 -0700422 num := fd.Number()
423 switch {
Joe Tsaiac31a352019-05-13 14:32:56 -0700424 case fd.IsList():
425 list := m.New().KnownFields().Get(num).List()
426 if n == 0 {
427 return pref.ValueOf(list)
428 }
429 list.Append(newListElement(fd, list, 0, stack))
430 list.Append(newListElement(fd, list, minVal, stack))
431 list.Append(newListElement(fd, list, maxVal, stack))
432 list.Append(newListElement(fd, list, n, stack))
433 return pref.ValueOf(list)
Damien Neil5ec58b92019-04-30 11:36:39 -0700434 case fd.IsMap():
Joe Tsai0fc49f82019-05-01 12:29:25 -0700435 mapv := m.New().KnownFields().Get(num).Map()
Damien Neil5ec58b92019-04-30 11:36:39 -0700436 if n == 0 {
437 return pref.ValueOf(mapv)
438 }
439 mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, stack))
440 mapv.Set(newMapKey(fd, minVal), newMapValue(fd, mapv, minVal, stack))
441 mapv.Set(newMapKey(fd, maxVal), newMapValue(fd, mapv, maxVal, stack))
442 mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, 10*n, stack))
443 return pref.ValueOf(mapv)
Damien Neil5ec58b92019-04-30 11:36:39 -0700444 case fd.Message() != nil:
445 return populateMessage(m.KnownFields().NewMessage(num), n, stack)
446 default:
447 return newScalarValue(fd, n)
448 }
449}
450
Joe Tsai0fc49f82019-05-01 12:29:25 -0700451func newListElement(fd pref.FieldDescriptor, list pref.List, n seed, stack []pref.MessageDescriptor) pref.Value {
Damien Neil5ec58b92019-04-30 11:36:39 -0700452 if fd.Message() == nil {
453 return newScalarValue(fd, n)
454 }
455 return populateMessage(list.NewMessage(), n, stack)
456}
457
458func newMapKey(fd pref.FieldDescriptor, n seed) pref.MapKey {
Joe Tsaiac31a352019-05-13 14:32:56 -0700459 kd := fd.MapKey()
Damien Neil5ec58b92019-04-30 11:36:39 -0700460 return newScalarValue(kd, n).MapKey()
461}
462
Joe Tsai0fc49f82019-05-01 12:29:25 -0700463func newMapValue(fd pref.FieldDescriptor, mapv pref.Map, n seed, stack []pref.MessageDescriptor) pref.Value {
Joe Tsaiac31a352019-05-13 14:32:56 -0700464 vd := fd.MapValue()
Damien Neil5ec58b92019-04-30 11:36:39 -0700465 if vd.Message() == nil {
466 return newScalarValue(vd, n)
467 }
468 return populateMessage(mapv.NewMessage(), n, stack)
469}
470
471func newScalarValue(fd pref.FieldDescriptor, n seed) pref.Value {
472 switch fd.Kind() {
473 case pref.BoolKind:
474 return pref.ValueOf(n != 0)
475 case pref.EnumKind:
476 // TODO use actual value
477 return pref.ValueOf(pref.EnumNumber(n))
478 case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
479 switch n {
480 case minVal:
481 return pref.ValueOf(int32(math.MinInt32))
482 case maxVal:
483 return pref.ValueOf(int32(math.MaxInt32))
484 default:
485 return pref.ValueOf(int32(n))
486 }
487 case pref.Uint32Kind, pref.Fixed32Kind:
488 switch n {
489 case minVal:
490 // Only use 0 for the zero value.
491 return pref.ValueOf(uint32(1))
492 case maxVal:
493 return pref.ValueOf(uint32(math.MaxInt32))
494 default:
495 return pref.ValueOf(uint32(n))
496 }
497 case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
498 switch n {
499 case minVal:
500 return pref.ValueOf(int64(math.MinInt64))
501 case maxVal:
502 return pref.ValueOf(int64(math.MaxInt64))
503 default:
504 return pref.ValueOf(int64(n))
505 }
506 case pref.Uint64Kind, pref.Fixed64Kind:
507 switch n {
508 case minVal:
509 // Only use 0 for the zero value.
510 return pref.ValueOf(uint64(1))
511 case maxVal:
512 return pref.ValueOf(uint64(math.MaxInt64))
513 default:
514 return pref.ValueOf(uint64(n))
515 }
516 case pref.FloatKind:
517 switch n {
518 case minVal:
519 return pref.ValueOf(float32(math.SmallestNonzeroFloat32))
520 case maxVal:
521 return pref.ValueOf(float32(math.MaxFloat32))
522 default:
523 return pref.ValueOf(1.5 * float32(n))
524 }
525 case pref.DoubleKind:
526 switch n {
527 case minVal:
528 return pref.ValueOf(float64(math.SmallestNonzeroFloat64))
529 case maxVal:
530 return pref.ValueOf(float64(math.MaxFloat64))
531 default:
532 return pref.ValueOf(1.5 * float64(n))
533 }
534 case pref.StringKind:
535 if n == 0 {
536 return pref.ValueOf("")
537 }
538 return pref.ValueOf(fmt.Sprintf("%d", n))
539 case pref.BytesKind:
540 if n == 0 {
541 return pref.ValueOf([]byte(nil))
542 }
543 return pref.ValueOf([]byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)})
544 }
545 panic("unhandled kind")
546}
547
Joe Tsai0fc49f82019-05-01 12:29:25 -0700548func populateMessage(m pref.Message, n seed, stack []pref.MessageDescriptor) pref.Value {
Damien Neil5ec58b92019-04-30 11:36:39 -0700549 if n == 0 {
550 return pref.ValueOf(m)
551 }
Joe Tsai0fc49f82019-05-01 12:29:25 -0700552 md := m.Descriptor()
Damien Neil5ec58b92019-04-30 11:36:39 -0700553 for _, x := range stack {
554 if md == x {
555 return pref.ValueOf(m)
556 }
557 }
558 stack = append(stack, md)
559 known := m.KnownFields()
560 for i := 0; i < md.Fields().Len(); i++ {
561 fd := md.Fields().Get(i)
562 if fd.IsWeak() {
563 continue
564 }
565 known.Set(fd.Number(), newValue(m, fd, 10*n+seed(i), stack))
566 }
567 return pref.ValueOf(m)
568}