proto: implement Merge

Change-Id: Ibb579bf5ad8548359dfd9805fd3022bcd53a6379
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/183679
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/proto/merge.go b/proto/merge.go
new file mode 100644
index 0000000..7cea77d
--- /dev/null
+++ b/proto/merge.go
@@ -0,0 +1,78 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package proto
+
+import "google.golang.org/protobuf/reflect/protoreflect"
+
+// Merge merges src into dst, which must be messages with the same descriptor.
+//
+// Populated scalar fields in src are copied to dst, while populated
+// singular messages in src are merged into dst by recursively calling Merge.
+// The elements of every list field in src is appended to the corresponded
+// list fields in dst. The entries of every map field in src is copied into
+// the corresponding map field in dst, possibly replacing existing entries.
+// The unknown fields of src are appended to the unknown fields of dst.
+func Merge(dst, src Message) {
+	mergeMessage(dst.ProtoReflect(), src.ProtoReflect())
+}
+
+func mergeMessage(dst, src protoreflect.Message) {
+	if dst.Descriptor() != src.Descriptor() {
+		panic("descriptor mismatch")
+	}
+
+	src.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
+		switch {
+		case fd.IsList():
+			mergeList(dst.Mutable(fd).List(), v.List(), fd)
+		case fd.IsMap():
+			mergeMap(dst.Mutable(fd).Map(), v.Map(), fd.MapValue())
+		case fd.Message() != nil:
+			mergeMessage(dst.Mutable(fd).Message(), v.Message())
+		case fd.Kind() == protoreflect.BytesKind:
+			dst.Set(fd, cloneBytes(v))
+		default:
+			dst.Set(fd, v)
+		}
+		return true
+	})
+
+	dst.SetUnknown(append(dst.GetUnknown(), src.GetUnknown()...))
+}
+
+func mergeList(dst, src protoreflect.List, fd protoreflect.FieldDescriptor) {
+	for i := 0; i < src.Len(); i++ {
+		switch v := src.Get(i); {
+		case fd.Message() != nil:
+			m := dst.NewMessage()
+			mergeMessage(m, v.Message())
+			dst.Append(protoreflect.ValueOf(m))
+		case fd.Kind() == protoreflect.BytesKind:
+			dst.Append(cloneBytes(v))
+		default:
+			dst.Append(v)
+		}
+	}
+}
+
+func mergeMap(dst, src protoreflect.Map, fd protoreflect.FieldDescriptor) {
+	src.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
+		switch {
+		case fd.Message() != nil:
+			m := dst.NewMessage()
+			mergeMessage(m, v.Message())
+			dst.Set(k, protoreflect.ValueOf(m)) // may replace existing entry
+		case fd.Kind() == protoreflect.BytesKind:
+			dst.Set(k, cloneBytes(v))
+		default:
+			dst.Set(k, v)
+		}
+		return true
+	})
+}
+
+func cloneBytes(v protoreflect.Value) protoreflect.Value {
+	return protoreflect.ValueOf(append([]byte{}, v.Bytes()...))
+}
diff --git a/proto/merge_test.go b/proto/merge_test.go
new file mode 100644
index 0000000..109a413
--- /dev/null
+++ b/proto/merge_test.go
@@ -0,0 +1,398 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package proto_test
+
+import (
+	"testing"
+
+	"google.golang.org/protobuf/internal/encoding/pack"
+	"google.golang.org/protobuf/internal/scalar"
+	"google.golang.org/protobuf/proto"
+
+	testpb "google.golang.org/protobuf/internal/testprotos/test"
+)
+
+func TestMerge(t *testing.T) {
+	dst := new(testpb.TestAllTypes)
+	src := (*testpb.TestAllTypes)(nil)
+	proto.Merge(dst, src)
+
+	// Mutating the source should not affect dst.
+
+	tests := []struct {
+		desc    string
+		dst     proto.Message
+		src     proto.Message
+		want    proto.Message
+		mutator func(proto.Message) // if provided, is run on src after merging
+
+		skipMarshalUnmarshal bool // TODO: Remove this when proto.Unmarshal is fixed for messages in oneofs
+	}{{
+		desc: "merge from nil message",
+		dst:  new(testpb.TestAllTypes),
+		src:  (*testpb.TestAllTypes)(nil),
+		want: new(testpb.TestAllTypes),
+	}, {
+		desc: "clone a large message",
+		dst:  new(testpb.TestAllTypes),
+		src: &testpb.TestAllTypes{
+			OptionalInt64:      scalar.Int64(0),
+			OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(1).Enum(),
+			OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
+				A: scalar.Int32(100),
+			},
+			RepeatedSfixed32: []int32{1, 2, 3},
+			RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
+				{A: scalar.Int32(200)},
+				{A: scalar.Int32(300)},
+			},
+			MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
+				"fizz": 400,
+				"buzz": 500,
+			},
+			MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
+				"foo": {A: scalar.Int32(600)},
+				"bar": {A: scalar.Int32(700)},
+			},
+			OneofField: &testpb.TestAllTypes_OneofNestedMessage{
+				&testpb.TestAllTypes_NestedMessage{
+					A: scalar.Int32(800),
+				},
+			},
+		},
+		want: &testpb.TestAllTypes{
+			OptionalInt64:      scalar.Int64(0),
+			OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(1).Enum(),
+			OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
+				A: scalar.Int32(100),
+			},
+			RepeatedSfixed32: []int32{1, 2, 3},
+			RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
+				{A: scalar.Int32(200)},
+				{A: scalar.Int32(300)},
+			},
+			MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
+				"fizz": 400,
+				"buzz": 500,
+			},
+			MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
+				"foo": {A: scalar.Int32(600)},
+				"bar": {A: scalar.Int32(700)},
+			},
+			OneofField: &testpb.TestAllTypes_OneofNestedMessage{
+				&testpb.TestAllTypes_NestedMessage{
+					A: scalar.Int32(800),
+				},
+			},
+		},
+		mutator: func(mi proto.Message) {
+			m := mi.(*testpb.TestAllTypes)
+			*m.OptionalInt64++
+			*m.OptionalNestedEnum++
+			*m.OptionalNestedMessage.A++
+			m.RepeatedSfixed32[0]++
+			*m.RepeatedNestedMessage[0].A++
+			delete(m.MapStringNestedEnum, "fizz")
+			*m.MapStringNestedMessage["foo"].A++
+			*m.OneofField.(*testpb.TestAllTypes_OneofNestedMessage).OneofNestedMessage.A++
+		},
+	}, {
+		desc: "merge bytes",
+		dst: &testpb.TestAllTypes{
+			OptionalBytes:  []byte{1, 2, 3},
+			RepeatedBytes:  [][]byte{{1, 2}, {3, 4}},
+			MapStringBytes: map[string][]byte{"alpha": {1, 2, 3}},
+		},
+		src: &testpb.TestAllTypes{
+			OptionalBytes:  []byte{4, 5, 6},
+			RepeatedBytes:  [][]byte{{5, 6}, {7, 8}},
+			MapStringBytes: map[string][]byte{"alpha": {4, 5, 6}, "bravo": {1, 2, 3}},
+		},
+		want: &testpb.TestAllTypes{
+			OptionalBytes:  []byte{4, 5, 6},
+			RepeatedBytes:  [][]byte{{1, 2}, {3, 4}, {5, 6}, {7, 8}},
+			MapStringBytes: map[string][]byte{"alpha": {4, 5, 6}, "bravo": {1, 2, 3}},
+		},
+		mutator: func(mi proto.Message) {
+			m := mi.(*testpb.TestAllTypes)
+			m.OptionalBytes[0]++
+			m.RepeatedBytes[0][0]++
+			m.MapStringBytes["alpha"][0]++
+		},
+	}, {
+		desc: "merge singular fields",
+		dst: &testpb.TestAllTypes{
+			OptionalInt32:      scalar.Int32(1),
+			OptionalInt64:      scalar.Int64(1),
+			OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(10).Enum(),
+			OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
+				A: scalar.Int32(100),
+				Corecursive: &testpb.TestAllTypes{
+					OptionalInt64: scalar.Int64(1000),
+				},
+			},
+		},
+		src: &testpb.TestAllTypes{
+			OptionalInt64:      scalar.Int64(2),
+			OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(20).Enum(),
+			OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
+				A: scalar.Int32(200),
+			},
+		},
+		want: &testpb.TestAllTypes{
+			OptionalInt32:      scalar.Int32(1),
+			OptionalInt64:      scalar.Int64(2),
+			OptionalNestedEnum: testpb.TestAllTypes_NestedEnum(20).Enum(),
+			OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
+				A: scalar.Int32(200),
+				Corecursive: &testpb.TestAllTypes{
+					OptionalInt64: scalar.Int64(1000),
+				},
+			},
+		},
+		mutator: func(mi proto.Message) {
+			m := mi.(*testpb.TestAllTypes)
+			*m.OptionalInt64++
+			*m.OptionalNestedEnum++
+			*m.OptionalNestedMessage.A++
+		},
+	}, {
+		desc: "merge list fields",
+		dst: &testpb.TestAllTypes{
+			RepeatedSfixed32: []int32{1, 2, 3},
+			RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
+				{A: scalar.Int32(100)},
+				{A: scalar.Int32(200)},
+			},
+		},
+		src: &testpb.TestAllTypes{
+			RepeatedSfixed32: []int32{4, 5, 6},
+			RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
+				{A: scalar.Int32(300)},
+				{A: scalar.Int32(400)},
+			},
+		},
+		want: &testpb.TestAllTypes{
+			RepeatedSfixed32: []int32{1, 2, 3, 4, 5, 6},
+			RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
+				{A: scalar.Int32(100)},
+				{A: scalar.Int32(200)},
+				{A: scalar.Int32(300)},
+				{A: scalar.Int32(400)},
+			},
+		},
+		mutator: func(mi proto.Message) {
+			m := mi.(*testpb.TestAllTypes)
+			m.RepeatedSfixed32[0]++
+			*m.RepeatedNestedMessage[0].A++
+		},
+	}, {
+		desc: "merge map fields",
+		dst: &testpb.TestAllTypes{
+			MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
+				"fizz": 100,
+				"buzz": 200,
+				"guzz": 300,
+			},
+			MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
+				"foo": {A: scalar.Int32(400)},
+			},
+		},
+		src: &testpb.TestAllTypes{
+			MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
+				"fizz": 1000,
+				"buzz": 2000,
+			},
+			MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
+				"foo": {A: scalar.Int32(3000)},
+				"bar": {},
+			},
+		},
+		want: &testpb.TestAllTypes{
+			MapStringNestedEnum: map[string]testpb.TestAllTypes_NestedEnum{
+				"fizz": 1000,
+				"buzz": 2000,
+				"guzz": 300,
+			},
+			MapStringNestedMessage: map[string]*testpb.TestAllTypes_NestedMessage{
+				"foo": {A: scalar.Int32(3000)},
+				"bar": {},
+			},
+		},
+		mutator: func(mi proto.Message) {
+			m := mi.(*testpb.TestAllTypes)
+			delete(m.MapStringNestedEnum, "fizz")
+			m.MapStringNestedMessage["bar"].A = scalar.Int32(1)
+		},
+	}, {
+		desc: "merge oneof message fields",
+		dst: &testpb.TestAllTypes{
+			OneofField: &testpb.TestAllTypes_OneofNestedMessage{
+				&testpb.TestAllTypes_NestedMessage{
+					A: scalar.Int32(100),
+				},
+			},
+		},
+		src: &testpb.TestAllTypes{
+			OneofField: &testpb.TestAllTypes_OneofNestedMessage{
+				&testpb.TestAllTypes_NestedMessage{
+					Corecursive: &testpb.TestAllTypes{
+						OptionalInt64: scalar.Int64(1000),
+					},
+				},
+			},
+		},
+		want: &testpb.TestAllTypes{
+			OneofField: &testpb.TestAllTypes_OneofNestedMessage{
+				&testpb.TestAllTypes_NestedMessage{
+					A: scalar.Int32(100),
+					Corecursive: &testpb.TestAllTypes{
+						OptionalInt64: scalar.Int64(1000),
+					},
+				},
+			},
+		},
+		mutator: func(mi proto.Message) {
+			m := mi.(*testpb.TestAllTypes)
+			*m.OneofField.(*testpb.TestAllTypes_OneofNestedMessage).OneofNestedMessage.Corecursive.OptionalInt64++
+		},
+		skipMarshalUnmarshal: true,
+	}, {
+		desc: "merge oneof scalar fields",
+		dst: &testpb.TestAllTypes{
+			OneofField: &testpb.TestAllTypes_OneofUint32{100},
+		},
+		src: &testpb.TestAllTypes{
+			OneofField: &testpb.TestAllTypes_OneofFloat{3.14152},
+		},
+		want: &testpb.TestAllTypes{
+			OneofField: &testpb.TestAllTypes_OneofFloat{3.14152},
+		},
+		mutator: func(mi proto.Message) {
+			m := mi.(*testpb.TestAllTypes)
+			m.OneofField.(*testpb.TestAllTypes_OneofFloat).OneofFloat++
+		},
+	}, {
+		desc: "merge extension fields",
+		dst: func() proto.Message {
+			m := new(testpb.TestAllExtensions)
+			m.ProtoReflect().Set(
+				testpb.E_OptionalInt32Extension.Type,
+				testpb.E_OptionalInt32Extension.Type.ValueOf(int32(32)),
+			)
+			m.ProtoReflect().Set(
+				testpb.E_OptionalNestedMessageExtension.Type,
+				testpb.E_OptionalNestedMessageExtension.Type.ValueOf(&testpb.TestAllTypes_NestedMessage{
+					A: scalar.Int32(50),
+				}),
+			)
+			m.ProtoReflect().Set(
+				testpb.E_RepeatedFixed32Extension.Type,
+				testpb.E_RepeatedFixed32Extension.Type.ValueOf(&[]uint32{1, 2, 3}),
+			)
+			return m
+		}(),
+		src: func() proto.Message {
+			m := new(testpb.TestAllExtensions)
+			m.ProtoReflect().Set(
+				testpb.E_OptionalInt64Extension.Type,
+				testpb.E_OptionalInt64Extension.Type.ValueOf(int64(64)),
+			)
+			m.ProtoReflect().Set(
+				testpb.E_OptionalNestedMessageExtension.Type,
+				testpb.E_OptionalNestedMessageExtension.Type.ValueOf(&testpb.TestAllTypes_NestedMessage{
+					Corecursive: &testpb.TestAllTypes{
+						OptionalInt64: scalar.Int64(1000),
+					},
+				}),
+			)
+			m.ProtoReflect().Set(
+				testpb.E_RepeatedFixed32Extension.Type,
+				testpb.E_RepeatedFixed32Extension.Type.ValueOf(&[]uint32{4, 5, 6}),
+			)
+			return m
+		}(),
+		want: func() proto.Message {
+			m := new(testpb.TestAllExtensions)
+			m.ProtoReflect().Set(
+				testpb.E_OptionalInt32Extension.Type,
+				testpb.E_OptionalInt32Extension.Type.ValueOf(int32(32)),
+			)
+			m.ProtoReflect().Set(
+				testpb.E_OptionalInt64Extension.Type,
+				testpb.E_OptionalInt64Extension.Type.ValueOf(int64(64)),
+			)
+			m.ProtoReflect().Set(
+				testpb.E_OptionalNestedMessageExtension.Type,
+				testpb.E_OptionalNestedMessageExtension.Type.ValueOf(&testpb.TestAllTypes_NestedMessage{
+					A: scalar.Int32(50),
+					Corecursive: &testpb.TestAllTypes{
+						OptionalInt64: scalar.Int64(1000),
+					},
+				}),
+			)
+			m.ProtoReflect().Set(
+				testpb.E_RepeatedFixed32Extension.Type,
+				testpb.E_RepeatedFixed32Extension.Type.ValueOf(&[]uint32{1, 2, 3, 4, 5, 6}),
+			)
+			return m
+		}(),
+	}, {
+		desc: "merge unknown fields",
+		dst: func() proto.Message {
+			m := new(testpb.TestAllTypes)
+			m.ProtoReflect().SetUnknown(pack.Message{
+				pack.Tag{Number: 50000, Type: pack.VarintType}, pack.Svarint(-5),
+			}.Marshal())
+			return m
+		}(),
+		src: func() proto.Message {
+			m := new(testpb.TestAllTypes)
+			m.ProtoReflect().SetUnknown(pack.Message{
+				pack.Tag{Number: 500000, Type: pack.VarintType}, pack.Svarint(-50),
+			}.Marshal())
+			return m
+		}(),
+		want: func() proto.Message {
+			m := new(testpb.TestAllTypes)
+			m.ProtoReflect().SetUnknown(pack.Message{
+				pack.Tag{Number: 50000, Type: pack.VarintType}, pack.Svarint(-5),
+				pack.Tag{Number: 500000, Type: pack.VarintType}, pack.Svarint(-50),
+			}.Marshal())
+			return m
+		}(),
+	}}
+
+	for _, tt := range tests {
+		t.Run(tt.desc, func(t *testing.T) {
+			// Merge should be semantically equivalent to unmarshaling the
+			// encoded form of src into the current dst.
+			b1, err := proto.MarshalOptions{AllowPartial: true}.Marshal(tt.dst)
+			if err != nil {
+				t.Fatalf("Marshal(dst) error: %v", err)
+			}
+			b2, err := proto.MarshalOptions{AllowPartial: true}.Marshal(tt.src)
+			if err != nil {
+				t.Fatalf("Marshal(src) error: %v", err)
+			}
+			dst := tt.dst.ProtoReflect().New().Interface()
+			err = proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(append(b1, b2...), dst)
+			if err != nil {
+				t.Fatalf("Unmarshal() error: %v", err)
+			}
+			if !proto.Equal(dst, tt.want) && !tt.skipMarshalUnmarshal {
+				t.Fatalf("Unmarshal(Marshal(dst)+Marshal(src)) mismatch: got %v, want %v", dst, tt.want)
+			}
+
+			proto.Merge(tt.dst, tt.src)
+			if tt.mutator != nil {
+				tt.mutator(tt.src) // should not be observable by dst
+			}
+			if !proto.Equal(tt.dst, tt.want) {
+				t.Fatalf("Merge() mismatch: got %v, want %v", tt.dst, tt.want)
+			}
+		})
+	}
+}