goprotobuf: Implement SetDefaults function.

R=r
CC=golang-dev
http://codereview.appspot.com/4972046
diff --git a/proto/all_test.go b/proto/all_test.go
index 256ab4c..0930d08 100644
--- a/proto/all_test.go
+++ b/proto/all_test.go
@@ -35,6 +35,7 @@
 import (
 	"bytes"
 	"fmt"
+	"math"
 	"os"
 	"reflect"
 	"strings"
@@ -1221,6 +1222,68 @@
 	}
 }
 
+func TestAllSetDefaults(t *testing.T) {
+	// Exercise SetDefaults with all scalar field types.
+	m := &Defaults{
+		// NaN != NaN, so override that here.
+		F_Nan: Float32(1.7),
+	}
+	expected := &Defaults{
+		F_Bool:    Bool(true),
+		F_Int32:   Int32(32),
+		F_Int64:   Int64(64),
+		F_Fixed32: Uint32(320),
+		F_Fixed64: Uint64(640),
+		F_Uint32:  Uint32(3200),
+		F_Uint64:  Uint64(6400),
+		F_Float:   Float32(314159),
+		F_Double:  Float64(271828),
+		F_String:  String(`hello, "world!"` + "\n"),
+		F_Bytes:   []byte("Bignose"),
+		F_Sint32:  Int32(-32),
+		F_Sint64:  Int64(-64),
+		F_Enum:    NewDefaults_Color(Defaults_GREEN),
+		F_Pinf:    Float32(float32(math.Inf(1))),
+		F_Ninf:    Float32(float32(math.Inf(-1))),
+		F_Nan:     Float32(1.7),
+	}
+	SetDefaults(m)
+	if !Equal(m, expected) {
+		t.Errorf(" got %v\nwant %v", m, expected)
+	}
+}
+
+func TestSetDefaultsWithSetField(t *testing.T) {
+	// Check that a set value is not overridden.
+	m := &Defaults{
+		F_Int32: Int32(12),
+	}
+	SetDefaults(m)
+	if v := GetInt32(m.F_Int32); v != 12 {
+		t.Errorf("m.FInt32 = %v, want 12", v)
+	}
+}
+
+func TestSetDefaultsWithSubMessage(t *testing.T) {
+	m := &OtherMessage{
+		Key: Int64(123),
+		Inner: &InnerMessage{
+			Host: String("gopher"),
+		},
+	}
+	expected := &OtherMessage{
+		Key: Int64(123),
+		Inner: &InnerMessage{
+			Host: String("gopher"),
+			Port: Int32(4000),
+		},
+	}
+	SetDefaults(m)
+	if !Equal(m, expected) {
+		t.Errorf(" got %v\nwant %v", m, expected)
+	}
+}
+
 func BenchmarkMarshal(b *testing.B) {
 	b.StopTimer()
 
diff --git a/proto/equal.go b/proto/equal.go
index bbcca82..3a12a7f 100644
--- a/proto/equal.go
+++ b/proto/equal.go
@@ -51,6 +51,8 @@
     corresponding fields are equal, unknown field sets
     are equal, and extensions sets are equal.
   - Two set scalar fields are equal iff their values are equal.
+    If the fields are of a floating-point type, remember that
+    NaN != x for all x, including NaN.
   - Two repeated fields are equal iff their lengths are the same,
     and their corresponding elements are equal.
   - Two unset fields are equal.
diff --git a/proto/lib.go b/proto/lib.go
index f794f53..b816e91 100644
--- a/proto/lib.go
+++ b/proto/lib.go
@@ -166,7 +166,10 @@
 
 import (
 	"fmt"
+	"log"
+	"reflect"
 	"strconv"
+	"sync"
 )
 
 // Stats records allocation details about the protocol buffer encoders
@@ -535,3 +538,233 @@
 	o.buf = obuf
 	o.index = index
 }
