encoding/jsonpb: add support for marshaling of extensions and messagesets

Change-Id: I839660146760a66c5cbf25d24f81f0ba5096d9e1
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/167395
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/encoding/jsonpb/encode.go b/encoding/jsonpb/encode.go
index 4152a74..3830f31 100644
--- a/encoding/jsonpb/encode.go
+++ b/encoding/jsonpb/encode.go
@@ -13,6 +13,8 @@
 	"github.com/golang/protobuf/v2/internal/pragma"
 	"github.com/golang/protobuf/v2/proto"
 	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
+
+	descpb "github.com/golang/protobuf/v2/types/descriptor"
 )
 
 // Marshal writes the given proto.Message in JSON format using default options.
@@ -70,6 +72,7 @@
 	fieldDescs := m.Type().Fields()
 	knownFields := m.KnownFields()
 
+	// Marshal out known fields.
 	for i := 0; i < fieldDescs.Len(); i++ {
 		fd := fieldDescs.Get(i)
 		num := fd.Number()
@@ -92,6 +95,11 @@
 			return err
 		}
 	}
+
+	// Marshal out extensions.
+	if err := e.marshalExtensions(knownFields); !nerr.Merge(err) {
+		return err
+	}
 	return nerr.E
 }
 
@@ -222,7 +230,6 @@
 		if err := e.WriteName(entry.key.String()); !nerr.Merge(err) {
 			return err
 		}
-
 		if err := e.marshalSingular(entry.value, valType); !nerr.Merge(err) {
 			return err
 		}
@@ -230,22 +237,94 @@
 	return nerr.E
 }
 
-// sortMap orders list based on value of key field for deterministic output.
+// sortMap orders list based on value of key field for deterministic ordering.
 func sortMap(keyKind pref.Kind, values []mapEntry) {
-	less := func(i, j int) bool {
-		return values[i].key.String() < values[j].key.String()
-	}
-	switch keyKind {
-	case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind,
-		pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
-		less = func(i, j int) bool {
+	sort.Slice(values, func(i, j int) bool {
+		switch keyKind {
+		case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind,
+			pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
 			return values[i].key.Int() < values[j].key.Int()
-		}
-	case pref.Uint32Kind, pref.Fixed32Kind,
-		pref.Uint64Kind, pref.Fixed64Kind:
-		less = func(i, j int) bool {
+
+		case pref.Uint32Kind, pref.Fixed32Kind,
+			pref.Uint64Kind, pref.Fixed64Kind:
 			return values[i].key.Uint() < values[j].key.Uint()
 		}
+		return values[i].key.String() < values[j].key.String()
+	})
+}
+
+// marshalExtensions marshals extension fields.
+func (e encoder) marshalExtensions(knownFields pref.KnownFields) error {
+	type xtEntry struct {
+		key    string
+		value  pref.Value
+		xtType pref.ExtensionType
 	}
-	sort.Slice(values, less)
+
+	xtTypes := knownFields.ExtensionTypes()
+
+	// Get a sorted list based on field key first.
+	entries := make([]xtEntry, 0, xtTypes.Len())
+	xtTypes.Range(func(xt pref.ExtensionType) bool {
+		name := xt.FullName()
+		// If extended type is a MessageSet, set field name to be the message type name.
+		if isMessageSetExtension(xt) {
+			name = xt.MessageType().FullName()
+		}
+
+		num := xt.Number()
+		if knownFields.Has(num) {
+			// Use [name] format for JSON field name.
+			pval := knownFields.Get(num)
+			entries = append(entries, xtEntry{
+				key:    string(name),
+				value:  pval,
+				xtType: xt,
+			})
+		}
+		return true
+	})
+
+	// Sort extensions lexicographically.
+	sort.Slice(entries, func(i, j int) bool {
+		return entries[i].key < entries[j].key
+	})
+
+	// Write out sorted list.
+	var nerr errors.NonFatal
+	for _, entry := range entries {
+		// JSON field name is the proto field name enclosed in [], similar to
+		// textproto. This is consistent with Go v1 lib. C++ lib v3.7.0 does not
+		// marshal out extension fields.
+		if err := e.WriteName("[" + entry.key + "]"); !nerr.Merge(err) {
+			return err
+		}
+		if err := e.marshalValue(entry.value, entry.xtType); !nerr.Merge(err) {
+			return err
+		}
+	}
+	return nerr.E
+}
+
+// isMessageSetExtension reports whether extension extends a message set.
+func isMessageSetExtension(xt pref.ExtensionType) bool {
+	if xt.Name() != "message_set_extension" {
+		return false
+	}
+	mt := xt.MessageType()
+	if mt == nil {
+		return false
+	}
+	if xt.FullName().Parent() != mt.FullName() {
+		return false
+	}
+	xmt := xt.ExtendedType()
+	if xmt.Fields().Len() != 0 {
+		return false
+	}
+	opt := xmt.Options().(*descpb.MessageOptions)
+	if opt == nil {
+		return false
+	}
+	return opt.GetMessageSetWireFormat()
 }