goprotobuf: Add Merge.
R=adg
CC=golang-dev
https://codereview.appspot.com/7227053
diff --git a/proto/all_test.go b/proto/all_test.go
index c3b2378..9c9cf0f 100644
--- a/proto/all_test.go
+++ b/proto/all_test.go
@@ -1427,7 +1427,7 @@
}
}
-func TestMerge(t *testing.T) {
+func TestMergeMessages(t *testing.T) {
pb := &MessageList{Message: []*MessageList_Message{{Name: String("x"), Count: Int32(1)}}}
data, err := Marshal(pb)
if err != nil {
diff --git a/proto/clone.go b/proto/clone.go
index a5ae9ce..6e8bbfb 100644
--- a/proto/clone.go
+++ b/proto/clone.go
@@ -48,22 +48,44 @@
}
out := reflect.New(in.Type().Elem())
- copyStruct(out.Elem(), in.Elem())
+ // out is empty so a merge is a deep copy.
+ mergeStruct(out.Elem(), in.Elem())
return out.Interface().(Message)
}
-func copyStruct(out, in reflect.Value) {
+// Merge merges src into dst.
+// Required and optional fields that are set in src will be set to that value in dst.
+// Elements of repeated fields will be appended.
+// Merge panics if src and dst are not the same type, or if dst is nil.
+func Merge(dst, src Message) {
+ in := reflect.ValueOf(src)
+ out := reflect.ValueOf(dst)
+ if out.IsNil() {
+ panic("proto: nil destination")
+ }
+ if in.Type() != out.Type() {
+ // Explicit test prior to mergeStruct so that mistyped nils will fail
+ panic("proto: type mismatch")
+ }
+ if in.IsNil() {
+ // Merging nil into non-nil is a quiet no-op
+ return
+ }
+ mergeStruct(out.Elem(), in.Elem())
+}
+
+func mergeStruct(out, in reflect.Value) {
for i := 0; i < in.NumField(); i++ {
f := in.Type().Field(i)
if strings.HasPrefix(f.Name, "XXX_") {
continue
}
- copyAny(out.Field(i), in.Field(i))
+ mergeAny(out.Field(i), in.Field(i))
}
if emIn, ok := in.Addr().Interface().(extendableProto); ok {
emOut := out.Addr().Interface().(extendableProto)
- copyExtension(emOut.ExtensionMap(), emIn.ExtensionMap())
+ mergeExtension(emOut.ExtensionMap(), emIn.ExtensionMap())
}
// Groups don't have XXX_unrecognized fields.
@@ -77,10 +99,14 @@
}
}
-func copyAny(out, in reflect.Value) {
+func mergeAny(out, in reflect.Value) {
if in.Type() == protoMessageType {
if !in.IsNil() {
- out.Set(reflect.ValueOf(Clone(in.Interface().(Message))))
+ if out.IsNil() {
+ out.Set(reflect.ValueOf(Clone(in.Interface().(Message))))
+ } else {
+ Merge(out.Interface().(Message), in.Interface().(Message))
+ }
}
return
}
@@ -92,37 +118,43 @@
if in.IsNil() {
return
}
- out.Set(reflect.New(in.Type().Elem()))
- copyAny(out.Elem(), in.Elem())
+ if out.IsNil() {
+ out.Set(reflect.New(in.Elem().Type()))
+ }
+ mergeAny(out.Elem(), in.Elem())
case reflect.Slice:
if in.IsNil() {
return
}
n := in.Len()
- out.Set(reflect.MakeSlice(in.Type(), n, n))
+ if out.IsNil() {
+ out.Set(reflect.MakeSlice(in.Type(), 0, n))
+ }
switch in.Type().Elem().Kind() {
case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64,
reflect.String, reflect.Uint32, reflect.Uint64, reflect.Uint8:
- reflect.Copy(out, in)
+ out.Set(reflect.AppendSlice(out, in))
default:
for i := 0; i < n; i++ {
- copyAny(out.Index(i), in.Index(i))
+ x := reflect.Indirect(reflect.New(in.Type().Elem()))
+ mergeAny(x, in.Index(i))
+ out.Set(reflect.Append(out, x))
}
}
case reflect.Struct:
- copyStruct(out, in)
+ mergeStruct(out, in)
default:
// unknown type, so not a protocol buffer
log.Printf("proto: don't know how to copy %v", in)
}
}
-func copyExtension(out, in map[int32]Extension) {
+func mergeExtension(out, in map[int32]Extension) {
for extNum, eIn := range in {
eOut := Extension{desc: eIn.desc}
if eIn.value != nil {
v := reflect.New(reflect.TypeOf(eIn.value)).Elem()
- copyAny(v, reflect.ValueOf(eIn.value))
+ mergeAny(v, reflect.ValueOf(eIn.value))
eOut.value = v.Interface()
}
if eIn.enc != nil {
diff --git a/proto/clone_test.go b/proto/clone_test.go
index 26bf0e7..2823c15 100644
--- a/proto/clone_test.go
+++ b/proto/clone_test.go
@@ -87,3 +87,94 @@
t.Errorf("Clone(%v) = %v", m, c)
}
}
+
+var mergeTests = []struct {
+ src, dst, want proto.Message
+}{
+ {
+ src: &pb.MyMessage{
+ Count: proto.Int32(42),
+ },
+ dst: &pb.MyMessage{
+ Name: proto.String("Dave"),
+ },
+ want: &pb.MyMessage{
+ Count: proto.Int32(42),
+ Name: proto.String("Dave"),
+ },
+ },
+ {
+ src: &pb.MyMessage{
+ Inner: &pb.InnerMessage{
+ Host: proto.String("hey"),
+ Connected: proto.Bool(true),
+ },
+ Pet: []string{"horsey"},
+ Others: []*pb.OtherMessage{
+ &pb.OtherMessage{
+ Value: []byte("some bytes"),
+ },
+ },
+ },
+ dst: &pb.MyMessage{
+ Inner: &pb.InnerMessage{
+ Host: proto.String("niles"),
+ Port: proto.Int32(9099),
+ },
+ Pet: []string{"bunny", "kitty"},
+ Others: []*pb.OtherMessage{
+ &pb.OtherMessage{
+ Key: proto.Int64(31415926535),
+ },
+ &pb.OtherMessage{
+ // Explicitly test a src=nil field
+ Inner: nil,
+ },
+ },
+ },
+ want: &pb.MyMessage{
+ Inner: &pb.InnerMessage{
+ Host: proto.String("hey"),
+ Connected: proto.Bool(true),
+ Port: proto.Int32(9099),
+ },
+ Pet: []string{"bunny", "kitty", "horsey"},
+ Others: []*pb.OtherMessage{
+ &pb.OtherMessage{
+ Key: proto.Int64(31415926535),
+ },
+ &pb.OtherMessage{},
+ &pb.OtherMessage{
+ Value: []byte("some bytes"),
+ },
+ },
+ },
+ },
+ {
+ src: &pb.MyMessage{
+ RepBytes: [][]byte{[]byte("wow")},
+ },
+ dst: &pb.MyMessage{
+ Somegroup: &pb.MyMessage_SomeGroup{
+ GroupField: proto.Int32(6),
+ },
+ RepBytes: [][]byte{[]byte("sham")},
+ },
+ want: &pb.MyMessage{
+ Somegroup: &pb.MyMessage_SomeGroup{
+ GroupField: proto.Int32(6),
+ },
+ RepBytes: [][]byte{[]byte("sham"), []byte("wow")},
+ },
+ },
+}
+
+func TestMerge(t *testing.T) {
+ for _, m := range mergeTests {
+ got := proto.Clone(m.dst)
+ proto.Merge(got, m.src)
+ if !proto.Equal(got, m.want) {
+ t.Errorf("Merge(%v, %v)\n got %v\nwant %v\n", m.dst, m.src, got, m.want)
+ }
+ }
+}