internal/encoding/defval: unify logic for handling default values

Logic for serializing the default value in textual form exists in
multiple places in trivially similar forms. Centralize that logic.

Change-Id: I4408ddfeef2c0dfa5c7468e01a4d4df5654ae57f
Reviewed-on: https://go-review.googlesource.com/c/153022
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/internal/encoding/defval/default.go b/internal/encoding/defval/default.go
new file mode 100644
index 0000000..e2c62a9
--- /dev/null
+++ b/internal/encoding/defval/default.go
@@ -0,0 +1,209 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package defval marshals and unmarshals textual forms of default values.
+//
+// This package handles both the form historically used in Go struct field tags
+// and also the form used by google.protobuf.FieldDescriptorProto.default_value
+// since they differ in superficial ways.
+package defval
+
+import (
+	"fmt"
+	"math"
+	"strconv"
+
+	ptext "github.com/golang/protobuf/v2/internal/encoding/text"
+	errors "github.com/golang/protobuf/v2/internal/errors"
+	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
+)
+
+// Format is the serialization format used to represent the default value.
+type Format int
+
+const (
+	_ Format = iota
+
+	// Descriptor uses the serialization format that protoc uses with the
+	// google.protobuf.FieldDescriptorProto.default_value field.
+	Descriptor
+
+	// GoTag uses the historical serialization format in Go struct field tags.
+	GoTag
+)
+
+// Unmarshal deserializes the default string s according to the given kind k.
+// When using the Descriptor format on an enum kind, a Value of type string
+// representing the enum identifier is returned. It is the caller's
+// responsibility to verify that the identifier is valid.
+func Unmarshal(s string, k pref.Kind, f Format) (pref.Value, error) {
+	switch k {
+	case pref.BoolKind:
+		if f == GoTag {
+			switch s {
+			case "1":
+				return pref.ValueOf(true), nil
+			case "0":
+				return pref.ValueOf(false), nil
+			}
+		} else {
+			switch s {
+			case "true":
+				return pref.ValueOf(true), nil
+			case "false":
+				return pref.ValueOf(false), nil
+			}
+		}
+	case pref.EnumKind:
+		if f == GoTag {
+			// Go tags used the numeric form of the enum value.
+			if n, err := strconv.ParseInt(s, 10, 32); err == nil {
+				return pref.ValueOf(pref.EnumNumber(n)), nil
+			}
+		} else {
+			// Descriptor default_value used the enum identifier.
+			if pref.Name(s).IsValid() {
+				return pref.ValueOf(s), nil
+			}
+		}
+	case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
+		if v, err := strconv.ParseInt(s, 10, 32); err == nil {
+			return pref.ValueOf(int32(v)), nil
+		}
+	case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
+		if v, err := strconv.ParseInt(s, 10, 64); err == nil {
+			return pref.ValueOf(int64(v)), nil
+		}
+	case pref.Uint32Kind, pref.Fixed32Kind:
+		if v, err := strconv.ParseUint(s, 10, 32); err == nil {
+			return pref.ValueOf(uint32(v)), nil
+		}
+	case pref.Uint64Kind, pref.Fixed64Kind:
+		if v, err := strconv.ParseUint(s, 10, 64); err == nil {
+			return pref.ValueOf(uint64(v)), nil
+		}
+	case pref.FloatKind, pref.DoubleKind:
+		var v float64
+		var err error
+		switch s {
+		case "-inf":
+			v = math.Inf(-1)
+		case "inf":
+			v = math.Inf(+1)
+		case "nan":
+			v = math.NaN()
+		default:
+			v, err = strconv.ParseFloat(s, 64)
+		}
+		if err == nil {
+			if k == pref.FloatKind {
+				return pref.ValueOf(float32(v)), nil
+			} else {
+				return pref.ValueOf(float64(v)), nil
+			}
+		}
+	case pref.StringKind:
+		// String values are already unescaped and can be used as is.
+		return pref.ValueOf(s), nil
+	case pref.BytesKind:
+		if b, ok := unmarshalBytes(s); ok {
+			return pref.ValueOf(b), nil
+		}
+	}
+	return pref.Value{}, errors.New("invalid default value for %v: %q", k, s)
+}
+
+// Marshal serializes v as the default string according to the given kind k.
+// Enums are serialized in numeric form regardless of format chosen.
+func Marshal(v pref.Value, k pref.Kind, f Format) (string, error) {
+	switch k {
+	case pref.BoolKind:
+		if f == GoTag {
+			if v.Bool() {
+				return "1", nil
+			} else {
+				return "0", nil
+			}
+		} else {
+			if v.Bool() {
+				return "true", nil
+			} else {
+				return "false", nil
+			}
+		}
+	case pref.EnumKind:
+		return strconv.FormatInt(int64(v.Enum()), 10), nil
+	case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind, pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
+		return strconv.FormatInt(v.Int(), 10), nil
+	case pref.Uint32Kind, pref.Fixed32Kind, pref.Uint64Kind, pref.Fixed64Kind:
+		return strconv.FormatUint(v.Uint(), 10), nil
+	case pref.FloatKind, pref.DoubleKind:
+		f := v.Float()
+		switch {
+		case math.IsInf(f, -1):
+			return "-inf", nil
+		case math.IsInf(f, +1):
+			return "inf", nil
+		case math.IsNaN(f):
+			return "nan", nil
+		default:
+			if k == pref.FloatKind {
+				return strconv.FormatFloat(f, 'g', -1, 32), nil
+			} else {
+				return strconv.FormatFloat(f, 'g', -1, 64), nil
+			}
+		}
+	case pref.StringKind:
+		// String values are serialized as is without any escaping.
+		return v.String(), nil
+	case pref.BytesKind:
+		if s, ok := marshalBytes(v.Bytes()); ok {
+			return s, nil
+		}
+	}
+	return "", errors.New("invalid default value for %v: %v", k, v)
+}
+
+// unmarshalBytes deserializes bytes by applying C unescaping.
+func unmarshalBytes(s string) ([]byte, bool) {
+	// Bytes values use the same escaping as the text format,
+	// however they lack the surrounding double quotes.
+	// TODO: Export unmarshalString in the text package to avoid this hack.
+	v, err := ptext.Unmarshal([]byte(`["` + s + `"]:0`))
+	if err == nil && len(v.Message()) == 1 {
+		s := v.Message()[0][0].String()
+		return []byte(s), true
+	}
+	return nil, false
+}
+
+// marshalBytes serializes bytes by using C escaping.
+// To match the exact output of protoc, this is identical to the
+// CEscape function in strutil.cc of the protoc source code.
+func marshalBytes(b []byte) (string, bool) {
+	var s []byte
+	for _, c := range b {
+		switch c {
+		case '\n':
+			s = append(s, `\n`...)
+		case '\r':
+			s = append(s, `\r`...)
+		case '\t':
+			s = append(s, `\t`...)
+		case '"':
+			s = append(s, `\"`...)
+		case '\'':
+			s = append(s, `\'`...)
+		case '\\':
+			s = append(s, `\\`...)
+		default:
+			if printableASCII := c >= 0x20 && c <= 0x7e; printableASCII {
+				s = append(s, c)
+			} else {
+				s = append(s, fmt.Sprintf(`\%03o`, c)...)
+			}
+		}
+	}
+	return string(s), true
+}
diff --git a/internal/encoding/defval/default_test.go b/internal/encoding/defval/default_test.go
new file mode 100644
index 0000000..e91eb7e
--- /dev/null
+++ b/internal/encoding/defval/default_test.go
@@ -0,0 +1,54 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package defval
+
+import (
+	"math"
+	"reflect"
+	"testing"
+
+	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
+)
+
+func Test(t *testing.T) {
+	V := pref.ValueOf
+	tests := []struct {
+		val   pref.Value
+		kind  pref.Kind
+		strPB string
+		strGo string
+	}{
+		{V(bool(true)), pref.BoolKind, "true", "1"},
+		{V(int32(-0x1234)), pref.Int32Kind, "-4660", "-4660"},
+		{V(float32(math.Pi)), pref.FloatKind, "3.1415927", "3.1415927"},
+		{V(float64(math.Pi)), pref.DoubleKind, "3.141592653589793", "3.141592653589793"},
+		{V(string("hello, \xde\xad\xbe\xef\n")), pref.StringKind, "hello, \xde\xad\xbe\xef\n", "hello, \xde\xad\xbe\xef\n"},
+		{V([]byte("hello, \xde\xad\xbe\xef\n")), pref.BytesKind, "hello, \\336\\255\\276\\357\\n", "hello, \\336\\255\\276\\357\\n"},
+	}
+
+	for _, tt := range tests {
+		t.Run("", func(t *testing.T) {
+			gotStrPB, _ := Marshal(tt.val, tt.kind, Descriptor)
+			if gotStrPB != tt.strPB {
+				t.Errorf("Marshal(%v, %v, Descriptor) = %q, want %q", tt.val, tt.kind, gotStrPB, tt.strPB)
+			}
+
+			gotStrGo, _ := Marshal(tt.val, tt.kind, GoTag)
+			if gotStrGo != tt.strGo {
+				t.Errorf("Marshal(%v, %v, GoTag) = %q, want %q", tt.val, tt.kind, gotStrGo, tt.strGo)
+			}
+
+			gotValPB, _ := Unmarshal(tt.strPB, tt.kind, Descriptor)
+			if !reflect.DeepEqual(gotValPB.Interface(), tt.val.Interface()) {
+				t.Errorf("Unmarshal(%v, %v, Descriptor) = %q, want %q", tt.strPB, tt.kind, gotValPB, tt.val)
+			}
+
+			gotValGo, _ := Unmarshal(tt.strGo, tt.kind, GoTag)
+			if !reflect.DeepEqual(gotValGo.Interface(), tt.val.Interface()) {
+				t.Errorf("Unmarshal(%v, %v, GoTag) = %q, want %q", tt.strGo, tt.kind, gotValGo, tt.val)
+			}
+		})
+	}
+}
diff --git a/internal/encoding/tag/tag.go b/internal/encoding/tag/tag.go
index 17594d2..3afedb8 100644
--- a/internal/encoding/tag/tag.go
+++ b/internal/encoding/tag/tag.go
@@ -7,13 +7,11 @@
 package tag
 
 import (
-	"fmt"
-	"math"
 	"reflect"
 	"strconv"
 	"strings"
 
-	ptext "github.com/golang/protobuf/v2/internal/encoding/text"
+	defval "github.com/golang/protobuf/v2/internal/encoding/defval"
 	scalar "github.com/golang/protobuf/v2/internal/scalar"
 	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
 	ptype "github.com/golang/protobuf/v2/reflect/prototype"
@@ -115,57 +113,7 @@
 			// The default tag is special in that everything afterwards is the
 			// default regardless of the presence of commas.
 			s, i = tag[len("def="):], len(tag)
-
-			// Defaults are parsed last, so Kind is populated.
-			switch f.Kind {
-			case pref.BoolKind:
-				switch s {
-				case "1":
-					f.Default = pref.ValueOf(true)
-				case "0":
-					f.Default = pref.ValueOf(false)
-				}
-			case pref.EnumKind:
-				n, _ := strconv.ParseInt(s, 10, 32)
-				f.Default = pref.ValueOf(pref.EnumNumber(n))
-			case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
-				n, _ := strconv.ParseInt(s, 10, 32)
-				f.Default = pref.ValueOf(int32(n))
-			case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
-				n, _ := strconv.ParseInt(s, 10, 64)
-				f.Default = pref.ValueOf(int64(n))
-			case pref.Uint32Kind, pref.Fixed32Kind:
-				n, _ := strconv.ParseUint(s, 10, 32)
-				f.Default = pref.ValueOf(uint32(n))
-			case pref.Uint64Kind, pref.Fixed64Kind:
-				n, _ := strconv.ParseUint(s, 10, 64)
-				f.Default = pref.ValueOf(uint64(n))
-			case pref.FloatKind, pref.DoubleKind:
-				n, _ := strconv.ParseFloat(s, 64)
-				switch s {
-				case "nan":
-					n = math.NaN()
-				case "inf":
-					n = math.Inf(+1)
-				case "-inf":
-					n = math.Inf(-1)
-				}
-				if f.Kind == pref.FloatKind {
-					f.Default = pref.ValueOf(float32(n))
-				} else {
-					f.Default = pref.ValueOf(float64(n))
-				}
-			case pref.StringKind:
-				f.Default = pref.ValueOf(string(s))
-			case pref.BytesKind:
-				// The default value is in escaped form (C-style).
-				// TODO: Export unmarshalString in the text package to avoid this hack.
-				v, err := ptext.Unmarshal([]byte(`["` + s + `"]:0`))
-				if err == nil && len(v.Message()) == 1 {
-					s := v.Message()[0][0].String()
-					f.Default = pref.ValueOf([]byte(s))
-				}
-			}
+			f.Default, _ = defval.Unmarshal(s, f.Kind, defval.GoTag)
 		}
 		tag = strings.TrimPrefix(tag[i:], ",")
 	}
