blob: 1b3c86825f63da693499943324372bbfec3e5072 [file] [log] [blame]
Damien Neile6f060f2019-04-23 17:11:02 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8 "bytes"
9
10 pref "github.com/golang/protobuf/v2/reflect/protoreflect"
11)
12
13// Equal returns true of two messages are equal.
14//
15// Two messages are equal if they have identical types and registered extension fields,
16// marshal to the same bytes under deterministic serialization,
17// and contain no floating point NaNs.
18func Equal(a, b Message) bool {
19 return equalMessage(a.ProtoReflect(), b.ProtoReflect())
20}
21
22// equalMessage compares two messages.
23func equalMessage(a, b pref.Message) bool {
Joe Tsai0fc49f82019-05-01 12:29:25 -070024 mda, mdb := a.Descriptor(), b.Descriptor()
Damien Neile6f060f2019-04-23 17:11:02 -070025 if mda != mdb && mda.FullName() != mdb.FullName() {
26 return false
27 }
28
29 // TODO: The v1 says that a nil message is not equal to an empty one.
30 // Decide what to do about this when v1 wraps v2.
31
32 knowna, knownb := a.KnownFields(), b.KnownFields()
33
34 fields := mda.Fields()
35 for i, flen := 0, fields.Len(); i < flen; i++ {
36 fd := fields.Get(i)
37 num := fd.Number()
38 hasa, hasb := knowna.Has(num), knownb.Has(num)
39 if !hasa && !hasb {
40 continue
41 }
42 if hasa != hasb || !equalFields(fd, knowna.Get(num), knownb.Get(num)) {
43 return false
44 }
45 }
46 equal := true
47
48 unknowna, unknownb := a.UnknownFields(), b.UnknownFields()
49 ulen := unknowna.Len()
50 if ulen != unknownb.Len() {
51 return false
52 }
53 unknowna.Range(func(num pref.FieldNumber, ra pref.RawFields) bool {
54 rb := unknownb.Get(num)
55 if !bytes.Equal([]byte(ra), []byte(rb)) {
56 equal = false
57 return false
58 }
59 return true
60 })
61 if !equal {
62 return false
63 }
64
65 // If the set of extension types is not identical for both messages, we report
66 // a inequality.
67 //
68 // This requirement is stringent. Registering an extension type for a message
69 // without setting a value for the extension will cause that message to compare
70 // as inequal to the same message without the registration.
71 //
72 // TODO: Revisit this behavior after eager decoding of extensions is implemented.
73 xtypesa, xtypesb := knowna.ExtensionTypes(), knownb.ExtensionTypes()
74 if la, lb := xtypesa.Len(), xtypesb.Len(); la != lb {
75 return false
76 } else if la == 0 {
77 return true
78 }
79 xtypesa.Range(func(xt pref.ExtensionType) bool {
Joe Tsai0fc49f82019-05-01 12:29:25 -070080 num := xt.Descriptor().Number()
Damien Neile6f060f2019-04-23 17:11:02 -070081 if xtypesb.ByNumber(num) != xt {
82 equal = false
83 return false
84 }
85 hasa, hasb := knowna.Has(num), knownb.Has(num)
86 if !hasa && !hasb {
87 return true
88 }
Joe Tsai0fc49f82019-05-01 12:29:25 -070089 if hasa != hasb || !equalFields(xt.Descriptor(), knowna.Get(num), knownb.Get(num)) {
Damien Neile6f060f2019-04-23 17:11:02 -070090 equal = false
91 return false
92 }
93 return true
94 })
95 return equal
96}
97
98// equalFields compares two fields.
99func equalFields(fd pref.FieldDescriptor, a, b pref.Value) bool {
100 switch {
101 case fd.IsMap():
102 return equalMap(fd, a.Map(), b.Map())
103 case fd.Cardinality() == pref.Repeated:
104 return equalList(fd, a.List(), b.List())
105 default:
106 return equalValue(fd, a, b)
107 }
108}
109
110// equalMap compares a map field.
111func equalMap(fd pref.FieldDescriptor, a, b pref.Map) bool {
112 fdv := fd.Message().Fields().ByNumber(2)
113 alen := a.Len()
114 if alen != b.Len() {
115 return false
116 }
117 equal := true
118 a.Range(func(k pref.MapKey, va pref.Value) bool {
119 vb := b.Get(k)
120 if !vb.IsValid() || !equalValue(fdv, va, vb) {
121 equal = false
122 return false
123 }
124 return true
125 })
126 return equal
127}
128
129// equalList compares a non-map repeated field.
130func equalList(fd pref.FieldDescriptor, a, b pref.List) bool {
131 alen := a.Len()
132 if alen != b.Len() {
133 return false
134 }
135 for i := 0; i < alen; i++ {
136 if !equalValue(fd, a.Get(i), b.Get(i)) {
137 return false
138 }
139 }
140 return true
141}
142
143// equalValue compares the scalar value type of a field.
144func equalValue(fd pref.FieldDescriptor, a, b pref.Value) bool {
145 switch {
146 case fd.Message() != nil:
147 return equalMessage(a.Message(), b.Message())
148 case fd.Kind() == pref.BytesKind:
149 return bytes.Equal(a.Bytes(), b.Bytes())
150 default:
151 return a.Interface() == b.Interface()
152 }
153}