+
+// SetDefaults sets unset protocol buffer fields to their default values.
+// It only modifies fields that are both unset and have defined defaults.
+// It recursively sets default values in any non-nil sub-messages.
+func SetDefaults(pb interface{}) {
+	v := reflect.ValueOf(pb)
+	if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
+		log.Printf("proto: hit non-pointer-to-struct %v", v)
+	}
+	setDefaults(v, true, false)
+}
+
+// v is a pointer to a struct.
+func setDefaults(v reflect.Value, recur, zeros bool) {
+	v = v.Elem()
+
+	defaultMu.Lock()
+	dm, ok := defaults[v.Type()]
+	defaultMu.Unlock()
+	if !ok {
+		dm = buildDefaultMessage(v.Type())
+		defaultMu.Lock()
+		defaults[v.Type()] = dm
+		defaultMu.Unlock()
+	}
+
+	for _, sf := range dm.scalars {
+		f := v.Field(sf.index)
+		if !f.IsNil() {
+			// field already set
+			continue
+		}
+		dv := sf.value
+		if dv == nil && !zeros {
+			// no explicit default, and don't want to set zeros
+			continue
+		}
+		fptr := f.Addr().Interface() // **T
+		// TODO: Consider batching the allocations we do here.
+		switch sf.kind {
+		case reflect.Bool:
+			b := new(bool)
+			if dv != nil {
+				*b = dv.(bool)
+			}
+			*(fptr.(**bool)) = b
+		case reflect.Float32:
+			f := new(float32)
+			if dv != nil {
+				*f = dv.(float32)
+			}
+			*(fptr.(**float32)) = f
+		case reflect.Float64:
+			f := new(float64)
+			if dv != nil {
+				*f = dv.(float64)
+			}
+			*(fptr.(**float64)) = f
+		case reflect.Int32:
+			// might be an enum
+			if ft := f.Type(); ft != int32PtrType {
+				// enum
+				f.Set(reflect.New(ft.Elem()))
+				if dv != nil {
+					f.Elem().SetInt(int64(dv.(int32)))
+				}
+			} else {
+				// int32 field
+				i := new(int32)
+				if dv != nil {
+					*i = dv.(int32)
+				}
+				*(fptr.(**int32)) = i
+			}
+		case reflect.Int64:
+			i := new(int64)
+			if dv != nil {
+				*i = dv.(int64)
+			}
+			*(fptr.(**int64)) = i
+		case reflect.String:
+			s := new(string)
+			if dv != nil {
+				*s = dv.(string)
+			}
+			*(fptr.(**string)) = s
+		case reflect.Uint8:
+			// exceptional case: []byte
+			var b []byte
+			if dv != nil {
+				db := dv.([]byte)
+				b = make([]byte, len(db))
+				copy(b, db)
+			} else {
+				b = []byte{}
+			}
+			*(fptr.(*[]byte)) = b
+		case reflect.Uint32:
+			u := new(uint32)
+			if dv != nil {
+				*u = dv.(uint32)
+			}
+			*(fptr.(**uint32)) = u
+		case reflect.Uint64:
+			u := new(uint64)
+			if dv != nil {
+				*u = dv.(uint64)
+			}
+			*(fptr.(**uint64)) = u
+		default:
+			log.Printf("proto: can't set default for field %v (sf.kind=%v)", f, sf.kind)
+		}
+	}
+
+	for _, ni := range dm.nested {
+		setDefaults(v.Field(ni), recur, zeros)
+	}
+}
+
+var (
+	// defaults maps a protocol buffer struct type to a slice of the fields,
+	// with its scalar fields set to their proto-declared non-zero default values.
+	defaultMu sync.Mutex
+	defaults  = make(map[reflect.Type]defaultMessage)
+
+	int32PtrType = reflect.TypeOf((*int32)(nil))
+)
+
+// defaultMessage represents information about the default values of a message.
+type defaultMessage struct {
+	scalars []scalarField
+	nested  []int // struct field index of nested messages
+}
+
+type scalarField struct {
+	index int          // struct field index
+	kind  reflect.Kind // element type (the T in *T or []T)
+	value interface{}  // the proto-declared default value, or nil
+}
+
+// t is a struct type.
+func buildDefaultMessage(t reflect.Type) (dm defaultMessage) {
+	sprop := GetProperties(t)
+	for _, prop := range sprop.Prop {
+		fi := sprop.tags[prop.Tag]
+		ft := t.Field(fi).Type
+
+		// nested messages
+		if ft.Kind() == reflect.Ptr && ft.Elem().Kind() == reflect.Struct {
+			dm.nested = append(dm.nested, fi)
+			continue
+		}
+
+		sf := scalarField{
+			index: fi,
+			kind:  ft.Elem().Kind(),
+		}
+
+		// scalar fields without defaults
+		if prop.Default == "" {
+			dm.scalars = append(dm.scalars, sf)
+			continue
+		}
+
+		// a scalar field: either *T or []byte
+		switch ft.Elem().Kind() {
+		case reflect.Bool:
+			x, err := strconv.Atob(prop.Default)
+			if err != nil {
+				log.Printf("proto: bad default bool %q: %v", prop.Default, err)
+				continue
+			}
+			sf.value = x
+		case reflect.Float32:
+			x, err := strconv.Atof32(prop.Default)
+			if err != nil {
+				log.Printf("proto: bad default float32 %q: %v", prop.Default, err)
+				continue
+			}
+			sf.value = x
+		case reflect.Float64:
+			x, err := strconv.Atof64(prop.Default)
+			if err != nil {
+				log.Printf("proto: bad default float64 %q: %v", prop.Default, err)
+				continue
+			}
+			sf.value = x
+		case reflect.Int32:
+			x, err := strconv.Atoi64(prop.Default)
+			if err != nil {
+				log.Printf("proto: bad default int32 %q: %v", prop.Default, err)
+				continue
+			}
+			sf.value = int32(x)
+		case reflect.Int64:
+			x, err := strconv.Atoi64(prop.Default)
+			if err != nil {
+				log.Printf("proto: bad default int64 %q: %v", prop.Default, err)
+				continue
+			}
+			sf.value = x
+		case reflect.String:
+			sf.value = prop.Default
+		case reflect.Uint8:
+			// []byte (not *uint8)
+			sf.value = []byte(prop.Default)
+		case reflect.Uint32:
+			x, err := strconv.Atoui64(prop.Default)
+			if err != nil {
+				log.Printf("proto: bad default uint32 %q: %v", prop.Default, err)
+				continue
+			}
+			sf.value = uint32(x)
+		case reflect.Uint64:
+			x, err := strconv.Atoui64(prop.Default)
+			if err != nil {
+				log.Printf("proto: bad default uint64 %q: %v", prop.Default, err)
+				continue
+			}
+			sf.value = x
+		default:
+			log.Printf("proto: unhandled def kind %v", ft.Elem().Kind())
+			continue
+		}
+
+		dm.scalars = append(dm.scalars, sf)
+	}
+
+	return dm
+}
diff --git a/proto/testdata/test.proto b/proto/testdata/test.proto
index 70ff890..12f2dec 100644
--- a/proto/testdata/test.proto
+++ b/proto/testdata/test.proto
@@ -242,5 +242,35 @@
   repeated group Message = 1 {
     required string name = 2;
     required int32 count = 3;
-  };
-};
+  }
+}
+
+message Defaults {
+  enum Color {
+    RED = 0;
+    GREEN = 1;
+    BLUE = 2;
+  }
+
+  // Default-valued fields of all basic types.
+  // Same as GoTest, but copied here to make testing easier.
+  optional bool F_Bool = 1 [default=true];
+  optional int32 F_Int32 = 2 [default=32];
+  optional int64 F_Int64 = 3 [default=64];
+  optional fixed32 F_Fixed32 = 4 [default=320];
+  optional fixed64 F_Fixed64 = 5 [default=640];
+  optional uint32 F_Uint32 = 6 [default=3200];
+  optional uint64 F_Uint64 = 7 [default=6400];
+  optional float F_Float = 8 [default=314159.];
+  optional double F_Double = 9 [default=271828.];
+  optional string F_String = 10 [default="hello, \"world!\"\n"];
+  optional bytes F_Bytes = 11 [default="Bignose"];
+  optional sint32 F_Sint32 = 12 [default=-32];
+  optional sint64 F_Sint64 = 13 [default=-64];
+  optional Color F_Enum = 14 [default=GREEN];
+
+  // More fields with crazy defaults.
+  optional float F_Pinf = 15 [default=inf];
+  optional float F_Ninf = 16 [default=-inf];
+  optional float F_Nan = 17 [default=nan];
+}