encoding/jsonpb: add support for marshaling of extensions and messagesets

Change-Id: I839660146760a66c5cbf25d24f81f0ba5096d9e1
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/167395
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/encoding/jsonpb/encode_test.go b/encoding/jsonpb/encode_test.go
index 43e7529..280e8db 100644
--- a/encoding/jsonpb/encode_test.go
+++ b/encoding/jsonpb/encode_test.go
@@ -9,13 +9,18 @@
 	"strings"
 	"testing"
 
+	"github.com/golang/protobuf/protoapi"
 	"github.com/golang/protobuf/v2/encoding/jsonpb"
 	"github.com/golang/protobuf/v2/internal/encoding/pack"
+	"github.com/golang/protobuf/v2/internal/encoding/wire"
 	"github.com/golang/protobuf/v2/internal/scalar"
 	"github.com/golang/protobuf/v2/proto"
 	"github.com/google/go-cmp/cmp"
 	"github.com/google/go-cmp/cmp/cmpopts"
 
+	// This legacy package is still needed when importing legacy message.
+	_ "github.com/golang/protobuf/v2/internal/legacy"
+
 	"github.com/golang/protobuf/v2/encoding/testprotos/pb2"
 	"github.com/golang/protobuf/v2/encoding/testprotos/pb3"
 )
@@ -37,6 +42,17 @@
 	return p
 }
 
