internal/legacy: use v2 proto.Unmarshal
The v2 decoder isn't 100% complete, but it's good enough.
Delete the vendored copy of the v1 Unmarshal implementation.
Change-Id: Ibeabbb2e9109a1ec3df57e71f98b7aa4a583fc5b
Reviewed-on: https://go-review.googlesource.com/c/154577
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/proto/decode_test.go b/proto/decode_test.go
index 9560ab8..a922c5e 100644
--- a/proto/decode_test.go
+++ b/proto/decode_test.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style.
// license that can be found in the LICENSE file.
-package proto
+package proto_test
import (
"fmt"
@@ -14,12 +14,13 @@
_ "github.com/golang/protobuf/v2/internal/legacy"
"github.com/golang/protobuf/v2/internal/scalar"
testpb "github.com/golang/protobuf/v2/internal/testprotos/test"
+ "github.com/golang/protobuf/v2/proto"
pref "github.com/golang/protobuf/v2/reflect/protoreflect"
)
type testProto struct {
desc string
- decodeTo []Message
+ decodeTo []proto.Message
wire []byte
}
@@ -28,8 +29,8 @@
for _, want := range test.decodeTo {
t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
wire := append(([]byte)(nil), test.wire...)
- got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(Message)
- if err := Unmarshal(wire, got); err != nil {
+ got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
+ if err := proto.Unmarshal(wire, got); err != nil {
t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
return
}
@@ -51,7 +52,7 @@
var testProtos = []testProto{
{
desc: "basic scalar types",
- decodeTo: []Message{&testpb.TestAllTypes{
+ decodeTo: []proto.Message{&testpb.TestAllTypes{
OptionalInt32: scalar.Int32(1001),
OptionalInt64: scalar.Int64(1002),
OptionalUint32: scalar.Uint32(1003),
@@ -108,7 +109,7 @@
},
{
desc: "groups",
- decodeTo: []Message{&testpb.TestAllTypes{
+ decodeTo: []proto.Message{&testpb.TestAllTypes{
Optionalgroup: &testpb.TestAllTypes_OptionalGroup{
A: scalar.Int32(1017),
},
@@ -126,7 +127,7 @@
},
{
desc: "groups (field overridden)",
- decodeTo: []Message{&testpb.TestAllTypes{
+ decodeTo: []proto.Message{&testpb.TestAllTypes{
Optionalgroup: &testpb.TestAllTypes_OptionalGroup{
A: scalar.Int32(2),
},
@@ -147,7 +148,7 @@
},
{
desc: "messages",
- decodeTo: []Message{&testpb.TestAllTypes{
+ decodeTo: []proto.Message{&testpb.TestAllTypes{
OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
A: scalar.Int32(42),
Corecursive: &testpb.TestAllTypes{
@@ -174,7 +175,7 @@
},
{
desc: "messages (split across multiple tags)",
- decodeTo: []Message{&testpb.TestAllTypes{
+ decodeTo: []proto.Message{&testpb.TestAllTypes{
OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
A: scalar.Int32(42),
Corecursive: &testpb.TestAllTypes{
@@ -203,7 +204,7 @@
},
{
desc: "messages (field overridden)",
- decodeTo: []Message{&testpb.TestAllTypes{
+ decodeTo: []proto.Message{&testpb.TestAllTypes{
OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{
A: scalar.Int32(2),
},
@@ -224,7 +225,7 @@
},
{
desc: "basic repeated types",
- decodeTo: []Message{&testpb.TestAllTypes{
+ decodeTo: []proto.Message{&testpb.TestAllTypes{
RepeatedInt32: []int32{1001, 2001},
RepeatedInt64: []int64{1002, 2002},
RepeatedUint32: []uint32{1003, 2003},
@@ -303,7 +304,7 @@
},
{
desc: "basic repeated types (packed encoding)",
- decodeTo: []Message{&testpb.TestAllTypes{
+ decodeTo: []proto.Message{&testpb.TestAllTypes{
RepeatedInt32: []int32{1001, 2001},
RepeatedInt64: []int64{1002, 2002},
RepeatedUint32: []uint32{1003, 2003},
@@ -389,7 +390,7 @@
},
{
desc: "repeated messages",
- decodeTo: []Message{&testpb.TestAllTypes{
+ decodeTo: []proto.Message{&testpb.TestAllTypes{
RepeatedNestedMessage: []*testpb.TestAllTypes_NestedMessage{
{A: scalar.Int32(1)},
{A: scalar.Int32(2)},
@@ -412,7 +413,7 @@
},
{
desc: "repeated groups",
- decodeTo: []Message{&testpb.TestAllTypes{
+ decodeTo: []proto.Message{&testpb.TestAllTypes{
Repeatedgroup: []*testpb.TestAllTypes_RepeatedGroup{
{A: scalar.Int32(1017)},
{A: scalar.Int32(2017)},
@@ -435,7 +436,7 @@
},
{
desc: "maps",
- decodeTo: []Message{&testpb.TestAllTypes{
+ decodeTo: []proto.Message{&testpb.TestAllTypes{
MapInt32Int32: map[int32]int32{1056: 1156, 2056: 2156},
MapInt64Int64: map[int64]int64{1057: 1157, 2057: 2157},
MapUint32Uint32: map[uint32]uint32{1058: 1158, 2058: 2158},
@@ -605,12 +606,12 @@
},
{
desc: "oneof (uint32)",
- decodeTo: []Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofUint32{1111}}},
+ decodeTo: []proto.Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofUint32{1111}}},
wire: pack.Message{pack.Tag{111, pack.VarintType}, pack.Varint(1111)}.Marshal(),
},
{
desc: "oneof (message)",
- decodeTo: []Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofNestedMessage{
+ decodeTo: []proto.Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofNestedMessage{
&testpb.TestAllTypes_NestedMessage{A: scalar.Int32(1112)},
}}},
wire: pack.Message{pack.Tag{112, pack.BytesType}, pack.LengthPrefix(pack.Message{
@@ -619,7 +620,7 @@
},
{
desc: "oneof (overridden message)",
- decodeTo: []Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofNestedMessage{
+ decodeTo: []proto.Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofNestedMessage{
&testpb.TestAllTypes_NestedMessage{
Corecursive: &testpb.TestAllTypes{
OptionalInt32: scalar.Int32(43),
@@ -639,42 +640,42 @@
},
{
desc: "oneof (string)",
- decodeTo: []Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofString{"1113"}}},
+ decodeTo: []proto.Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofString{"1113"}}},
wire: pack.Message{pack.Tag{113, pack.BytesType}, pack.String("1113")}.Marshal(),
},
{
desc: "oneof (bytes)",
- decodeTo: []Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofBytes{[]byte("1114")}}},
+ decodeTo: []proto.Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofBytes{[]byte("1114")}}},
wire: pack.Message{pack.Tag{114, pack.BytesType}, pack.String("1114")}.Marshal(),
},
{
desc: "oneof (bool)",
- decodeTo: []Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofBool{true}}},
+ decodeTo: []proto.Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofBool{true}}},
wire: pack.Message{pack.Tag{115, pack.VarintType}, pack.Bool(true)}.Marshal(),
},
{
desc: "oneof (uint64)",
- decodeTo: []Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofUint64{116}}},
+ decodeTo: []proto.Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofUint64{116}}},
wire: pack.Message{pack.Tag{116, pack.VarintType}, pack.Varint(116)}.Marshal(),
},
{
desc: "oneof (float)",
- decodeTo: []Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofFloat{117.5}}},
+ decodeTo: []proto.Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofFloat{117.5}}},
wire: pack.Message{pack.Tag{117, pack.Fixed32Type}, pack.Float32(117.5)}.Marshal(),
},
{
desc: "oneof (double)",
- decodeTo: []Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofDouble{118.5}}},
+ decodeTo: []proto.Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofDouble{118.5}}},
wire: pack.Message{pack.Tag{118, pack.Fixed64Type}, pack.Float64(118.5)}.Marshal(),
},
{
desc: "oneof (enum)",
- decodeTo: []Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofEnum{testpb.TestAllTypes_BAR}}},
+ decodeTo: []proto.Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofEnum{testpb.TestAllTypes_BAR}}},
wire: pack.Message{pack.Tag{119, pack.VarintType}, pack.Varint(int(testpb.TestAllTypes_BAR))}.Marshal(),
},
{
desc: "oneof (overridden value)",
- decodeTo: []Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofUint64{2}}},
+ decodeTo: []proto.Message{&testpb.TestAllTypes{OneofField: &testpb.TestAllTypes_OneofUint64{2}}},
wire: pack.Message{
pack.Tag{111, pack.VarintType}, pack.Varint(1),
pack.Tag{116, pack.VarintType}, pack.Varint(2),
@@ -687,7 +688,7 @@
// that's a problem or not.
{
desc: "unknown fields",
- decodeTo: []Message{build(
+ decodeTo: []proto.Message{build(
&testpb.TestAllTypes{},
unknown(100000, pack.Message{
pack.Tag{100000, pack.VarintType}, pack.Varint(1),
@@ -699,7 +700,7 @@
},
{
desc: "field type mismatch",
- decodeTo: []Message{build(
+ decodeTo: []proto.Message{build(
&testpb.TestAllTypes{},
unknown(1, pack.Message{
pack.Tag{1, pack.BytesType}, pack.String("string"),
@@ -711,7 +712,7 @@
},
{
desc: "map field element mismatch",
- decodeTo: []Message{
+ decodeTo: []proto.Message{
&testpb.TestAllTypes{
MapInt32Int32: map[int32]int32{1: 0},
},
@@ -725,23 +726,23 @@
},
}
-func build(m Message, opts ...buildOpt) Message {
+func build(m proto.Message, opts ...buildOpt) proto.Message {
for _, opt := range opts {
opt(m)
}
return m
}
-type buildOpt func(Message)
+type buildOpt func(proto.Message)
func unknown(num pref.FieldNumber, raw pref.RawFields) buildOpt {
- return func(m Message) {
+ return func(m proto.Message) {
m.ProtoReflect().UnknownFields().Set(num, raw)
}
}
func extend(desc *protoV1.ExtensionDesc, value interface{}) buildOpt {
- return func(m Message) {
+ return func(m proto.Message) {
if err := protoV1.SetExtension(m.(protoV1.Message), desc, value); err != nil {
panic(err)
}