reflect/protoreflect: add helper methods to FieldDescriptor
Added API:
FieldDescriptor.IsExtension
FieldDescriptor.IsList
FieldDescriptor.MapKey
FieldDescriptor.MapValue
FieldDescriptor.ContainingOneof
FieldDescriptor.ContainingMessage
Deprecated API (to be removed in subsequent CL):
FieldDescriptor.Oneof
FieldDescriptor.Extendee
These methods help cleanup several common usage patterns.
Change-Id: I9a3ffabc2edb2173c536509b22f330f98bba7cf3
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/176977
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/encoding/bench_test.go b/encoding/bench_test.go
index 88f62b7..fdb33d1 100644
--- a/encoding/bench_test.go
+++ b/encoding/bench_test.go
@@ -49,13 +49,12 @@
for i := 0; i < fieldDescs.Len(); i++ {
fd := fieldDescs.Get(i)
num := fd.Number()
- if cardinality := fd.Cardinality(); cardinality == pref.Repeated {
- if !fd.IsMap() {
- setList(knownFields.Get(num).List(), fd, level)
- } else {
- setMap(knownFields.Get(num).Map(), fd, level)
- }
- } else {
+ switch {
+ case fd.IsList():
+ setList(knownFields.Get(num).List(), fd, level)
+ case fd.IsMap():
+ setMap(knownFields.Get(num).Map(), fd, level)
+ default:
setScalarField(knownFields, fd, level)
}
}
diff --git a/encoding/jsonpb/decode.go b/encoding/jsonpb/decode.go
index 7e142c5..1b5a5af 100644
--- a/encoding/jsonpb/decode.go
+++ b/encoding/jsonpb/decode.go
@@ -244,14 +244,20 @@
continue
}
- if cardinality := fd.Cardinality(); cardinality == pref.Repeated {
- // Map or list fields have cardinality of repeated.
- if err := o.unmarshalRepeated(knownFields, fd); !nerr.Merge(err) {
+ switch {
+ case fd.IsList():
+ list := knownFields.Get(fd.Number()).List()
+ if err := o.unmarshalList(list, fd); !nerr.Merge(err) {
return errors.New("%v|%q: %v", fd.FullName(), name, err)
}
- } else {
+ case fd.IsMap():
+ mmap := knownFields.Get(fd.Number()).Map()
+ if err := o.unmarshalMap(mmap, fd); !nerr.Merge(err) {
+ return errors.New("%v|%q: %v", fd.FullName(), name, err)
+ }
+ default:
// If field is a oneof, check if it has already been set.
- if od := fd.Oneof(); od != nil {
+ if od := fd.ContainingOneof(); od != nil {
idx := uint64(od.Index())
if seenOneofs.Has(idx) {
return errors.New("%v: oneof is already set", od.FullName())
@@ -548,24 +554,6 @@
return pref.Value{}, unexpectedJSONError{jval}
}
-// unmarshalRepeated unmarshals into a repeated field.
-func (o UnmarshalOptions) unmarshalRepeated(knownFields pref.KnownFields, fd pref.FieldDescriptor) error {
- var nerr errors.NonFatal
- num := fd.Number()
- val := knownFields.Get(num)
- if !fd.IsMap() {
- if err := o.unmarshalList(val.List(), fd); !nerr.Merge(err) {
- return err
- }
- } else {
- if err := o.unmarshalMap(val.Map(), fd); !nerr.Merge(err) {
- return err
- }
- }
- return nerr.E
-}
-
-// unmarshalList unmarshals into given protoreflect.List.
func (o UnmarshalOptions) unmarshalList(list pref.List, fd pref.FieldDescriptor) error {
var nerr errors.NonFatal
jval, err := o.decoder.Read()
@@ -610,10 +598,8 @@
return nerr.E
}
-// unmarshalMap unmarshals into given protoreflect.Map.
func (o UnmarshalOptions) unmarshalMap(mmap pref.Map, fd pref.FieldDescriptor) error {
var nerr errors.NonFatal
-
jval, err := o.decoder.Read()
if !nerr.Merge(err) {
return err
@@ -622,17 +608,11 @@
return unexpectedJSONError{jval}
}
- fields := fd.Message().Fields()
- keyDesc := fields.ByNumber(1)
- valDesc := fields.ByNumber(2)
-
// Determine ahead whether map entry is a scalar type or a message type in
// order to call the appropriate unmarshalMapValue func inside the for loop
// below.
- unmarshalMapValue := func() (pref.Value, error) {
- return o.unmarshalScalar(valDesc)
- }
- switch valDesc.Kind() {
+ var unmarshalMapValue func() (pref.Value, error)
+ switch fd.MapValue().Kind() {
case pref.MessageKind, pref.GroupKind:
unmarshalMapValue = func() (pref.Value, error) {
var nerr errors.NonFatal
@@ -642,6 +622,10 @@
}
return pref.ValueOf(m), nerr.E
}
+ default:
+ unmarshalMapValue = func() (pref.Value, error) {
+ return o.unmarshalScalar(fd.MapValue())
+ }
}
Loop:
@@ -666,7 +650,7 @@
}
// Unmarshal field name.
- pkey, err := unmarshalMapKey(name, keyDesc)
+ pkey, err := unmarshalMapKey(name, fd.MapKey())
if !nerr.Merge(err) {
return err
}
diff --git a/encoding/jsonpb/encode.go b/encoding/jsonpb/encode.go
index 90767ce..256cd74 100644
--- a/encoding/jsonpb/encode.go
+++ b/encoding/jsonpb/encode.go
@@ -118,25 +118,14 @@
// marshalValue marshals the given protoreflect.Value.
func (o MarshalOptions) marshalValue(val pref.Value, fd pref.FieldDescriptor) error {
- var nerr errors.NonFatal
- if fd.Cardinality() == pref.Repeated {
- // Map or repeated fields.
- if fd.IsMap() {
- if err := o.marshalMap(val.Map(), fd); !nerr.Merge(err) {
- return err
- }
- } else {
- if err := o.marshalList(val.List(), fd); !nerr.Merge(err) {
- return err
- }
- }
- } else {
- // Required or optional fields.
- if err := o.marshalSingular(val, fd); !nerr.Merge(err) {
- return err
- }
+ switch {
+ case fd.IsList():
+ return o.marshalList(val.List(), fd)
+ case fd.IsMap():
+ return o.marshalMap(val.Map(), fd)
+ default:
+ return o.marshalSingular(val, fd)
}
- return nerr.E
}
// marshalSingular marshals the given non-repeated field value. This includes
@@ -226,17 +215,13 @@
o.encoder.StartObject()
defer o.encoder.EndObject()
- msgFields := fd.Message().Fields()
- keyType := msgFields.ByNumber(1)
- valType := msgFields.ByNumber(2)
-
// Get a sorted list based on keyType first.
entries := make([]mapEntry, 0, mmap.Len())
mmap.Range(func(key pref.MapKey, val pref.Value) bool {
entries = append(entries, mapEntry{key: key, value: val})
return true
})
- sortMap(keyType.Kind(), entries)
+ sortMap(fd.MapKey().Kind(), entries)
// Write out sorted list.
var nerr errors.NonFatal
@@ -244,7 +229,7 @@
if err := o.encoder.WriteName(entry.key.String()); !nerr.Merge(err) {
return err
}
- if err := o.marshalSingular(entry.value, valType); !nerr.Merge(err) {
+ if err := o.marshalSingular(entry.value, fd.MapValue()); !nerr.Merge(err) {
return err
}
}
@@ -333,6 +318,6 @@
if xd.FullName().Parent() != md.FullName() {
return false
}
- xmd, ok := xd.Extendee().(interface{ IsMessageSet() bool })
+ xmd, ok := xd.ContainingMessage().(interface{ IsMessageSet() bool })
return ok && xmd.IsMessageSet()
}
diff --git a/encoding/textpb/decode.go b/encoding/textpb/decode.go
index 38092ac..c586331 100644
--- a/encoding/textpb/decode.go
+++ b/encoding/textpb/decode.go
@@ -159,14 +159,36 @@
return errors.New("%v contains unknown field: %v", messageDesc.FullName(), tkey)
}
- if cardinality := fd.Cardinality(); cardinality == pref.Repeated {
- // Map or list fields have cardinality of repeated.
- if err := o.unmarshalRepeated(tval, fd, knownFields); !nerr.Merge(err) {
+ switch {
+ case fd.IsList():
+ // If input is not a list, turn it into a list.
+ var items []text.Value
+ if tval.Type() != text.List {
+ items = []text.Value{tval}
+ } else {
+ items = tval.List()
+ }
+
+ list := knownFields.Get(fd.Number()).List()
+ if err := o.unmarshalList(items, fd, list); !nerr.Merge(err) {
return err
}
- } else {
+ case fd.IsMap():
+ // If input is not a list, turn it into a list.
+ var items []text.Value
+ if tval.Type() != text.List {
+ items = []text.Value{tval}
+ } else {
+ items = tval.List()
+ }
+
+ mmap := knownFields.Get(fd.Number()).Map()
+ if err := o.unmarshalMap(items, fd, mmap); !nerr.Merge(err) {
+ return err
+ }
+ default:
// If field is a oneof, check if it has already been set.
- if od := fd.Oneof(); od != nil {
+ if od := fd.ContainingOneof(); od != nil {
idx := uint64(od.Index())
if seenOneofs.Has(idx) {
return errors.New("oneof %v is already set", od.FullName())
@@ -232,33 +254,6 @@
return nerr.E
}
-// unmarshalRepeated unmarshals given text.Value into a repeated field. Caller should only
-// call this for cardinality=repeated.
-func (o UnmarshalOptions) unmarshalRepeated(input text.Value, fd pref.FieldDescriptor, knownFields pref.KnownFields) error {
- var items []text.Value
- // If input is not a list, turn it into a list.
- if input.Type() != text.List {
- items = []text.Value{input}
- } else {
- items = input.List()
- }
-
- var nerr errors.NonFatal
- num := fd.Number()
- val := knownFields.Get(num)
- if !fd.IsMap() {
- if err := o.unmarshalList(items, fd, val.List()); !nerr.Merge(err) {
- return err
- }
- } else {
- if err := o.unmarshalMap(items, fd, val.Map()); !nerr.Merge(err) {
- return err
- }
- }
-
- return nerr.E
-}
-
// unmarshalScalar converts the given text.Value to a scalar/enum protoreflect.Value specified in
// the given FieldDescriptor. Caller should not pass in a FieldDescriptor for a message/group kind.
func unmarshalScalar(input text.Value, fd pref.FieldDescriptor) (pref.Value, error) {
@@ -358,14 +353,11 @@
// unmarshalMap unmarshals given []text.Value into given protoreflect.Map.
func (o UnmarshalOptions) unmarshalMap(input []text.Value, fd pref.FieldDescriptor, mmap pref.Map) error {
var nerr errors.NonFatal
- fields := fd.Message().Fields()
- keyDesc := fields.ByNumber(1)
- valDesc := fields.ByNumber(2)
// Determine ahead whether map entry is a scalar type or a message type in order to call the
// appropriate unmarshalMapValue func inside the for loop below.
unmarshalMapValue := unmarshalMapScalarValue
- switch valDesc.Kind() {
+ switch fd.MapValue().Kind() {
case pref.MessageKind, pref.GroupKind:
unmarshalMapValue = o.unmarshalMapMessageValue
}
@@ -378,11 +370,11 @@
if !nerr.Merge(err) {
return err
}
- pkey, err := unmarshalMapKey(tkey, keyDesc)
+ pkey, err := unmarshalMapKey(tkey, fd.MapKey())
if !nerr.Merge(err) {
return err
}
- err = unmarshalMapValue(tval, pkey, valDesc, mmap)
+ err = unmarshalMapValue(tval, pkey, fd.MapValue(), mmap)
if !nerr.Merge(err) {
return err
}
diff --git a/encoding/textpb/encode.go b/encoding/textpb/encode.go
index b8b2e71..9bd8279 100644
--- a/encoding/textpb/encode.go
+++ b/encoding/textpb/encode.go
@@ -132,28 +132,26 @@
func (o MarshalOptions) appendField(msgFields [][2]text.Value, name text.Value, pval pref.Value, fd pref.FieldDescriptor) ([][2]text.Value, error) {
var nerr errors.NonFatal
- if fd.Cardinality() == pref.Repeated {
- // Map or repeated fields.
- var items []text.Value
- var err error
- if fd.IsMap() {
- items, err = o.marshalMap(pval.Map(), fd)
- if !nerr.Merge(err) {
- return msgFields, err
- }
- } else {
- items, err = o.marshalList(pval.List(), fd)
- if !nerr.Merge(err) {
- return msgFields, err
- }
+ switch {
+ case fd.IsList():
+ items, err := o.marshalList(pval.List(), fd)
+ if !nerr.Merge(err) {
+ return msgFields, err
}
- // Add each item as key: value field.
for _, item := range items {
msgFields = append(msgFields, [2]text.Value{name, item})
}
- } else {
- // Required or optional fields.
+ case fd.IsMap():
+ items, err := o.marshalMap(pval.Map(), fd)
+ if !nerr.Merge(err) {
+ return msgFields, err
+ }
+
+ for _, item := range items {
+ msgFields = append(msgFields, [2]text.Value{name, item})
+ }
+ default:
tval, err := o.marshalSingular(pval, fd)
if !nerr.Merge(err) {
return msgFields, err
@@ -231,19 +229,16 @@
var nerr errors.NonFatal
// values is a list of messages.
values := make([]text.Value, 0, mmap.Len())
- msgFields := fd.Message().Fields()
- keyType := msgFields.ByNumber(1)
- valType := msgFields.ByNumber(2)
var err error
- mapsort.Range(mmap, keyType.Kind(), func(key pref.MapKey, val pref.Value) bool {
+ mapsort.Range(mmap, fd.MapKey().Kind(), func(key pref.MapKey, val pref.Value) bool {
var keyTxtVal text.Value
- keyTxtVal, err = o.marshalSingular(key.Value(), keyType)
+ keyTxtVal, err = o.marshalSingular(key.Value(), fd.MapKey())
if !nerr.Merge(err) {
return false
}
var valTxtVal text.Value
- valTxtVal, err = o.marshalSingular(val, valType)
+ valTxtVal, err = o.marshalSingular(val, fd.MapValue())
if !nerr.Merge(err) {
return false
}
@@ -314,7 +309,7 @@
if xd.FullName().Parent() != md.FullName() {
return false
}
- xmd, ok := xd.Extendee().(interface{ IsMessageSet() bool })
+ xmd, ok := xd.ContainingMessage().(interface{ IsMessageSet() bool })
return ok && xmd.IsMessageSet()
}