+func setExtension(m proto.Message, xd *protoapi.ExtensionDesc, val interface{}) {
+	knownFields := m.ProtoReflect().KnownFields()
+	extTypes := knownFields.ExtensionTypes()
+	extTypes.Register(xd.Type)
+	if val == nil {
+		return
+	}
+	pval := xd.Type.ValueOf(val)
+	knownFields.Set(wire.Number(xd.Field), pval)
+}
+
 func TestMarshal(t *testing.T) {
 	tests := []struct {
 		desc  string
@@ -701,12 +717,209 @@
 		want: `{
   "foo_bar": "json_name"
 }`,
+	}, {
+		desc: "extensions of non-repeated fields",
+		input: func() proto.Message {
+			m := &pb2.Extensions{
+				OptString: scalar.String("non-extension field"),
+				OptBool:   scalar.Bool(true),
+				OptInt32:  scalar.Int32(42),
+			}
+			setExtension(m, pb2.E_OptExtBool, true)
+			setExtension(m, pb2.E_OptExtString, "extension field")
+			setExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN)
+			setExtension(m, pb2.E_OptExtNested, &pb2.Nested{
+				OptString: scalar.String("nested in an extension"),
+				OptNested: &pb2.Nested{
+					OptString: scalar.String("another nested in an extension"),
+				},
+			})
+			return m
+		}(),
+		want: `{
+  "optString": "non-extension field",
+  "optBool": true,
+  "optInt32": 42,
+  "[pb2.opt_ext_bool]": true,
+  "[pb2.opt_ext_enum]": "TEN",
+  "[pb2.opt_ext_nested]": {
+    "optString": "nested in an extension",
+    "optNested": {
+      "optString": "another nested in an extension"
+    }
+  },
+  "[pb2.opt_ext_string]": "extension field"
+}`,
+	}, {
+		desc: "extension message field set to nil",
+		input: func() proto.Message {
+			m := &pb2.Extensions{}
+			setExtension(m, pb2.E_OptExtNested, nil)
+			return m
+		}(),
+		want: "{}",
+	}, {
+		desc: "extensions of repeated fields",
+		input: func() proto.Message {
+			m := &pb2.Extensions{}
+			setExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+			setExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47})
+			setExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{
+				&pb2.Nested{OptString: scalar.String("one")},
+				&pb2.Nested{OptString: scalar.String("two")},
+				&pb2.Nested{OptString: scalar.String("three")},
+			})
+			return m
+		}(),
+		want: `{
+  "[pb2.rpt_ext_enum]": [
+    "TEN",
+    101,
+    "ONE"
+  ],
+  "[pb2.rpt_ext_fixed32]": [
+    42,
+    47
+  ],
+  "[pb2.rpt_ext_nested]": [
+    {
+      "optString": "one"
+    },
+    {
+      "optString": "two"
+    },
+    {
+      "optString": "three"
+    }
+  ]
+}`,
+	}, {
+		desc: "extensions of non-repeated fields in another message",
+		input: func() proto.Message {
+			m := &pb2.Extensions{}
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true)
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field")
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN)
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{
+				OptString: scalar.String("nested in an extension"),
+				OptNested: &pb2.Nested{
+					OptString: scalar.String("another nested in an extension"),
+				},
+			})
+			return m
+		}(),
+		want: `{
+  "[pb2.ExtensionsContainer.opt_ext_bool]": true,
+  "[pb2.ExtensionsContainer.opt_ext_enum]": "TEN",
+  "[pb2.ExtensionsContainer.opt_ext_nested]": {
+    "optString": "nested in an extension",
+    "optNested": {
+      "optString": "another nested in an extension"
+    }
+  },
+  "[pb2.ExtensionsContainer.opt_ext_string]": "extension field"
+}`,
+	}, {
+		desc: "extensions of repeated fields in another message",
+		input: func() proto.Message {
+			m := &pb2.Extensions{
+				OptString: scalar.String("non-extension field"),
+				OptBool:   scalar.Bool(true),
+				OptInt32:  scalar.Int32(42),
+			}
+			setExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE})
+			setExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"})
+			setExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{
+				&pb2.Nested{OptString: scalar.String("one")},
+				&pb2.Nested{OptString: scalar.String("two")},
+				&pb2.Nested{OptString: scalar.String("three")},
+			})
+			return m
+		}(),
+		want: `{
+  "optString": "non-extension field",
+  "optBool": true,
+  "optInt32": 42,
+  "[pb2.ExtensionsContainer.rpt_ext_enum]": [
+    "TEN",
+    101,
+    "ONE"
+  ],
+  "[pb2.ExtensionsContainer.rpt_ext_nested]": [
+    {
+      "optString": "one"
+    },
+    {
+      "optString": "two"
+    },
+    {
+      "optString": "three"
+    }
+  ],
+  "[pb2.ExtensionsContainer.rpt_ext_string]": [
+    "hello",
+    "world"
+  ]
+}`,
+	}, {
+		desc: "MessageSet",
+		input: func() proto.Message {
+			m := &pb2.MessageSet{}
+			setExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{
+				OptString: scalar.String("a messageset extension"),
+			})
+			setExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{
+				OptString: scalar.String("not a messageset extension"),
+			})
+			setExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{
+				OptString: scalar.String("just a regular extension"),
+			})
+			return m
+		}(),
+		want: `{
+  "[pb2.MessageSetExtension]": {
+    "optString": "a messageset extension"
+  },
+  "[pb2.MessageSetExtension.ext_nested]": {
+    "optString": "just a regular extension"
+  },
+  "[pb2.MessageSetExtension.not_message_set_extension]": {
+    "optString": "not a messageset extension"
+  }
+}`,
+	}, {
+		desc: "not real MessageSet 1",
+		input: func() proto.Message {
+			m := &pb2.FakeMessageSet{}
+			setExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{
+				OptString: scalar.String("not a messageset extension"),
+			})
+			return m
+		}(),
+		want: `{
+  "[pb2.FakeMessageSetExtension.message_set_extension]": {
+    "optString": "not a messageset extension"
+  }
+}`,
+	}, {
+		desc: "not real MessageSet 2",
+		input: func() proto.Message {
+			m := &pb2.MessageSet{}
+			setExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{
+				OptString: scalar.String("another not a messageset extension"),
+			})
+			return m
+		}(),
+		want: `{
+  "[pb2.message_set_extension]": {
+    "optString": "another not a messageset extension"
+  }
+}`,
 	}}
 
 	for _, tt := range tests {
 		tt := tt
 		t.Run(tt.desc, func(t *testing.T) {
-			t.Parallel()
 			b, err := tt.mo.Marshal(tt.input)
 			if err != nil {
 				t.Errorf("Marshal() returned error: %v\n", err)