internal/fileinit: prevent map entry descriptors from implementing MessageType
The protobuf type system hacks the representation of map entries into that
of a pseudo-message descriptor.
Previously, we made all message descriptors implement MessageType
where type descriptors had a GoType method that simply returned nil.
Unfortunately, this violates a nice property in the Go type system
where being able to assert to a MessageType guarantees that Go type
information is truly associated with that descriptor.
This CL makes it such that message descriptors for map entries
do not implement MessageType.
Change-Id: I23873cb71fe0ab3c0befd8052830ea6e53c97ca9
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/168399
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/fileinit/desc_init.go b/internal/fileinit/desc_init.go
index 7d14998..d1310a4 100644
--- a/internal/fileinit/desc_init.go
+++ b/internal/fileinit/desc_init.go
@@ -19,6 +19,13 @@
file.initDecls(len(fb.EnumOutputTypes), len(fb.MessageOutputTypes), len(fb.ExtensionOutputTypes))
file.unmarshalSeed(fb.RawDescriptor)
+ // Determine which message descriptors represent map entries based on the
+ // lack of an associated Go type.
+ messageDecls := file.GoTypes[len(file.allEnums):]
+ for i := range file.allMessages {
+ file.allMessages[i].isMapEntry = messageDecls[i] == nil
+ }
+
// Extended message dependencies are eagerly handled since registration
// needs this information at program init time.
for i := range file.allExtensions {
@@ -31,7 +38,7 @@
}
// initDecls pre-allocates slices for the exact number of enums, messages
-// (excluding map entries), and extensions declared in the proto file.
+// (including map entries), and extensions declared in the proto file.
// This is done to avoid regrowing the slice, which would change the address
// for any previously seen declaration.
//
@@ -279,7 +286,7 @@
for i := range md.enums.list {
_, n := wire.ConsumeVarint(b)
v, m := wire.ConsumeBytes(b[n:])
- md.enums.list[i].unmarshalSeed(v, nb, pf, md, i)
+ md.enums.list[i].unmarshalSeed(v, nb, pf, md.asDesc(), i)
b = b[n+m:]
}
}
@@ -288,7 +295,7 @@
for i := range md.messages.list {
_, n := wire.ConsumeVarint(b)
v, m := wire.ConsumeBytes(b[n:])
- md.messages.list[i].unmarshalSeed(v, nb, pf, md, i)
+ md.messages.list[i].unmarshalSeed(v, nb, pf, md.asDesc(), i)
b = b[n+m:]
}
}
@@ -297,7 +304,7 @@
for i := range md.extensions.list {
_, n := wire.ConsumeVarint(b)
v, m := wire.ConsumeBytes(b[n:])
- md.extensions.list[i].unmarshalSeed(v, nb, pf, md, i)
+ md.extensions.list[i].unmarshalSeed(v, nb, pf, md.asDesc(), i)
b = b[n+m:]
}
}