internal/encoding/defval: unify logic for handling default values
Logic for serializing the default value in textual form exists in
multiple places in trivially similar forms. Centralize that logic.
Change-Id: I4408ddfeef2c0dfa5c7468e01a4d4df5654ae57f
Reviewed-on: https://go-review.googlesource.com/c/153022
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/internal/encoding/defval/default.go b/internal/encoding/defval/default.go
new file mode 100644
index 0000000..e2c62a9
--- /dev/null
+++ b/internal/encoding/defval/default.go
@@ -0,0 +1,209 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package defval marshals and unmarshals textual forms of default values.
+//
+// This package handles both the form historically used in Go struct field tags
+// and also the form used by google.protobuf.FieldDescriptorProto.default_value
+// since they differ in superficial ways.
+package defval
+
+import (
+ "fmt"
+ "math"
+ "strconv"
+
+ ptext "github.com/golang/protobuf/v2/internal/encoding/text"
+ errors "github.com/golang/protobuf/v2/internal/errors"
+ pref "github.com/golang/protobuf/v2/reflect/protoreflect"
+)
+
+// Format is the serialization format used to represent the default value.
+type Format int
+
+const (
+ _ Format = iota
+
+ // Descriptor uses the serialization format that protoc uses with the
+ // google.protobuf.FieldDescriptorProto.default_value field.
+ Descriptor
+
+ // GoTag uses the historical serialization format in Go struct field tags.
+ GoTag
+)
+
+// Unmarshal deserializes the default string s according to the given kind k.
+// When using the Descriptor format on an enum kind, a Value of type string
+// representing the enum identifier is returned. It is the caller's
+// responsibility to verify that the identifier is valid.
+func Unmarshal(s string, k pref.Kind, f Format) (pref.Value, error) {
+ switch k {
+ case pref.BoolKind:
+ if f == GoTag {
+ switch s {
+ case "1":
+ return pref.ValueOf(true), nil
+ case "0":
+ return pref.ValueOf(false), nil
+ }
+ } else {
+ switch s {
+ case "true":
+ return pref.ValueOf(true), nil
+ case "false":
+ return pref.ValueOf(false), nil
+ }
+ }
+ case pref.EnumKind:
+ if f == GoTag {
+ // Go tags used the numeric form of the enum value.
+ if n, err := strconv.ParseInt(s, 10, 32); err == nil {
+ return pref.ValueOf(pref.EnumNumber(n)), nil
+ }
+ } else {
+ // Descriptor default_value used the enum identifier.
+ if pref.Name(s).IsValid() {
+ return pref.ValueOf(s), nil
+ }
+ }
+ case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
+ if v, err := strconv.ParseInt(s, 10, 32); err == nil {
+ return pref.ValueOf(int32(v)), nil
+ }
+ case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
+ if v, err := strconv.ParseInt(s, 10, 64); err == nil {
+ return pref.ValueOf(int64(v)), nil
+ }
+ case pref.Uint32Kind, pref.Fixed32Kind:
+ if v, err := strconv.ParseUint(s, 10, 32); err == nil {
+ return pref.ValueOf(uint32(v)), nil
+ }
+ case pref.Uint64Kind, pref.Fixed64Kind:
+ if v, err := strconv.ParseUint(s, 10, 64); err == nil {
+ return pref.ValueOf(uint64(v)), nil
+ }
+ case pref.FloatKind, pref.DoubleKind:
+ var v float64
+ var err error
+ switch s {
+ case "-inf":
+ v = math.Inf(-1)
+ case "inf":
+ v = math.Inf(+1)
+ case "nan":
+ v = math.NaN()
+ default:
+ v, err = strconv.ParseFloat(s, 64)
+ }
+ if err == nil {
+ if k == pref.FloatKind {
+ return pref.ValueOf(float32(v)), nil
+ } else {
+ return pref.ValueOf(float64(v)), nil
+ }
+ }
+ case pref.StringKind:
+ // String values are already unescaped and can be used as is.
+ return pref.ValueOf(s), nil
+ case pref.BytesKind:
+ if b, ok := unmarshalBytes(s); ok {
+ return pref.ValueOf(b), nil
+ }
+ }
+ return pref.Value{}, errors.New("invalid default value for %v: %q", k, s)
+}
+
+// Marshal serializes v as the default string according to the given kind k.
+// Enums are serialized in numeric form regardless of format chosen.
+func Marshal(v pref.Value, k pref.Kind, f Format) (string, error) {
+ switch k {
+ case pref.BoolKind:
+ if f == GoTag {
+ if v.Bool() {
+ return "1", nil
+ } else {
+ return "0", nil
+ }
+ } else {
+ if v.Bool() {
+ return "true", nil
+ } else {
+ return "false", nil
+ }
+ }
+ case pref.EnumKind:
+ return strconv.FormatInt(int64(v.Enum()), 10), nil
+ case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind, pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
+ return strconv.FormatInt(v.Int(), 10), nil
+ case pref.Uint32Kind, pref.Fixed32Kind, pref.Uint64Kind, pref.Fixed64Kind:
+ return strconv.FormatUint(v.Uint(), 10), nil
+ case pref.FloatKind, pref.DoubleKind:
+ f := v.Float()
+ switch {
+ case math.IsInf(f, -1):
+ return "-inf", nil
+ case math.IsInf(f, +1):
+ return "inf", nil
+ case math.IsNaN(f):
+ return "nan", nil
+ default:
+ if k == pref.FloatKind {
+ return strconv.FormatFloat(f, 'g', -1, 32), nil
+ } else {
+ return strconv.FormatFloat(f, 'g', -1, 64), nil
+ }
+ }
+ case pref.StringKind:
+ // String values are serialized as is without any escaping.
+ return v.String(), nil
+ case pref.BytesKind:
+ if s, ok := marshalBytes(v.Bytes()); ok {
+ return s, nil
+ }
+ }
+ return "", errors.New("invalid default value for %v: %v", k, v)
+}
+
+// unmarshalBytes deserializes bytes by applying C unescaping.
+func unmarshalBytes(s string) ([]byte, bool) {
+ // Bytes values use the same escaping as the text format,
+ // however they lack the surrounding double quotes.
+ // TODO: Export unmarshalString in the text package to avoid this hack.
+ v, err := ptext.Unmarshal([]byte(`["` + s + `"]:0`))
+ if err == nil && len(v.Message()) == 1 {
+ s := v.Message()[0][0].String()
+ return []byte(s), true
+ }
+ return nil, false
+}
+
+// marshalBytes serializes bytes by using C escaping.
+// To match the exact output of protoc, this is identical to the
+// CEscape function in strutil.cc of the protoc source code.
+func marshalBytes(b []byte) (string, bool) {
+ var s []byte
+ for _, c := range b {
+ switch c {
+ case '\n':
+ s = append(s, `\n`...)
+ case '\r':
+ s = append(s, `\r`...)
+ case '\t':
+ s = append(s, `\t`...)
+ case '"':
+ s = append(s, `\"`...)
+ case '\'':
+ s = append(s, `\'`...)
+ case '\\':
+ s = append(s, `\\`...)
+ default:
+ if printableASCII := c >= 0x20 && c <= 0x7e; printableASCII {
+ s = append(s, c)
+ } else {
+ s = append(s, fmt.Sprintf(`\%03o`, c)...)
+ }
+ }
+ }
+ return string(s), true
+}
diff --git a/internal/encoding/defval/default_test.go b/internal/encoding/defval/default_test.go
new file mode 100644
index 0000000..e91eb7e
--- /dev/null
+++ b/internal/encoding/defval/default_test.go
@@ -0,0 +1,54 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package defval
+
+import (
+ "math"
+ "reflect"
+ "testing"
+
+ pref "github.com/golang/protobuf/v2/reflect/protoreflect"
+)
+
+func Test(t *testing.T) {
+ V := pref.ValueOf
+ tests := []struct {
+ val pref.Value
+ kind pref.Kind
+ strPB string
+ strGo string
+ }{
+ {V(bool(true)), pref.BoolKind, "true", "1"},
+ {V(int32(-0x1234)), pref.Int32Kind, "-4660", "-4660"},
+ {V(float32(math.Pi)), pref.FloatKind, "3.1415927", "3.1415927"},
+ {V(float64(math.Pi)), pref.DoubleKind, "3.141592653589793", "3.141592653589793"},
+ {V(string("hello, \xde\xad\xbe\xef\n")), pref.StringKind, "hello, \xde\xad\xbe\xef\n", "hello, \xde\xad\xbe\xef\n"},
+ {V([]byte("hello, \xde\xad\xbe\xef\n")), pref.BytesKind, "hello, \\336\\255\\276\\357\\n", "hello, \\336\\255\\276\\357\\n"},
+ }
+
+ for _, tt := range tests {
+ t.Run("", func(t *testing.T) {
+ gotStrPB, _ := Marshal(tt.val, tt.kind, Descriptor)
+ if gotStrPB != tt.strPB {
+ t.Errorf("Marshal(%v, %v, Descriptor) = %q, want %q", tt.val, tt.kind, gotStrPB, tt.strPB)
+ }
+
+ gotStrGo, _ := Marshal(tt.val, tt.kind, GoTag)
+ if gotStrGo != tt.strGo {
+ t.Errorf("Marshal(%v, %v, GoTag) = %q, want %q", tt.val, tt.kind, gotStrGo, tt.strGo)
+ }
+
+ gotValPB, _ := Unmarshal(tt.strPB, tt.kind, Descriptor)
+ if !reflect.DeepEqual(gotValPB.Interface(), tt.val.Interface()) {
+ t.Errorf("Unmarshal(%v, %v, Descriptor) = %q, want %q", tt.strPB, tt.kind, gotValPB, tt.val)
+ }
+
+ gotValGo, _ := Unmarshal(tt.strGo, tt.kind, GoTag)
+ if !reflect.DeepEqual(gotValGo.Interface(), tt.val.Interface()) {
+ t.Errorf("Unmarshal(%v, %v, GoTag) = %q, want %q", tt.strGo, tt.kind, gotValGo, tt.val)
+ }
+ })
+ }
+}
diff --git a/internal/encoding/tag/tag.go b/internal/encoding/tag/tag.go
index 17594d2..3afedb8 100644
--- a/internal/encoding/tag/tag.go
+++ b/internal/encoding/tag/tag.go
@@ -7,13 +7,11 @@
package tag
import (
- "fmt"
- "math"
"reflect"
"strconv"
"strings"
- ptext "github.com/golang/protobuf/v2/internal/encoding/text"
+ defval "github.com/golang/protobuf/v2/internal/encoding/defval"
scalar "github.com/golang/protobuf/v2/internal/scalar"
pref "github.com/golang/protobuf/v2/reflect/protoreflect"
ptype "github.com/golang/protobuf/v2/reflect/prototype"
@@ -115,57 +113,7 @@
// The default tag is special in that everything afterwards is the
// default regardless of the presence of commas.
s, i = tag[len("def="):], len(tag)
-
- // Defaults are parsed last, so Kind is populated.
- switch f.Kind {
- case pref.BoolKind:
- switch s {
- case "1":
- f.Default = pref.ValueOf(true)
- case "0":
- f.Default = pref.ValueOf(false)
- }
- case pref.EnumKind:
- n, _ := strconv.ParseInt(s, 10, 32)
- f.Default = pref.ValueOf(pref.EnumNumber(n))
- case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
- n, _ := strconv.ParseInt(s, 10, 32)
- f.Default = pref.ValueOf(int32(n))
- case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
- n, _ := strconv.ParseInt(s, 10, 64)
- f.Default = pref.ValueOf(int64(n))
- case pref.Uint32Kind, pref.Fixed32Kind:
- n, _ := strconv.ParseUint(s, 10, 32)
- f.Default = pref.ValueOf(uint32(n))
- case pref.Uint64Kind, pref.Fixed64Kind:
- n, _ := strconv.ParseUint(s, 10, 64)
- f.Default = pref.ValueOf(uint64(n))
- case pref.FloatKind, pref.DoubleKind:
- n, _ := strconv.ParseFloat(s, 64)
- switch s {
- case "nan":
- n = math.NaN()
- case "inf":
- n = math.Inf(+1)
- case "-inf":
- n = math.Inf(-1)
- }
- if f.Kind == pref.FloatKind {
- f.Default = pref.ValueOf(float32(n))
- } else {
- f.Default = pref.ValueOf(float64(n))
- }
- case pref.StringKind:
- f.Default = pref.ValueOf(string(s))
- case pref.BytesKind:
- // The default value is in escaped form (C-style).
- // TODO: Export unmarshalString in the text package to avoid this hack.
- v, err := ptext.Unmarshal([]byte(`["` + s + `"]:0`))
- if err == nil && len(v.Message()) == 1 {
- s := v.Message()[0][0].String()
- f.Default = pref.ValueOf([]byte(s))
- }
- }
+ f.Default, _ = defval.Unmarshal(s, f.Kind, defval.GoTag)
}
tag = strings.TrimPrefix(tag[i:], ",")
}
@@ -242,60 +190,7 @@
}
// This must appear last in the tag, since commas in strings aren't escaped.
if fd.HasDefault() {
- var def string
- switch fd.Kind() {
- case pref.BoolKind:
- if fd.Default().Bool() {
- def = "1"
- } else {
- def = "0"
- }
- case pref.BytesKind:
- // Preserve protoc-gen-go's historical output of escaped bytes.
- // This behavior is buggy, but fixing it makes it impossible to
- // distinguish between the escaped and unescaped forms.
- //
- // To match the exact output of protoc, this is identical to the
- // CEscape function in strutil.cc of the protoc source code.
- var b []byte
- for _, c := range fd.Default().Bytes() {
- switch c {
- case '\n':
- b = append(b, `\n`...)
- case '\r':
- b = append(b, `\r`...)
- case '\t':
- b = append(b, `\t`...)
- case '"':
- b = append(b, `\"`...)
- case '\'':
- b = append(b, `\'`...)
- case '\\':
- b = append(b, `\\`...)
- default:
- if c >= 0x20 && c <= 0x7e { // printable ASCII
- b = append(b, c)
- } else {
- b = append(b, fmt.Sprintf(`\%03o`, c)...)
- }
- }
- }
- def = string(b)
- case pref.FloatKind, pref.DoubleKind:
- f := fd.Default().Float()
- switch {
- case math.IsInf(f, -1):
- def = "-inf"
- case math.IsInf(f, 1):
- def = "inf"
- case math.IsNaN(f):
- def = "nan"
- default:
- def = fmt.Sprint(fd.Default().Interface())
- }
- default:
- def = fmt.Sprint(fd.Default().Interface())
- }
+ def, _ := defval.Marshal(fd.Default(), fd.Kind(), defval.GoTag)
tag = append(tag, "def="+def)
}
return strings.Join(tag, ",")