encoding/textpb: clean-up code and fix error handling inside closures
In marshalMap, replace Map.Range and sortMap with
internal/mapsort.Range. Also, make sure to capture fatal error properly
for return.
In appendExtensions, capture fatal error properly for return. Added a
testcase that shows this was a bug.
In unmarshalAny, remove unnecessary checks and use internal/fieldnum for
Any's field numbers.
Change-Id: Id8574d5b4eb820ad961f6ad5e886f8ae2e9b90f0
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/170627
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/encoding/textpb/encode.go b/encoding/textpb/encode.go
index d94f771..4fd5d7d 100644
--- a/encoding/textpb/encode.go
+++ b/encoding/textpb/encode.go
@@ -11,6 +11,8 @@
"github.com/golang/protobuf/v2/internal/encoding/text"
"github.com/golang/protobuf/v2/internal/encoding/wire"
"github.com/golang/protobuf/v2/internal/errors"
+ "github.com/golang/protobuf/v2/internal/fieldnum"
+ "github.com/golang/protobuf/v2/internal/mapsort"
"github.com/golang/protobuf/v2/internal/pragma"
"github.com/golang/protobuf/v2/proto"
pref "github.com/golang/protobuf/v2/reflect/protoreflect"
@@ -188,7 +190,7 @@
return o.marshalMessage(val.Message())
}
- return text.Value{}, errors.New("%v has unknown kind: %v", fd.FullName(), kind)
+ panic(fmt.Sprintf("%v has unknown kind: %v", fd.FullName(), kind))
}
// marshalList converts a protoreflect.List to []text.Value.
@@ -224,12 +226,15 @@
keyType := msgFields.ByNumber(1)
valType := msgFields.ByNumber(2)
- mmap.Range(func(key pref.MapKey, val pref.Value) bool {
- keyTxtVal, err := o.marshalSingular(key.Value(), keyType)
+ var err error
+ mapsort.Range(mmap, keyType.Kind(), func(key pref.MapKey, val pref.Value) bool {
+ var keyTxtVal text.Value
+ keyTxtVal, err = o.marshalSingular(key.Value(), keyType)
if !nerr.Merge(err) {
return false
}
- valTxtVal, err := o.marshalSingular(val, valType)
+ var valTxtVal text.Value
+ valTxtVal, err = o.marshalSingular(val, valType)
if !nerr.Merge(err) {
return false
}
@@ -239,65 +244,22 @@
{mapValueName, valTxtVal},
})
values = append(values, msg)
+ err = nil
return true
})
+ if err != nil {
+ return nil, err
+ }
- sortMap(keyType.Kind(), values)
return values, nerr.E
}
-// sortMap orders list based on value of key field for deterministic output.
-// TODO: Improve sort comparison of text.Value for map keys.
-func sortMap(keyKind pref.Kind, values []text.Value) {
- less := func(i, j int) bool {
- mi := values[i].Message()
- mj := values[j].Message()
- return mi[0][1].String() < mj[0][1].String()
- }
- switch keyKind {
- case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
- less = func(i, j int) bool {
- mi := values[i].Message()
- mj := values[j].Message()
- ni, _ := mi[0][1].Int(false)
- nj, _ := mj[0][1].Int(false)
- return ni < nj
- }
- case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
- less = func(i, j int) bool {
- mi := values[i].Message()
- mj := values[j].Message()
- ni, _ := mi[0][1].Int(true)
- nj, _ := mj[0][1].Int(true)
- return ni < nj
- }
-
- case pref.Uint32Kind, pref.Fixed32Kind:
- less = func(i, j int) bool {
- mi := values[i].Message()
- mj := values[j].Message()
- ni, _ := mi[0][1].Uint(false)
- nj, _ := mj[0][1].Uint(false)
- return ni < nj
- }
- case pref.Uint64Kind, pref.Fixed64Kind:
- less = func(i, j int) bool {
- mi := values[i].Message()
- mj := values[j].Message()
- ni, _ := mi[0][1].Uint(true)
- nj, _ := mj[0][1].Uint(true)
- return ni < nj
- }
- }
- sort.Slice(values, less)
-}
-
// appendExtensions marshals extension fields and appends them to the given [][2]text.Value.
func (o MarshalOptions) appendExtensions(msgFields [][2]text.Value, knownFields pref.KnownFields) ([][2]text.Value, error) {
- var nerr errors.NonFatal
xtTypes := knownFields.ExtensionTypes()
xtFields := make([][2]text.Value, 0, xtTypes.Len())
+ var nerr errors.NonFatal
var err error
xtTypes.Range(func(xt pref.ExtensionType) bool {
name := xt.FullName()
@@ -312,13 +274,14 @@
tname := text.ValueOf(string(name))
pval := knownFields.Get(num)
xtFields, err = o.appendField(xtFields, tname, pval, xt)
- if err != nil {
+ if !nerr.Merge(err) {
return false
}
+ err = nil
}
return true
})
- if !nerr.Merge(err) {
+ if err != nil {
return msgFields, err
}
@@ -380,20 +343,9 @@
// marshalAny converts a google.protobuf.Any protoreflect.Message to a text.Value.
func (o MarshalOptions) marshalAny(m pref.Message) (text.Value, error) {
var nerr errors.NonFatal
-
- fds := m.Type().Fields()
- tfd := fds.ByName("type_url")
- if tfd == nil || tfd.Kind() != pref.StringKind {
- return text.Value{}, errors.New("invalid google.protobuf.Any message")
- }
- vfd := fds.ByName("value")
- if vfd == nil || vfd.Kind() != pref.BytesKind {
- return text.Value{}, errors.New("invalid google.protobuf.Any message")
- }
-
knownFields := m.KnownFields()
- typeURL := knownFields.Get(tfd.Number()).String()
- value := knownFields.Get(vfd.Number())
+ typeURL := knownFields.Get(fieldnum.Any_TypeUrl).String()
+ value := knownFields.Get(fieldnum.Any_Value)
emt, err := o.Resolver.FindMessageByURL(typeURL)
if !nerr.Merge(err) {