goprotobuf: Implement SetDefaults function.
R=r
CC=golang-dev
http://codereview.appspot.com/4972046
diff --git a/proto/all_test.go b/proto/all_test.go
index 256ab4c..0930d08 100644
--- a/proto/all_test.go
+++ b/proto/all_test.go
@@ -35,6 +35,7 @@
import (
"bytes"
"fmt"
+ "math"
"os"
"reflect"
"strings"
@@ -1221,6 +1222,68 @@
}
}
+func TestAllSetDefaults(t *testing.T) {
+ // Exercise SetDefaults with all scalar field types.
+ m := &Defaults{
+ // NaN != NaN, so override that here.
+ F_Nan: Float32(1.7),
+ }
+ expected := &Defaults{
+ F_Bool: Bool(true),
+ F_Int32: Int32(32),
+ F_Int64: Int64(64),
+ F_Fixed32: Uint32(320),
+ F_Fixed64: Uint64(640),
+ F_Uint32: Uint32(3200),
+ F_Uint64: Uint64(6400),
+ F_Float: Float32(314159),
+ F_Double: Float64(271828),
+ F_String: String(`hello, "world!"` + "\n"),
+ F_Bytes: []byte("Bignose"),
+ F_Sint32: Int32(-32),
+ F_Sint64: Int64(-64),
+ F_Enum: NewDefaults_Color(Defaults_GREEN),
+ F_Pinf: Float32(float32(math.Inf(1))),
+ F_Ninf: Float32(float32(math.Inf(-1))),
+ F_Nan: Float32(1.7),
+ }
+ SetDefaults(m)
+ if !Equal(m, expected) {
+ t.Errorf(" got %v\nwant %v", m, expected)
+ }
+}
+
+func TestSetDefaultsWithSetField(t *testing.T) {
+ // Check that a set value is not overridden.
+ m := &Defaults{
+ F_Int32: Int32(12),
+ }
+ SetDefaults(m)
+ if v := GetInt32(m.F_Int32); v != 12 {
+ t.Errorf("m.FInt32 = %v, want 12", v)
+ }
+}
+
+func TestSetDefaultsWithSubMessage(t *testing.T) {
+ m := &OtherMessage{
+ Key: Int64(123),
+ Inner: &InnerMessage{
+ Host: String("gopher"),
+ },
+ }
+ expected := &OtherMessage{
+ Key: Int64(123),
+ Inner: &InnerMessage{
+ Host: String("gopher"),
+ Port: Int32(4000),
+ },
+ }
+ SetDefaults(m)
+ if !Equal(m, expected) {
+ t.Errorf(" got %v\nwant %v", m, expected)
+ }
+}
+
func BenchmarkMarshal(b *testing.B) {
b.StopTimer()
diff --git a/proto/equal.go b/proto/equal.go
index bbcca82..3a12a7f 100644
--- a/proto/equal.go
+++ b/proto/equal.go
@@ -51,6 +51,8 @@
corresponding fields are equal, unknown field sets
are equal, and extensions sets are equal.
- Two set scalar fields are equal iff their values are equal.
+ If the fields are of a floating-point type, remember that
+ NaN != x for all x, including NaN.
- Two repeated fields are equal iff their lengths are the same,
and their corresponding elements are equal.
- Two unset fields are equal.
diff --git a/proto/lib.go b/proto/lib.go
index f794f53..b816e91 100644
--- a/proto/lib.go
+++ b/proto/lib.go
@@ -166,7 +166,10 @@
import (
"fmt"
+ "log"
+ "reflect"
"strconv"
+ "sync"
)
// Stats records allocation details about the protocol buffer encoders
@@ -535,3 +538,233 @@
o.buf = obuf
o.index = index
}
+
+// SetDefaults sets unset protocol buffer fields to their default values.
+// It only modifies fields that are both unset and have defined defaults.
+// It recursively sets default values in any non-nil sub-messages.
+func SetDefaults(pb interface{}) {
+ v := reflect.ValueOf(pb)
+ if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
+ log.Printf("proto: hit non-pointer-to-struct %v", v)
+ }
+ setDefaults(v, true, false)
+}
+
+// v is a pointer to a struct.
+func setDefaults(v reflect.Value, recur, zeros bool) {
+ v = v.Elem()
+
+ defaultMu.Lock()
+ dm, ok := defaults[v.Type()]
+ defaultMu.Unlock()
+ if !ok {
+ dm = buildDefaultMessage(v.Type())
+ defaultMu.Lock()
+ defaults[v.Type()] = dm
+ defaultMu.Unlock()
+ }
+
+ for _, sf := range dm.scalars {
+ f := v.Field(sf.index)
+ if !f.IsNil() {
+ // field already set
+ continue
+ }
+ dv := sf.value
+ if dv == nil && !zeros {
+ // no explicit default, and don't want to set zeros
+ continue
+ }
+ fptr := f.Addr().Interface() // **T
+ // TODO: Consider batching the allocations we do here.
+ switch sf.kind {
+ case reflect.Bool:
+ b := new(bool)
+ if dv != nil {
+ *b = dv.(bool)
+ }
+ *(fptr.(**bool)) = b
+ case reflect.Float32:
+ f := new(float32)
+ if dv != nil {
+ *f = dv.(float32)
+ }
+ *(fptr.(**float32)) = f
+ case reflect.Float64:
+ f := new(float64)
+ if dv != nil {
+ *f = dv.(float64)
+ }
+ *(fptr.(**float64)) = f
+ case reflect.Int32:
+ // might be an enum
+ if ft := f.Type(); ft != int32PtrType {
+ // enum
+ f.Set(reflect.New(ft.Elem()))
+ if dv != nil {
+ f.Elem().SetInt(int64(dv.(int32)))
+ }
+ } else {
+ // int32 field
+ i := new(int32)
+ if dv != nil {
+ *i = dv.(int32)
+ }
+ *(fptr.(**int32)) = i
+ }
+ case reflect.Int64:
+ i := new(int64)
+ if dv != nil {
+ *i = dv.(int64)
+ }
+ *(fptr.(**int64)) = i
+ case reflect.String:
+ s := new(string)
+ if dv != nil {
+ *s = dv.(string)
+ }
+ *(fptr.(**string)) = s
+ case reflect.Uint8:
+ // exceptional case: []byte
+ var b []byte
+ if dv != nil {
+ db := dv.([]byte)
+ b = make([]byte, len(db))
+ copy(b, db)
+ } else {
+ b = []byte{}
+ }
+ *(fptr.(*[]byte)) = b
+ case reflect.Uint32:
+ u := new(uint32)
+ if dv != nil {
+ *u = dv.(uint32)
+ }
+ *(fptr.(**uint32)) = u
+ case reflect.Uint64:
+ u := new(uint64)
+ if dv != nil {
+ *u = dv.(uint64)
+ }
+ *(fptr.(**uint64)) = u
+ default:
+ log.Printf("proto: can't set default for field %v (sf.kind=%v)", f, sf.kind)
+ }
+ }
+
+ for _, ni := range dm.nested {
+ setDefaults(v.Field(ni), recur, zeros)
+ }
+}
+
+var (
+ // defaults maps a protocol buffer struct type to a slice of the fields,
+ // with its scalar fields set to their proto-declared non-zero default values.
+ defaultMu sync.Mutex
+ defaults = make(map[reflect.Type]defaultMessage)
+
+ int32PtrType = reflect.TypeOf((*int32)(nil))
+)
+
+// defaultMessage represents information about the default values of a message.
+type defaultMessage struct {
+ scalars []scalarField
+ nested []int // struct field index of nested messages
+}
+
+type scalarField struct {
+ index int // struct field index
+ kind reflect.Kind // element type (the T in *T or []T)
+ value interface{} // the proto-declared default value, or nil
+}
+
+// t is a struct type.
+func buildDefaultMessage(t reflect.Type) (dm defaultMessage) {
+ sprop := GetProperties(t)
+ for _, prop := range sprop.Prop {
+ fi := sprop.tags[prop.Tag]
+ ft := t.Field(fi).Type
+
+ // nested messages
+ if ft.Kind() == reflect.Ptr && ft.Elem().Kind() == reflect.Struct {
+ dm.nested = append(dm.nested, fi)
+ continue
+ }
+
+ sf := scalarField{
+ index: fi,
+ kind: ft.Elem().Kind(),
+ }
+
+ // scalar fields without defaults
+ if prop.Default == "" {
+ dm.scalars = append(dm.scalars, sf)
+ continue
+ }
+
+ // a scalar field: either *T or []byte
+ switch ft.Elem().Kind() {
+ case reflect.Bool:
+ x, err := strconv.Atob(prop.Default)
+ if err != nil {
+ log.Printf("proto: bad default bool %q: %v", prop.Default, err)
+ continue
+ }
+ sf.value = x
+ case reflect.Float32:
+ x, err := strconv.Atof32(prop.Default)
+ if err != nil {
+ log.Printf("proto: bad default float32 %q: %v", prop.Default, err)
+ continue
+ }
+ sf.value = x
+ case reflect.Float64:
+ x, err := strconv.Atof64(prop.Default)
+ if err != nil {
+ log.Printf("proto: bad default float64 %q: %v", prop.Default, err)
+ continue
+ }
+ sf.value = x
+ case reflect.Int32:
+ x, err := strconv.Atoi64(prop.Default)
+ if err != nil {
+ log.Printf("proto: bad default int32 %q: %v", prop.Default, err)
+ continue
+ }
+ sf.value = int32(x)
+ case reflect.Int64:
+ x, err := strconv.Atoi64(prop.Default)
+ if err != nil {
+ log.Printf("proto: bad default int64 %q: %v", prop.Default, err)
+ continue
+ }
+ sf.value = x
+ case reflect.String:
+ sf.value = prop.Default
+ case reflect.Uint8:
+ // []byte (not *uint8)
+ sf.value = []byte(prop.Default)
+ case reflect.Uint32:
+ x, err := strconv.Atoui64(prop.Default)
+ if err != nil {
+ log.Printf("proto: bad default uint32 %q: %v", prop.Default, err)
+ continue
+ }
+ sf.value = uint32(x)
+ case reflect.Uint64:
+ x, err := strconv.Atoui64(prop.Default)
+ if err != nil {
+ log.Printf("proto: bad default uint64 %q: %v", prop.Default, err)
+ continue
+ }
+ sf.value = x
+ default:
+ log.Printf("proto: unhandled def kind %v", ft.Elem().Kind())
+ continue
+ }
+
+ dm.scalars = append(dm.scalars, sf)
+ }
+
+ return dm
+}
diff --git a/proto/testdata/test.proto b/proto/testdata/test.proto
index 70ff890..12f2dec 100644
--- a/proto/testdata/test.proto
+++ b/proto/testdata/test.proto
@@ -242,5 +242,35 @@
repeated group Message = 1 {
required string name = 2;
required int32 count = 3;
- };
-};
+ }
+}
+
+message Defaults {
+ enum Color {
+ RED = 0;
+ GREEN = 1;
+ BLUE = 2;
+ }
+
+ // Default-valued fields of all basic types.
+ // Same as GoTest, but copied here to make testing easier.
+ optional bool F_Bool = 1 [default=true];
+ optional int32 F_Int32 = 2 [default=32];
+ optional int64 F_Int64 = 3 [default=64];
+ optional fixed32 F_Fixed32 = 4 [default=320];
+ optional fixed64 F_Fixed64 = 5 [default=640];
+ optional uint32 F_Uint32 = 6 [default=3200];
+ optional uint64 F_Uint64 = 7 [default=6400];
+ optional float F_Float = 8 [default=314159.];
+ optional double F_Double = 9 [default=271828.];
+ optional string F_String = 10 [default="hello, \"world!\"\n"];
+ optional bytes F_Bytes = 11 [default="Bignose"];
+ optional sint32 F_Sint32 = 12 [default=-32];
+ optional sint64 F_Sint64 = 13 [default=-64];
+ optional Color F_Enum = 14 [default=GREEN];
+
+ // More fields with crazy defaults.
+ optional float F_Pinf = 15 [default=inf];
+ optional float F_Ninf = 16 [default=-inf];
+ optional float F_Nan = 17 [default=nan];
+}