Damien Neil | e6f060f | 2019-04-23 17:11:02 -0700 | [diff] [blame] | 1 | // Copyright 2019 The Go Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | package proto |
| 6 | |
| 7 | import ( |
| 8 | "bytes" |
| 9 | |
Damien Neil | e89e624 | 2019-05-13 23:55:40 -0700 | [diff] [blame^] | 10 | pref "google.golang.org/protobuf/reflect/protoreflect" |
Damien Neil | e6f060f | 2019-04-23 17:11:02 -0700 | [diff] [blame] | 11 | ) |
| 12 | |
| 13 | // Equal returns true of two messages are equal. |
| 14 | // |
| 15 | // Two messages are equal if they have identical types and registered extension fields, |
| 16 | // marshal to the same bytes under deterministic serialization, |
| 17 | // and contain no floating point NaNs. |
| 18 | func Equal(a, b Message) bool { |
| 19 | return equalMessage(a.ProtoReflect(), b.ProtoReflect()) |
| 20 | } |
| 21 | |
| 22 | // equalMessage compares two messages. |
| 23 | func equalMessage(a, b pref.Message) bool { |
Joe Tsai | 0fc49f8 | 2019-05-01 12:29:25 -0700 | [diff] [blame] | 24 | mda, mdb := a.Descriptor(), b.Descriptor() |
Damien Neil | e6f060f | 2019-04-23 17:11:02 -0700 | [diff] [blame] | 25 | if mda != mdb && mda.FullName() != mdb.FullName() { |
| 26 | return false |
| 27 | } |
| 28 | |
| 29 | // TODO: The v1 says that a nil message is not equal to an empty one. |
| 30 | // Decide what to do about this when v1 wraps v2. |
| 31 | |
| 32 | knowna, knownb := a.KnownFields(), b.KnownFields() |
| 33 | |
| 34 | fields := mda.Fields() |
| 35 | for i, flen := 0, fields.Len(); i < flen; i++ { |
| 36 | fd := fields.Get(i) |
| 37 | num := fd.Number() |
| 38 | hasa, hasb := knowna.Has(num), knownb.Has(num) |
| 39 | if !hasa && !hasb { |
| 40 | continue |
| 41 | } |
| 42 | if hasa != hasb || !equalFields(fd, knowna.Get(num), knownb.Get(num)) { |
| 43 | return false |
| 44 | } |
| 45 | } |
| 46 | equal := true |
| 47 | |
| 48 | unknowna, unknownb := a.UnknownFields(), b.UnknownFields() |
| 49 | ulen := unknowna.Len() |
| 50 | if ulen != unknownb.Len() { |
| 51 | return false |
| 52 | } |
| 53 | unknowna.Range(func(num pref.FieldNumber, ra pref.RawFields) bool { |
| 54 | rb := unknownb.Get(num) |
| 55 | if !bytes.Equal([]byte(ra), []byte(rb)) { |
| 56 | equal = false |
| 57 | return false |
| 58 | } |
| 59 | return true |
| 60 | }) |
| 61 | if !equal { |
| 62 | return false |
| 63 | } |
| 64 | |
| 65 | // If the set of extension types is not identical for both messages, we report |
| 66 | // a inequality. |
| 67 | // |
| 68 | // This requirement is stringent. Registering an extension type for a message |
| 69 | // without setting a value for the extension will cause that message to compare |
| 70 | // as inequal to the same message without the registration. |
| 71 | // |
| 72 | // TODO: Revisit this behavior after eager decoding of extensions is implemented. |
| 73 | xtypesa, xtypesb := knowna.ExtensionTypes(), knownb.ExtensionTypes() |
| 74 | if la, lb := xtypesa.Len(), xtypesb.Len(); la != lb { |
| 75 | return false |
| 76 | } else if la == 0 { |
| 77 | return true |
| 78 | } |
| 79 | xtypesa.Range(func(xt pref.ExtensionType) bool { |
Joe Tsai | 0fc49f8 | 2019-05-01 12:29:25 -0700 | [diff] [blame] | 80 | num := xt.Descriptor().Number() |
Damien Neil | e6f060f | 2019-04-23 17:11:02 -0700 | [diff] [blame] | 81 | if xtypesb.ByNumber(num) != xt { |
| 82 | equal = false |
| 83 | return false |
| 84 | } |
| 85 | hasa, hasb := knowna.Has(num), knownb.Has(num) |
| 86 | if !hasa && !hasb { |
| 87 | return true |
| 88 | } |
Joe Tsai | 0fc49f8 | 2019-05-01 12:29:25 -0700 | [diff] [blame] | 89 | if hasa != hasb || !equalFields(xt.Descriptor(), knowna.Get(num), knownb.Get(num)) { |
Damien Neil | e6f060f | 2019-04-23 17:11:02 -0700 | [diff] [blame] | 90 | equal = false |
| 91 | return false |
| 92 | } |
| 93 | return true |
| 94 | }) |
| 95 | return equal |
| 96 | } |
| 97 | |
| 98 | // equalFields compares two fields. |
| 99 | func equalFields(fd pref.FieldDescriptor, a, b pref.Value) bool { |
| 100 | switch { |
Joe Tsai | ac31a35 | 2019-05-13 14:32:56 -0700 | [diff] [blame] | 101 | case fd.IsList(): |
| 102 | return equalList(fd, a.List(), b.List()) |
Damien Neil | e6f060f | 2019-04-23 17:11:02 -0700 | [diff] [blame] | 103 | case fd.IsMap(): |
| 104 | return equalMap(fd, a.Map(), b.Map()) |
Damien Neil | e6f060f | 2019-04-23 17:11:02 -0700 | [diff] [blame] | 105 | default: |
| 106 | return equalValue(fd, a, b) |
| 107 | } |
| 108 | } |
| 109 | |
| 110 | // equalMap compares a map field. |
| 111 | func equalMap(fd pref.FieldDescriptor, a, b pref.Map) bool { |
Damien Neil | e6f060f | 2019-04-23 17:11:02 -0700 | [diff] [blame] | 112 | alen := a.Len() |
| 113 | if alen != b.Len() { |
| 114 | return false |
| 115 | } |
| 116 | equal := true |
| 117 | a.Range(func(k pref.MapKey, va pref.Value) bool { |
| 118 | vb := b.Get(k) |
Joe Tsai | ac31a35 | 2019-05-13 14:32:56 -0700 | [diff] [blame] | 119 | if !vb.IsValid() || !equalValue(fd.MapValue(), va, vb) { |
Damien Neil | e6f060f | 2019-04-23 17:11:02 -0700 | [diff] [blame] | 120 | equal = false |
| 121 | return false |
| 122 | } |
| 123 | return true |
| 124 | }) |
| 125 | return equal |
| 126 | } |
| 127 | |
| 128 | // equalList compares a non-map repeated field. |
| 129 | func equalList(fd pref.FieldDescriptor, a, b pref.List) bool { |
| 130 | alen := a.Len() |
| 131 | if alen != b.Len() { |
| 132 | return false |
| 133 | } |
| 134 | for i := 0; i < alen; i++ { |
| 135 | if !equalValue(fd, a.Get(i), b.Get(i)) { |
| 136 | return false |
| 137 | } |
| 138 | } |
| 139 | return true |
| 140 | } |
| 141 | |
| 142 | // equalValue compares the scalar value type of a field. |
| 143 | func equalValue(fd pref.FieldDescriptor, a, b pref.Value) bool { |
| 144 | switch { |
| 145 | case fd.Message() != nil: |
| 146 | return equalMessage(a.Message(), b.Message()) |
| 147 | case fd.Kind() == pref.BytesKind: |
| 148 | return bytes.Equal(a.Bytes(), b.Bytes()) |
| 149 | default: |
| 150 | return a.Interface() == b.Interface() |
| 151 | } |
| 152 | } |