@@ -242,60 +190,7 @@
 	}
 	// This must appear last in the tag, since commas in strings aren't escaped.
 	if fd.HasDefault() {
-		var def string
-		switch fd.Kind() {
-		case pref.BoolKind:
-			if fd.Default().Bool() {
-				def = "1"
-			} else {
-				def = "0"
-			}
-		case pref.BytesKind:
-			// Preserve protoc-gen-go's historical output of escaped bytes.
-			// This behavior is buggy, but fixing it makes it impossible to
-			// distinguish between the escaped and unescaped forms.
-			//
-			// To match the exact output of protoc, this is identical to the
-			// CEscape function in strutil.cc of the protoc source code.
-			var b []byte
-			for _, c := range fd.Default().Bytes() {
-				switch c {
-				case '\n':
-					b = append(b, `\n`...)
-				case '\r':
-					b = append(b, `\r`...)
-				case '\t':
-					b = append(b, `\t`...)
-				case '"':
-					b = append(b, `\"`...)
-				case '\'':
-					b = append(b, `\'`...)
-				case '\\':
-					b = append(b, `\\`...)
-				default:
-					if c >= 0x20 && c <= 0x7e { // printable ASCII
-						b = append(b, c)
-					} else {
-						b = append(b, fmt.Sprintf(`\%03o`, c)...)
-					}
-				}
-			}
-			def = string(b)
-		case pref.FloatKind, pref.DoubleKind:
-			f := fd.Default().Float()
-			switch {
-			case math.IsInf(f, -1):
-				def = "-inf"
-			case math.IsInf(f, 1):
-				def = "inf"
-			case math.IsNaN(f):
-				def = "nan"
-			default:
-				def = fmt.Sprint(fd.Default().Interface())
-			}
-		default:
-			def = fmt.Sprint(fd.Default().Interface())
-		}
+		def, _ := defval.Marshal(fd.Default(), fd.Kind(), defval.GoTag)
 		tag = append(tag, "def="+def)
 	}
 	return strings.Join(tag, ",")