blob: ede6ca34159a02770f5d4754f8b784ccf7750f48 [file] [log] [blame]
Damien Neil5ec58b92019-04-30 11:36:39 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package prototest exercises protobuf reflection.
6package prototest
7
8import (
9 "bytes"
10 "fmt"
11 "math"
12 "sort"
13 "testing"
14
15 textpb "github.com/golang/protobuf/v2/encoding/textpb"
16 "github.com/golang/protobuf/v2/proto"
17 pref "github.com/golang/protobuf/v2/reflect/protoreflect"
18)
19
20// TestMessage runs the provided message through a series of tests
21// exercising the protobuf reflection API.
22func TestMessage(t testing.TB, message proto.Message) {
23 md := message.ProtoReflect().Type()
24
25 m := md.New()
26 for i := 0; i < md.Fields().Len(); i++ {
27 fd := md.Fields().Get(i)
28 switch {
29 case fd.IsWeak():
30 // TODO: Weak field support.
31 continue
32 case fd.IsMap():
33 testFieldMap(t, m, fd)
34 case fd.Cardinality() == pref.Repeated:
35 testFieldList(t, m, fd)
36 case fd.Kind() == pref.FloatKind || fd.Kind() == pref.DoubleKind:
37 testFieldFloat(t, m, fd)
38 }
39 testField(t, m, fd)
40 }
41 for i := 0; i < md.Oneofs().Len(); i++ {
42 testOneof(t, m, md.Oneofs().Get(i))
43 }
44
45 // Test has/get/clear on a non-existent field.
46 for num := pref.FieldNumber(1); ; num++ {
47 if md.Fields().ByNumber(num) != nil {
48 continue
49 }
50 if md.ExtensionRanges().Has(num) {
51 continue
52 }
53 // Field num does not exist.
54 if m.KnownFields().Has(num) {
55 t.Errorf("non-existent field: Has(%v) = true, want false", num)
56 }
57 if v := m.KnownFields().Get(num); v.IsValid() {
58 t.Errorf("non-existent field: Get(%v) = %v, want invalid", num, formatValue(v))
59 }
60 m.KnownFields().Clear(num) // noop
61 break
62 }
63
64 // Test WhichOneof on a non-existent oneof.
65 const invalidName = "invalid-name"
66 if got, want := m.KnownFields().WhichOneof(invalidName), pref.FieldNumber(0); got != want {
67 t.Errorf("non-existent oneof: WhichOneof(%q) = %v, want %v", invalidName, got, want)
68 }
69
70 // TODO: Extensions, unknown fields.
71
72 // Test round-trip marshal/unmarshal.
73 m1 := md.New().Interface()
74 populateMessage(m1.ProtoReflect(), 1, nil)
75 b, err := proto.Marshal(m1)
76 if err != nil {
77 t.Errorf("Marshal() = %v, want nil\n%v", err, marshalText(m1))
78 }
79 m2 := md.New().Interface()
80 if err := proto.Unmarshal(b, m2); err != nil {
81 t.Errorf("Unmarshal() = %v, want nil\n%v", err, marshalText(m1))
82 }
83 if !proto.Equal(m1, m2) {
84 t.Errorf("round-trip marshal/unmarshal did not preserve message.\nOriginal:\n%v\nNew:\n%v", marshalText(m1), marshalText(m2))
85 }
86}
87
88func marshalText(m proto.Message) string {
89 b, _ := textpb.MarshalOptions{Indent: " "}.Marshal(m)
90 return string(b)
91}
92
93// testField exericises set/get/has/clear of a field.
94func testField(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
95 num := fd.Number()
96 name := fd.FullName()
97 known := m.KnownFields()
98
99 // Set to a non-zero value, the zero value, different non-zero values.
100 for _, n := range []seed{1, 0, minVal, maxVal} {
101 v := newValue(m, fd, n, nil)
102 known.Set(num, v)
103 wantHas := true
104 if n == 0 {
105 if fd.Syntax() == pref.Proto3 && fd.Message() == nil {
106 wantHas = false
107 }
108 if fd.Cardinality() == pref.Repeated {
109 wantHas = false
110 }
111 if fd.Oneof() != nil {
112 wantHas = true
113 }
114 }
115 if got, want := known.Has(num), wantHas; got != want {
116 t.Errorf("after setting %q to %v:\nHas(%v) = %v, want %v", name, formatValue(v), num, got, want)
117 }
118 if got, want := known.Get(num), v; !valueEqual(got, want) {
119 t.Errorf("after setting %q:\nGet(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
120 }
121 }
122
123 known.Clear(num)
124 if got, want := known.Has(num), false; got != want {
125 t.Errorf("after clearing %q:\nHas(%v) = %v, want %v", name, num, got, want)
126 }
127 switch {
128 case fd.IsMap():
129 if got := known.Get(num); got.Map().Len() != 0 {
130 t.Errorf("after clearing %q:\nGet(%v) = %v, want empty list", name, num, formatValue(got))
131 }
132 case fd.Cardinality() == pref.Repeated:
133 if got := known.Get(num); got.List().Len() != 0 {
134 t.Errorf("after clearing %q:\nGet(%v) = %v, want empty list", name, num, formatValue(got))
135 }
136 default:
137 if got, want := known.Get(num), fd.Default(); !valueEqual(got, want) {
138 t.Errorf("after clearing %q:\nGet(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want))
139 }
140 }
141}
142
143// testFieldMap tests set/get/has/clear of entries in a map field.
144func testFieldMap(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
145 num := fd.Number()
146 name := fd.FullName()
147 known := m.KnownFields()
148 known.Clear(num) // start with an empty map
149 mapv := known.Get(num).Map()
150
151 // Add values.
152 want := make(testMap)
153 for i, n := range []seed{1, 0, minVal, maxVal} {
154 if got, want := known.Has(num), i > 0; got != want {
155 t.Errorf("after inserting %d elements to %q:\nHas(%v) = %v, want %v", i, name, num, got, want)
156 }
157
158 k := newMapKey(fd, n)
159 v := newMapValue(fd, mapv, n, nil)
160 mapv.Set(k, v)
161 want.Set(k, v)
162 if got, want := known.Get(num), pref.ValueOf(want); !valueEqual(got, want) {
163 t.Errorf("after inserting %d elements to %q:\nGet(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want))
164 }
165 }
166
167 // Set values.
168 want.Range(func(k pref.MapKey, v pref.Value) bool {
169 nv := newMapValue(fd, mapv, 10, nil)
170 mapv.Set(k, nv)
171 want.Set(k, nv)
172 if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
173 t.Errorf("after setting element %v of %q:\nGet(%v) = %v, want %v", formatValue(k.Value()), name, num, formatValue(got), formatValue(want))
174 }
175 return true
176 })
177
178 // Clear values.
179 want.Range(func(k pref.MapKey, v pref.Value) bool {
180 mapv.Clear(k)
181 want.Clear(k)
182 if got, want := known.Has(num), want.Len() > 0; got != want {
183 t.Errorf("after clearing elements of %q:\nHas(%v) = %v, want %v", name, num, got, want)
184 }
185 if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
186 t.Errorf("after clearing elements of %q:\nGet(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
187 }
188 return true
189 })
190
191 // Non-existent map keys.
192 missingKey := newMapKey(fd, 1)
193 if got, want := mapv.Has(missingKey), false; got != want {
194 t.Errorf("non-existent map key in %q: Has(%v) = %v, want %v", name, formatValue(missingKey.Value()), got, want)
195 }
196 if got, want := mapv.Get(missingKey).IsValid(), false; got != want {
197 t.Errorf("non-existent map key in %q: Get(%v).IsValid() = %v, want %v", name, formatValue(missingKey.Value()), got, want)
198 }
199 mapv.Clear(missingKey) // noop
200}
201
202type testMap map[interface{}]pref.Value
203
204func (m testMap) Get(k pref.MapKey) pref.Value { return m[k.Interface()] }
205func (m testMap) Set(k pref.MapKey, v pref.Value) { m[k.Interface()] = v }
206func (m testMap) Has(k pref.MapKey) bool { return m.Get(k).IsValid() }
207func (m testMap) Clear(k pref.MapKey) { delete(m, k.Interface()) }
208func (m testMap) Len() int { return len(m) }
209func (m testMap) NewMessage() pref.Message { panic("unimplemented") }
210func (m testMap) Range(f func(pref.MapKey, pref.Value) bool) {
211 for k, v := range m {
212 if !f(pref.ValueOf(k).MapKey(), v) {
213 return
214 }
215 }
216}
217
218// testFieldList exercises set/get/append/truncate of values in a list.
219func testFieldList(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
220 num := fd.Number()
221 name := fd.FullName()
222 known := m.KnownFields()
223 known.Clear(num) // start with an empty list
224 list := known.Get(num).List()
225
226 // Append values.
227 var want pref.List = &testList{}
228 for i, n := range []seed{1, 0, minVal, maxVal} {
229 if got, want := known.Has(num), i > 0; got != want {
230 t.Errorf("after appending %d elements to %q:\nHas(%v) = %v, want %v", i, name, num, got, want)
231 }
232 v := newListElement(fd, list, n, nil)
233 want.Append(v)
234 list.Append(v)
235
236 if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
237 t.Errorf("after appending %d elements to %q:\nGet(%v) = %v, want %v", i+1, name, num, formatValue(got), formatValue(want))
238 }
239 }
240
241 // Set values.
242 for i := 0; i < want.Len(); i++ {
243 v := newListElement(fd, list, seed(i+10), nil)
244 want.Set(i, v)
245 list.Set(i, v)
246 if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
247 t.Errorf("after setting element %d of %q:\nGet(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want))
248 }
249 }
250
251 // Truncate.
252 for want.Len() > 0 {
253 n := want.Len() - 1
254 want.Truncate(n)
255 list.Truncate(n)
256 if got, want := known.Has(num), want.Len() > 0; got != want {
257 t.Errorf("after truncating %q to %d:\nHas(%v) = %v, want %v", name, n, num, got, want)
258 }
259 if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
260 t.Errorf("after truncating %q to %d:\nGet(%v) = %v, want %v", name, n, num, formatValue(got), formatValue(want))
261 }
262 }
263}
264
265type testList struct {
266 a []pref.Value
267}
268
269func (l *testList) Append(v pref.Value) { l.a = append(l.a, v) }
270func (l *testList) Get(n int) pref.Value { return l.a[n] }
271func (l *testList) Len() int { return len(l.a) }
272func (l *testList) Set(n int, v pref.Value) { l.a[n] = v }
273func (l *testList) Truncate(n int) { l.a = l.a[:n] }
274func (l *testList) NewMessage() pref.Message { panic("unimplemented") }
275
276// testFieldFloat exercises some interesting floating-point scalar field values.
277func testFieldFloat(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
278 num := fd.Number()
279 name := fd.FullName()
280 known := m.KnownFields()
281 for _, v := range []float64{math.Inf(-1), math.Inf(1), math.NaN(), math.Copysign(0, -1)} {
282 var val pref.Value
283 if fd.Kind() == pref.FloatKind {
284 val = pref.ValueOf(float32(v))
285 } else {
286 val = pref.ValueOf(v)
287 }
288 known.Set(num, val)
289 // Note that Has is true for -0.
290 if got, want := known.Has(num), true; got != want {
291 t.Errorf("after setting %v to %v: Get(%v) = %v, want %v", name, v, num, got, want)
292 }
293 if got, want := known.Get(num), val; !valueEqual(got, want) {
294 t.Errorf("after setting %v: Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
295 }
296 }
297}
298
299// testOneof tests the behavior of fields in a oneof.
300func testOneof(t testing.TB, m pref.Message, od pref.OneofDescriptor) {
301 known := m.KnownFields()
302 for i := 0; i < od.Fields().Len(); i++ {
303 fda := od.Fields().Get(i)
304 known.Set(fda.Number(), newValue(m, fda, 1, nil))
305 if got, want := known.WhichOneof(od.Name()), fda.Number(); got != want {
306 t.Errorf("after setting oneof field %q:\nWhichOneof(%q) = %v, want %v", fda.FullName(), fda.Name(), got, want)
307 }
308 for j := 0; j < od.Fields().Len(); j++ {
309 fdb := od.Fields().Get(j)
310 if got, want := known.Has(fdb.Number()), i == j; got != want {
311 t.Errorf("after setting oneof field %q:\nGet(%q) = %v, want %v", fda.FullName(), fdb.FullName(), got, want)
312 }
313 }
314 }
315}
316
317func formatValue(v pref.Value) string {
318 switch v := v.Interface().(type) {
319 case pref.List:
320 var buf bytes.Buffer
321 buf.WriteString("list[")
322 for i := 0; i < v.Len(); i++ {
323 if i > 0 {
324 buf.WriteString(" ")
325 }
326 buf.WriteString(formatValue(v.Get(i)))
327 }
328 buf.WriteString("]")
329 return buf.String()
330 case pref.Map:
331 var buf bytes.Buffer
332 buf.WriteString("map[")
333 var keys []pref.MapKey
334 v.Range(func(k pref.MapKey, v pref.Value) bool {
335 keys = append(keys, k)
336 return true
337 })
338 sort.Slice(keys, func(i, j int) bool {
339 return keys[i].String() < keys[j].String()
340 })
341 for i, k := range keys {
342 if i > 0 {
343 buf.WriteString(" ")
344 }
345 buf.WriteString(formatValue(k.Value()))
346 buf.WriteString(":")
347 buf.WriteString(formatValue(v.Get(k)))
348 }
349 buf.WriteString("]")
350 return buf.String()
351 case pref.Message:
352 b, err := textpb.Marshal(v.Interface())
353 if err != nil {
354 return fmt.Sprintf("<%v>", err)
355 }
356 return fmt.Sprintf("%v{%v}", v.Type().FullName(), string(b))
357 case string:
358 return fmt.Sprintf("%q", v)
359 default:
360 return fmt.Sprint(v)
361 }
362}
363
364func valueEqual(a, b pref.Value) bool {
365 ai, bi := a.Interface(), b.Interface()
366 switch ai.(type) {
367 case pref.Message:
368 return proto.Equal(
369 a.Message().Interface(),
370 b.Message().Interface(),
371 )
372 case pref.List:
373 lista, listb := a.List(), b.List()
374 if lista.Len() != listb.Len() {
375 return false
376 }
377 for i := 0; i < lista.Len(); i++ {
378 if !valueEqual(lista.Get(i), listb.Get(i)) {
379 return false
380 }
381 }
382 return true
383 case pref.Map:
384 mapa, mapb := a.Map(), b.Map()
385 if mapa.Len() != mapb.Len() {
386 return false
387 }
388 equal := true
389 mapa.Range(func(k pref.MapKey, v pref.Value) bool {
390 if !valueEqual(v, mapb.Get(k)) {
391 equal = false
392 return false
393 }
394 return true
395 })
396 return equal
397 case []byte:
398 return bytes.Equal(a.Bytes(), b.Bytes())
399 case float32, float64:
400 // NaNs are equal, but must be the same NaN.
401 return math.Float64bits(a.Float()) == math.Float64bits(a.Float())
402 default:
403 return ai == bi
404 }
405}
406
407// A seed is used to vary the content of a value.
408//
409// A seed of 0 is the zero value. Messages do not have a zero-value; a 0-seeded messages
410// is unpopulated.
411//
412// A seed of minVal or maxVal is the least or greatest value of the value type.
413type seed int
414
415const (
416 minVal seed = -1
417 maxVal seed = -2
418)
419
420// newValue returns a new value assignable to a field.
421//
422// The stack parameter is used to avoid infinite recursion when populating circular
423// data structures.
424func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.MessageType) pref.Value {
425 num := fd.Number()
426 switch {
427 case fd.IsMap():
428 mapv := m.Type().New().KnownFields().Get(num).Map()
429 if n == 0 {
430 return pref.ValueOf(mapv)
431 }
432 mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, stack))
433 mapv.Set(newMapKey(fd, minVal), newMapValue(fd, mapv, minVal, stack))
434 mapv.Set(newMapKey(fd, maxVal), newMapValue(fd, mapv, maxVal, stack))
435 mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, 10*n, stack))
436 return pref.ValueOf(mapv)
437 case fd.Cardinality() == pref.Repeated:
438 list := m.Type().New().KnownFields().Get(num).List()
439 if n == 0 {
440 return pref.ValueOf(list)
441 }
442 list.Append(newListElement(fd, list, 0, stack))
443 list.Append(newListElement(fd, list, minVal, stack))
444 list.Append(newListElement(fd, list, maxVal, stack))
445 list.Append(newListElement(fd, list, n, stack))
446 return pref.ValueOf(list)
447 case fd.Message() != nil:
448 return populateMessage(m.KnownFields().NewMessage(num), n, stack)
449 default:
450 return newScalarValue(fd, n)
451 }
452}
453
454func newListElement(fd pref.FieldDescriptor, list pref.List, n seed, stack []pref.MessageType) pref.Value {
455 if fd.Message() == nil {
456 return newScalarValue(fd, n)
457 }
458 return populateMessage(list.NewMessage(), n, stack)
459}
460
461func newMapKey(fd pref.FieldDescriptor, n seed) pref.MapKey {
462 kd := fd.Message().Fields().ByNumber(1)
463 return newScalarValue(kd, n).MapKey()
464}
465
466func newMapValue(fd pref.FieldDescriptor, mapv pref.Map, n seed, stack []pref.MessageType) pref.Value {
467 vd := fd.Message().Fields().ByNumber(2)
468 if vd.Message() == nil {
469 return newScalarValue(vd, n)
470 }
471 return populateMessage(mapv.NewMessage(), n, stack)
472}
473
474func newScalarValue(fd pref.FieldDescriptor, n seed) pref.Value {
475 switch fd.Kind() {
476 case pref.BoolKind:
477 return pref.ValueOf(n != 0)
478 case pref.EnumKind:
479 // TODO use actual value
480 return pref.ValueOf(pref.EnumNumber(n))
481 case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
482 switch n {
483 case minVal:
484 return pref.ValueOf(int32(math.MinInt32))
485 case maxVal:
486 return pref.ValueOf(int32(math.MaxInt32))
487 default:
488 return pref.ValueOf(int32(n))
489 }
490 case pref.Uint32Kind, pref.Fixed32Kind:
491 switch n {
492 case minVal:
493 // Only use 0 for the zero value.
494 return pref.ValueOf(uint32(1))
495 case maxVal:
496 return pref.ValueOf(uint32(math.MaxInt32))
497 default:
498 return pref.ValueOf(uint32(n))
499 }
500 case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
501 switch n {
502 case minVal:
503 return pref.ValueOf(int64(math.MinInt64))
504 case maxVal:
505 return pref.ValueOf(int64(math.MaxInt64))
506 default:
507 return pref.ValueOf(int64(n))
508 }
509 case pref.Uint64Kind, pref.Fixed64Kind:
510 switch n {
511 case minVal:
512 // Only use 0 for the zero value.
513 return pref.ValueOf(uint64(1))
514 case maxVal:
515 return pref.ValueOf(uint64(math.MaxInt64))
516 default:
517 return pref.ValueOf(uint64(n))
518 }
519 case pref.FloatKind:
520 switch n {
521 case minVal:
522 return pref.ValueOf(float32(math.SmallestNonzeroFloat32))
523 case maxVal:
524 return pref.ValueOf(float32(math.MaxFloat32))
525 default:
526 return pref.ValueOf(1.5 * float32(n))
527 }
528 case pref.DoubleKind:
529 switch n {
530 case minVal:
531 return pref.ValueOf(float64(math.SmallestNonzeroFloat64))
532 case maxVal:
533 return pref.ValueOf(float64(math.MaxFloat64))
534 default:
535 return pref.ValueOf(1.5 * float64(n))
536 }
537 case pref.StringKind:
538 if n == 0 {
539 return pref.ValueOf("")
540 }
541 return pref.ValueOf(fmt.Sprintf("%d", n))
542 case pref.BytesKind:
543 if n == 0 {
544 return pref.ValueOf([]byte(nil))
545 }
546 return pref.ValueOf([]byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)})
547 }
548 panic("unhandled kind")
549}
550
551func populateMessage(m pref.Message, n seed, stack []pref.MessageType) pref.Value {
552 if n == 0 {
553 return pref.ValueOf(m)
554 }
555 md := m.Type()
556 for _, x := range stack {
557 if md == x {
558 return pref.ValueOf(m)
559 }
560 }
561 stack = append(stack, md)
562 known := m.KnownFields()
563 for i := 0; i < md.Fields().Len(); i++ {
564 fd := md.Fields().Get(i)
565 if fd.IsWeak() {
566 continue
567 }
568 known.Set(fd.Number(), newValue(m, fd, 10*n+seed(i), stack))
569 }
570 return pref.ValueOf(m)
571}