proto: fix DiscardUnknown

UnmarshalOptions.DiscardUnknown was simply not working. Oops. Fix it.
Add a test.

Change-Id: I76888eae1221d99a007f0e9cdb711d292e6856b1
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/216762
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index 4b1bc6d..74fd821 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -176,7 +176,7 @@
 			if n < 0 {
 				return out, wire.ParseError(n)
 			}
-			if mi.unknownOffset.IsValid() {
+			if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
 				u := p.Apply(mi.unknownOffset).Bytes()
 				*u = wire.AppendTag(*u, num, wtyp)
 				*u = append(*u, b[:n]...)
diff --git a/proto/decode.go b/proto/decode.go
index 83942ea..9a6b2f7 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -154,7 +154,9 @@
 			if valLen < 0 {
 				return wire.ParseError(valLen)
 			}
-			m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
+			if !o.DiscardUnknown {
+				m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
+			}
 		}
 		b = b[tagLen+valLen:]
 	}
diff --git a/proto/decode_test.go b/proto/decode_test.go
index 02f07d4..5ccb816 100644
--- a/proto/decode_test.go
+++ b/proto/decode_test.go
@@ -25,9 +25,8 @@
 		}
 		for _, want := range test.decodeTo {
 			t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
-				opts := proto.UnmarshalOptions{
-					AllowPartial: test.partial,
-				}
+				opts := test.unmarshalOptions
+				opts.AllowPartial = test.partial
 				wire := append(([]byte)(nil), test.wire...)
 				got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
 				if err := opts.Unmarshal(wire, got); err != nil {
@@ -55,6 +54,8 @@
 		}
 		for _, m := range test.decodeTo {
 			t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
+				opts := test.unmarshalOptions
+				opts.AllowPartial = false
 				got := reflect.New(reflect.TypeOf(m).Elem()).Interface().(proto.Message)
 				if err := proto.Unmarshal(test.wire, got); err == nil {
 					t.Fatalf("Unmarshal succeeded (want error)\nMessage:\n%v", marshalText(got))
@@ -71,9 +72,8 @@
 		}
 		for _, want := range test.decodeTo {
 			t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
-				opts := proto.UnmarshalOptions{
-					AllowPartial: test.partial,
-				}
+				opts := test.unmarshalOptions
+				opts.AllowPartial = test.partial
 				got := want.ProtoReflect().New().Interface()
 				if err := opts.Unmarshal(test.wire, got); err == nil {
 					t.Errorf("Unmarshal unexpectedly succeeded\ninput bytes: [%x]\nMessage:\n%v", test.wire, marshalText(got))
diff --git a/proto/testmessages_test.go b/proto/testmessages_test.go
index 8a7cc29..6f66380 100644
--- a/proto/testmessages_test.go
+++ b/proto/testmessages_test.go
@@ -5,10 +5,12 @@
 package proto_test
 
 import (
+	"google.golang.org/protobuf/encoding/prototext"
 	"google.golang.org/protobuf/internal/encoding/pack"
 	"google.golang.org/protobuf/internal/encoding/wire"
 	"google.golang.org/protobuf/internal/impl"
 	"google.golang.org/protobuf/proto"
+	"google.golang.org/protobuf/reflect/protoregistry"
 
 	legacypb "google.golang.org/protobuf/internal/testprotos/legacy"
 	legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2_20160225_2fc053c5"
@@ -24,6 +26,7 @@
 	partial          bool
 	noEncode         bool
 	checkFastInit    bool
+	unmarshalOptions proto.UnmarshalOptions
 	validationStatus impl.ValidationStatus
 }
 
@@ -1118,6 +1121,19 @@
 		}.Marshal(),
 	},
 	{
+		desc: "discarded unknown fields",
+		unmarshalOptions: proto.UnmarshalOptions{
+			DiscardUnknown: true,
+		},
+		decodeTo: []proto.Message{
+			&testpb.TestAllTypes{},
+			&test3pb.TestAllTypes{},
+		},
+		wire: pack.Message{
+			pack.Tag{100000, pack.VarintType}, pack.Varint(1),
+		}.Marshal(),
+	},
+	{
 		desc: "field type mismatch",
 		decodeTo: []proto.Message{build(
 			&testpb.TestAllTypes{},
@@ -1615,6 +1631,46 @@
 			pack.Tag{pack.LastReservedNumber, pack.VarintType}, pack.Varint(1005),
 		}.Marshal(),
 	},
+	{
+		desc: "nested unknown extension",
+		unmarshalOptions: proto.UnmarshalOptions{
+			DiscardUnknown: true,
+			Resolver: func() protoregistry.ExtensionTypeResolver {
+				types := &protoregistry.Types{}
+				types.RegisterExtension(testpb.E_OptionalNestedMessageExtension)
+				types.RegisterExtension(testpb.E_OptionalInt32Extension)
+				return types
+			}(),
+		},
+		decodeTo: []proto.Message{func() proto.Message {
+			m := &testpb.TestAllExtensions{}
+			if err := prototext.Unmarshal([]byte(`
+				[goproto.proto.test.optional_nested_message_extension]: {
+					corecursive: {
+						[goproto.proto.test.optional_nested_message_extension]: {
+							corecursive: {
+								[goproto.proto.test.optional_int32_extension]: 42
+							}
+						}
+					}
+				}`), m); err != nil {
+				panic(err)
+			}
+			return m
+		}()},
+		wire: pack.Message{
+			pack.Tag{18, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+					pack.Tag{18, pack.BytesType}, pack.LengthPrefix(pack.Message{
+						pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+							pack.Tag{1, pack.VarintType}, pack.Varint(42),
+							pack.Tag{2, pack.VarintType}, pack.Varint(43),
+						}),
+					}),
+				}),
+			}),
+		}.Marshal(),
+	},
 }
 
 var testInvalidMessages = []testProto{