encoding/text: marshal extensions

Change-Id: Ic4a0c5909fb6eca76d22053b143be58c60b67b34
Reviewed-on: https://go-review.googlesource.com/c/154657
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/encoding/textpb/encode_test.go b/encoding/textpb/encode_test.go
index 55f3e77..8b7974e 100644
--- a/encoding/textpb/encode_test.go
+++ b/encoding/textpb/encode_test.go
@@ -9,9 +9,12 @@
 	"strings"
 	"testing"
 
+	"github.com/golang/protobuf/protoapi"
 	"github.com/golang/protobuf/v2/encoding/textpb"
 	"github.com/golang/protobuf/v2/internal/detrand"
 	"github.com/golang/protobuf/v2/internal/encoding/pack"
+	"github.com/golang/protobuf/v2/internal/encoding/wire"
+	"github.com/golang/protobuf/v2/internal/legacy"
 	"github.com/golang/protobuf/v2/internal/scalar"
 	"github.com/golang/protobuf/v2/proto"
 	"github.com/google/go-cmp/cmp"
@@ -47,6 +50,18 @@
 	return p
 }
 
+func setExtension(m proto.Message, xd *protoapi.ExtensionDesc, val interface{}) {
+	xt := legacy.Export{}.ExtensionTypeFromDesc(xd)
+	knownFields := m.ProtoReflect().KnownFields()
+	extTypes := knownFields.ExtensionTypes()
+	extTypes.Register(xt)
+	if val == nil {
+		return
+	}
+	pval := xt.ValueOf(val)
+	knownFields.Set(wire.Number(xd.Field), pval)
+}
+
 func TestMarshal(t *testing.T) {
 	tests := []struct {
 		desc    string
@@ -794,6 +809,161 @@
 102: "hello"
 102: "世界"
 `,
+	}, {
+		desc: "extensions of non-repeated fields",
+		input: func() proto.Message {
+			m := &pb2.Extensions{
+				OptString: scalar.String("non-extension field"),
+				OptBool:   scalar.Bool(true),
+				OptInt32:  scalar.Int32(42),
+			}
+			setExtension(m, pb2.E_OptExtBool, true)
+			setExtension(m, pb2.E_OptExtString, "extension field")
+			setExtension(m, pb2.E_OptExtEnum, pb2.Enum_TENTH)
+			setExtension(m, pb2.E_OptExtNested, &pb2.Nested{
+				OptString: scalar.String("nested in an extension"),
+				OptNested: &pb2.Nested{
+					OptString: scalar.String("another nested in an extension"),
+				},
+			})
+			return m
+		}(),
+		want: `opt_string: "non-extension field"
+opt_bool: true
+opt_int32: 42
+[pb2.opt_ext_bool]: true
+[pb2.opt_ext_enum]: TENTH
+[pb2.opt_ext_nested]: {
+  opt_string: "nested in an extension"
+  opt_nested: {
+    opt_string: "another nested in an extension"
+  }
+}
+[pb2.opt_ext_string]: "extension field"
+`,
+	}, {
+		desc: "registered extension but not set",
+		input: func() proto.Message {
+			m := &pb2.Extensions{}
+			setExtension(m, pb2.E_OptExtNested, nil)
+			return m
+		}(),
+		want: "\n",
+	}, {
+		desc: "extensions of repeated fields",
+		input: func() proto.Message {
+			m := &pb2.Extensions{}
+			setExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TENTH, 101, pb2.Enum_FIRST})
+			setExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47})
+			setExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{
+				&pb2.Nested{OptString: scalar.String("one")},
+				&pb2.Nested{OptString: scalar.String("two")},
+				&pb2.Nested{OptString: scalar.String("three")},
+			})
+			return m
+		}(),
+		want: `[pb2.rpt_ext_enum]: TENTH
+[pb2.rpt_ext_enum]: 101
+[pb2.rpt_ext_enum]: FIRST
+[pb2.rpt_ext_fixed32]: 42
+[pb2.rpt_ext_fixed32]: 47
+[pb2.rpt_ext_nested]: {
+  opt_string: "one"
+}
+[pb2.rpt_ext_nested]: {
+  opt_string: "two"
+}
+[pb2.rpt_ext_nested]: {
+  opt_string: "three"
+}
+`,
+	}, {
+		desc: "extensions of non-repeated fields in another message",
+		input: func() proto.Message {
+			m := &pb2.Extensions{}
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true)
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field")
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TENTH)
+			setExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{
+				OptString: scalar.String("nested in an extension"),
+				OptNested: &pb2.Nested{
+					OptString: scalar.String("another nested in an extension"),
+				},
+			})
+			return m
+		}(),
+		want: `[pb2.ExtensionsContainer.opt_ext_bool]: true
+[pb2.ExtensionsContainer.opt_ext_enum]: TENTH
+[pb2.ExtensionsContainer.opt_ext_nested]: {
+  opt_string: "nested in an extension"
+  opt_nested: {
+    opt_string: "another nested in an extension"
+  }
+}
+[pb2.ExtensionsContainer.opt_ext_string]: "extension field"
+`,
+	}, {
+		desc: "extensions of repeated fields in another message",
+		input: func() proto.Message {
+			m := &pb2.Extensions{
+				OptString: scalar.String("non-extension field"),
+				OptBool:   scalar.Bool(true),
+				OptInt32:  scalar.Int32(42),
+			}
+			setExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TENTH, 101, pb2.Enum_FIRST})
+			setExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"})
+			setExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{
+				&pb2.Nested{OptString: scalar.String("one")},
+				&pb2.Nested{OptString: scalar.String("two")},
+				&pb2.Nested{OptString: scalar.String("three")},
+			})
+			return m
+		}(),
+		want: `opt_string: "non-extension field"
+opt_bool: true
+opt_int32: 42
+[pb2.ExtensionsContainer.rpt_ext_enum]: TENTH
+[pb2.ExtensionsContainer.rpt_ext_enum]: 101
+[pb2.ExtensionsContainer.rpt_ext_enum]: FIRST
+[pb2.ExtensionsContainer.rpt_ext_nested]: {
+  opt_string: "one"
+}
+[pb2.ExtensionsContainer.rpt_ext_nested]: {
+  opt_string: "two"
+}
+[pb2.ExtensionsContainer.rpt_ext_nested]: {
+  opt_string: "three"
+}
+[pb2.ExtensionsContainer.rpt_ext_string]: "hello"
+[pb2.ExtensionsContainer.rpt_ext_string]: "world"
+`,
+		/* TODO: test for MessageSet
+		   	}, {
+		   		desc: "MessageSet",
+		   		input: func() proto.Message {
+		   			m := &pb2.MessageSet{}
+		   			setExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{
+		   				OptString: scalar.String("a messageset extension"),
+		   			})
+		   			setExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{
+		   				OptString: scalar.String("not a messageset extension"),
+		   			})
+		   			setExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{
+		   				OptString: scalar.String("just a regular extension"),
+		   			})
+		   			return m
+		   		}(),
+		   		want: `[pb2.MessageSetExtension]: {
+		     opt_string: "a messageset extension"
+		   }
+		   [pb2.MessageSetExtension.ext_nested]: {
+		     opt_string: "just a regular extension"
+		   }
+		   [pb2.MessageSetExtension.not_message_set_extension]: {
+		     opt_string: "not a messageset extension"
+		   }
+		   `,
+		*/
 	}}
 
 	for _, tt := range tests {