blob: 44f840a89aba569d4732d79fb5bf75c3bcfecc19 [file] [log] [blame]
Damien Neile6f060f2019-04-23 17:11:02 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8 "bytes"
Joe Tsai6bd33b62019-07-15 13:08:00 -07009 "math"
Joe Tsai378c1322019-04-25 23:48:08 -070010 "reflect"
Damien Neile6f060f2019-04-23 17:11:02 -070011
Joe Tsai378c1322019-04-25 23:48:08 -070012 "google.golang.org/protobuf/internal/encoding/wire"
Damien Neile89e6242019-05-13 23:55:40 -070013 pref "google.golang.org/protobuf/reflect/protoreflect"
Damien Neile6f060f2019-04-23 17:11:02 -070014)
15
Joe Tsai378c1322019-04-25 23:48:08 -070016// Equal reports whether two messages are equal.
Joe Tsai6bd33b62019-07-15 13:08:00 -070017// If two messages marshal to the same bytes under deterministic serialization,
18// then Equal is guaranteed to report true.
Damien Neile6f060f2019-04-23 17:11:02 -070019//
Joe Tsai378c1322019-04-25 23:48:08 -070020// Two messages are equal if they belong to the same message descriptor,
21// have the same set of populated known and extension field values,
22// and the same set of unknown fields values.
23//
24// Scalar values are compared with the equivalent of the == operator in Go,
Joe Tsai6bd33b62019-07-15 13:08:00 -070025// except bytes values which are compared using bytes.Equal and
26// floating point values which specially treat NaNs as equal.
Joe Tsai378c1322019-04-25 23:48:08 -070027// Message values are compared by recursively calling Equal.
28// Lists are equal if each element value is also equal.
29// Maps are equal if they have the same set of keys, where the pair of values
30// for each key is also equal.
31func Equal(x, y Message) bool {
32 return equalMessage(x.ProtoReflect(), y.ProtoReflect())
Damien Neile6f060f2019-04-23 17:11:02 -070033}
34
35// equalMessage compares two messages.
Joe Tsai378c1322019-04-25 23:48:08 -070036func equalMessage(mx, my pref.Message) bool {
37 if mx.Descriptor() != my.Descriptor() {
Damien Neile6f060f2019-04-23 17:11:02 -070038 return false
39 }
40
Damien Neila9940822019-06-24 12:58:17 -070041 nx := 0
Damien Neile6f060f2019-04-23 17:11:02 -070042 equal := true
Joe Tsai378c1322019-04-25 23:48:08 -070043 mx.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
Damien Neila9940822019-06-24 12:58:17 -070044 nx++
Joe Tsai378c1322019-04-25 23:48:08 -070045 vy := my.Get(fd)
46 equal = my.Has(fd) && equalField(fd, vx, vy)
47 return equal
Damien Neile6f060f2019-04-23 17:11:02 -070048 })
49 if !equal {
50 return false
51 }
Damien Neila9940822019-06-24 12:58:17 -070052 ny := 0
53 my.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
54 ny++
55 return true
56 })
57 if nx != ny {
58 return false
59 }
Damien Neile6f060f2019-04-23 17:11:02 -070060
Joe Tsai378c1322019-04-25 23:48:08 -070061 return equalUnknown(mx.GetUnknown(), my.GetUnknown())
Damien Neile6f060f2019-04-23 17:11:02 -070062}
63
Joe Tsai378c1322019-04-25 23:48:08 -070064// equalField compares two fields.
65func equalField(fd pref.FieldDescriptor, x, y pref.Value) bool {
Damien Neile6f060f2019-04-23 17:11:02 -070066 switch {
Joe Tsaiac31a352019-05-13 14:32:56 -070067 case fd.IsList():
Joe Tsai378c1322019-04-25 23:48:08 -070068 return equalList(fd, x.List(), y.List())
Damien Neile6f060f2019-04-23 17:11:02 -070069 case fd.IsMap():
Joe Tsai378c1322019-04-25 23:48:08 -070070 return equalMap(fd, x.Map(), y.Map())
Damien Neile6f060f2019-04-23 17:11:02 -070071 default:
Joe Tsai378c1322019-04-25 23:48:08 -070072 return equalValue(fd, x, y)
Damien Neile6f060f2019-04-23 17:11:02 -070073 }
74}
75
Joe Tsai378c1322019-04-25 23:48:08 -070076// equalMap compares two maps.
77func equalMap(fd pref.FieldDescriptor, x, y pref.Map) bool {
78 if x.Len() != y.Len() {
Damien Neile6f060f2019-04-23 17:11:02 -070079 return false
80 }
81 equal := true
Joe Tsai378c1322019-04-25 23:48:08 -070082 x.Range(func(k pref.MapKey, vx pref.Value) bool {
83 vy := y.Get(k)
84 equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy)
85 return equal
Damien Neile6f060f2019-04-23 17:11:02 -070086 })
87 return equal
88}
89
Joe Tsai378c1322019-04-25 23:48:08 -070090// equalList compares two lists.
91func equalList(fd pref.FieldDescriptor, x, y pref.List) bool {
92 if x.Len() != y.Len() {
Damien Neile6f060f2019-04-23 17:11:02 -070093 return false
94 }
Joe Tsai378c1322019-04-25 23:48:08 -070095 for i := x.Len() - 1; i >= 0; i-- {
96 if !equalValue(fd, x.Get(i), y.Get(i)) {
Damien Neile6f060f2019-04-23 17:11:02 -070097 return false
98 }
99 }
100 return true
101}
102
Joe Tsai378c1322019-04-25 23:48:08 -0700103// equalValue compares two singular values.
104func equalValue(fd pref.FieldDescriptor, x, y pref.Value) bool {
Damien Neile6f060f2019-04-23 17:11:02 -0700105 switch {
106 case fd.Message() != nil:
Joe Tsai378c1322019-04-25 23:48:08 -0700107 return equalMessage(x.Message(), y.Message())
Damien Neile6f060f2019-04-23 17:11:02 -0700108 case fd.Kind() == pref.BytesKind:
Joe Tsai378c1322019-04-25 23:48:08 -0700109 return bytes.Equal(x.Bytes(), y.Bytes())
Joe Tsai6bd33b62019-07-15 13:08:00 -0700110 case fd.Kind() == pref.FloatKind, fd.Kind() == pref.DoubleKind:
111 fx := x.Float()
112 fy := y.Float()
113 if math.IsNaN(fx) || math.IsNaN(fy) {
114 return math.IsNaN(fx) && math.IsNaN(fy)
115 }
116 return fx == fy
Damien Neile6f060f2019-04-23 17:11:02 -0700117 default:
Joe Tsai378c1322019-04-25 23:48:08 -0700118 return x.Interface() == y.Interface()
Damien Neile6f060f2019-04-23 17:11:02 -0700119 }
120}
Joe Tsai378c1322019-04-25 23:48:08 -0700121
122// equalUnknown compares unknown fields by direct comparison on the raw bytes
123// of each individual field number.
124func equalUnknown(x, y pref.RawFields) bool {
125 if len(x) != len(y) {
126 return false
127 }
128 if bytes.Equal([]byte(x), []byte(y)) {
129 return true
130 }
131
132 mx := make(map[pref.FieldNumber]pref.RawFields)
133 my := make(map[pref.FieldNumber]pref.RawFields)
134 for len(x) > 0 {
135 fnum, _, n := wire.ConsumeField(x)
136 mx[fnum] = append(mx[fnum], x[:n]...)
137 x = x[n:]
138 }
139 for len(y) > 0 {
140 fnum, _, n := wire.ConsumeField(y)
141 my[fnum] = append(my[fnum], y[:n]...)
142 y = y[n:]
143 }
144 return reflect.DeepEqual(mx, my)
145}