blob: 7b513bdef4af5a83d1d65dd3f0216cba93194981 [file] [log] [blame]
Damien Neile6f060f2019-04-23 17:11:02 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8 "bytes"
Joe Tsai6bd33b62019-07-15 13:08:00 -07009 "math"
Joe Tsai378c1322019-04-25 23:48:08 -070010 "reflect"
Damien Neile6f060f2019-04-23 17:11:02 -070011
Joe Tsai378c1322019-04-25 23:48:08 -070012 "google.golang.org/protobuf/internal/encoding/wire"
Damien Neile89e6242019-05-13 23:55:40 -070013 pref "google.golang.org/protobuf/reflect/protoreflect"
Damien Neile6f060f2019-04-23 17:11:02 -070014)
15
Joe Tsai378c1322019-04-25 23:48:08 -070016// Equal reports whether two messages are equal.
Joe Tsai6bd33b62019-07-15 13:08:00 -070017// If two messages marshal to the same bytes under deterministic serialization,
18// then Equal is guaranteed to report true.
Damien Neile6f060f2019-04-23 17:11:02 -070019//
Joe Tsai378c1322019-04-25 23:48:08 -070020// Two messages are equal if they belong to the same message descriptor,
21// have the same set of populated known and extension field values,
22// and the same set of unknown fields values.
23//
24// Scalar values are compared with the equivalent of the == operator in Go,
Joe Tsai6bd33b62019-07-15 13:08:00 -070025// except bytes values which are compared using bytes.Equal and
26// floating point values which specially treat NaNs as equal.
Joe Tsai378c1322019-04-25 23:48:08 -070027// Message values are compared by recursively calling Equal.
28// Lists are equal if each element value is also equal.
29// Maps are equal if they have the same set of keys, where the pair of values
30// for each key is also equal.
31func Equal(x, y Message) bool {
Joe Tsaif2c4ddc2019-09-19 21:28:52 -070032 if x == nil || y == nil {
33 return x == nil && y == nil
34 }
Joe Tsai378c1322019-04-25 23:48:08 -070035 return equalMessage(x.ProtoReflect(), y.ProtoReflect())
Damien Neile6f060f2019-04-23 17:11:02 -070036}
37
38// equalMessage compares two messages.
Joe Tsai378c1322019-04-25 23:48:08 -070039func equalMessage(mx, my pref.Message) bool {
40 if mx.Descriptor() != my.Descriptor() {
Damien Neile6f060f2019-04-23 17:11:02 -070041 return false
42 }
43
Damien Neila9940822019-06-24 12:58:17 -070044 nx := 0
Damien Neile6f060f2019-04-23 17:11:02 -070045 equal := true
Joe Tsai378c1322019-04-25 23:48:08 -070046 mx.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
Damien Neila9940822019-06-24 12:58:17 -070047 nx++
Joe Tsai378c1322019-04-25 23:48:08 -070048 vy := my.Get(fd)
49 equal = my.Has(fd) && equalField(fd, vx, vy)
50 return equal
Damien Neile6f060f2019-04-23 17:11:02 -070051 })
52 if !equal {
53 return false
54 }
Damien Neila9940822019-06-24 12:58:17 -070055 ny := 0
56 my.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
57 ny++
58 return true
59 })
60 if nx != ny {
61 return false
62 }
Damien Neile6f060f2019-04-23 17:11:02 -070063
Joe Tsai378c1322019-04-25 23:48:08 -070064 return equalUnknown(mx.GetUnknown(), my.GetUnknown())
Damien Neile6f060f2019-04-23 17:11:02 -070065}
66
Joe Tsai378c1322019-04-25 23:48:08 -070067// equalField compares two fields.
68func equalField(fd pref.FieldDescriptor, x, y pref.Value) bool {
Damien Neile6f060f2019-04-23 17:11:02 -070069 switch {
Joe Tsaiac31a352019-05-13 14:32:56 -070070 case fd.IsList():
Joe Tsai378c1322019-04-25 23:48:08 -070071 return equalList(fd, x.List(), y.List())
Damien Neile6f060f2019-04-23 17:11:02 -070072 case fd.IsMap():
Joe Tsai378c1322019-04-25 23:48:08 -070073 return equalMap(fd, x.Map(), y.Map())
Damien Neile6f060f2019-04-23 17:11:02 -070074 default:
Joe Tsai378c1322019-04-25 23:48:08 -070075 return equalValue(fd, x, y)
Damien Neile6f060f2019-04-23 17:11:02 -070076 }
77}
78
Joe Tsai378c1322019-04-25 23:48:08 -070079// equalMap compares two maps.
80func equalMap(fd pref.FieldDescriptor, x, y pref.Map) bool {
81 if x.Len() != y.Len() {
Damien Neile6f060f2019-04-23 17:11:02 -070082 return false
83 }
84 equal := true
Joe Tsai378c1322019-04-25 23:48:08 -070085 x.Range(func(k pref.MapKey, vx pref.Value) bool {
86 vy := y.Get(k)
87 equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy)
88 return equal
Damien Neile6f060f2019-04-23 17:11:02 -070089 })
90 return equal
91}
92
Joe Tsai378c1322019-04-25 23:48:08 -070093// equalList compares two lists.
94func equalList(fd pref.FieldDescriptor, x, y pref.List) bool {
95 if x.Len() != y.Len() {
Damien Neile6f060f2019-04-23 17:11:02 -070096 return false
97 }
Joe Tsai378c1322019-04-25 23:48:08 -070098 for i := x.Len() - 1; i >= 0; i-- {
99 if !equalValue(fd, x.Get(i), y.Get(i)) {
Damien Neile6f060f2019-04-23 17:11:02 -0700100 return false
101 }
102 }
103 return true
104}
105
Joe Tsai378c1322019-04-25 23:48:08 -0700106// equalValue compares two singular values.
107func equalValue(fd pref.FieldDescriptor, x, y pref.Value) bool {
Damien Neile6f060f2019-04-23 17:11:02 -0700108 switch {
109 case fd.Message() != nil:
Joe Tsai378c1322019-04-25 23:48:08 -0700110 return equalMessage(x.Message(), y.Message())
Damien Neile6f060f2019-04-23 17:11:02 -0700111 case fd.Kind() == pref.BytesKind:
Joe Tsai378c1322019-04-25 23:48:08 -0700112 return bytes.Equal(x.Bytes(), y.Bytes())
Joe Tsai6bd33b62019-07-15 13:08:00 -0700113 case fd.Kind() == pref.FloatKind, fd.Kind() == pref.DoubleKind:
114 fx := x.Float()
115 fy := y.Float()
116 if math.IsNaN(fx) || math.IsNaN(fy) {
117 return math.IsNaN(fx) && math.IsNaN(fy)
118 }
119 return fx == fy
Damien Neile6f060f2019-04-23 17:11:02 -0700120 default:
Joe Tsai378c1322019-04-25 23:48:08 -0700121 return x.Interface() == y.Interface()
Damien Neile6f060f2019-04-23 17:11:02 -0700122 }
123}
Joe Tsai378c1322019-04-25 23:48:08 -0700124
125// equalUnknown compares unknown fields by direct comparison on the raw bytes
126// of each individual field number.
127func equalUnknown(x, y pref.RawFields) bool {
128 if len(x) != len(y) {
129 return false
130 }
131 if bytes.Equal([]byte(x), []byte(y)) {
132 return true
133 }
134
135 mx := make(map[pref.FieldNumber]pref.RawFields)
136 my := make(map[pref.FieldNumber]pref.RawFields)
137 for len(x) > 0 {
138 fnum, _, n := wire.ConsumeField(x)
139 mx[fnum] = append(mx[fnum], x[:n]...)
140 x = x[n:]
141 }
142 for len(y) > 0 {
143 fnum, _, n := wire.ConsumeField(y)
144 my[fnum] = append(my[fnum], y[:n]...)
145 y = y[n:]
146 }
147 return reflect.DeepEqual(mx, my)
148}