encoding/text: marshal extensions

Change-Id: Ic4a0c5909fb6eca76d22053b143be58c60b67b34
Reviewed-on: https://go-review.googlesource.com/c/154657
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/encoding/textpb/encode.go b/encoding/textpb/encode.go
index 8d4b9cb..c0b04b6 100644
--- a/encoding/textpb/encode.go
+++ b/encoding/textpb/encode.go
@@ -79,49 +79,66 @@
 
 		tname := text.ValueOf(fd.Name())
 		pval := knownFields.Get(num)
-
-		if fd.Cardinality() == pref.Repeated {
-			// Map or repeated fields.
-			var items []text.Value
-			var err error
-			if fd.IsMap() {
-				items, err = o.marshalMap(pval.Map(), fd)
-				if !nerr.Merge(err) {
-					return text.Value{}, err
-				}
-			} else {
-				items, err = o.marshalList(pval.List(), fd)
-				if !nerr.Merge(err) {
-					return text.Value{}, err
-				}
-			}
-
-			// Add each item as key: value field.
-			for _, item := range items {
-				msgFields = append(msgFields, [2]text.Value{tname, item})
-			}
-		} else {
-			// Required or optional fields.
-			tval, err := o.marshalSingular(pval, fd)
-			if !nerr.Merge(err) {
-				return text.Value{}, err
-			}
-			msgFields = append(msgFields, [2]text.Value{tname, tval})
+		var err error
+		msgFields, err = o.appendField(msgFields, tname, pval, fd)
+		if !nerr.Merge(err) {
+			return text.Value{}, err
 		}
 	}
 
-	// Marshal out unknown fields.
+	// Handle extensions.
+	var err error
+	msgFields, err = o.appendExtensions(msgFields, knownFields)
+	if !nerr.Merge(err) {
+		return text.Value{}, err
+	}
+
+	// Handle unknown fields.
 	// TODO: Provide option to exclude or include unknown fields.
 	m.UnknownFields().Range(func(_ pref.FieldNumber, raw pref.RawFields) bool {
 		msgFields = appendUnknown(msgFields, raw)
 		return true
 	})
 
-	// TODO: Handle extensions and Any expansion.
-
 	return text.ValueOf(msgFields), nerr.E
 }
 
+// appendField marshals a protoreflect.Value and appends it to the given [][2]text.Value.
+func (o MarshalOptions) appendField(msgFields [][2]text.Value, tname text.Value, pval pref.Value, fd pref.FieldDescriptor) ([][2]text.Value, error) {
+	var nerr errors.NonFatal
+
+	if fd.Cardinality() == pref.Repeated {
+		// Map or repeated fields.
+		var items []text.Value
+		var err error
+		if fd.IsMap() {
+			items, err = o.marshalMap(pval.Map(), fd)
+			if !nerr.Merge(err) {
+				return msgFields, err
+			}
+		} else {
+			items, err = o.marshalList(pval.List(), fd)
+			if !nerr.Merge(err) {
+				return msgFields, err
+			}
+		}
+
+		// Add each item as key: value field.
+		for _, item := range items {
+			msgFields = append(msgFields, [2]text.Value{tname, item})
+		}
+	} else {
+		// Required or optional fields.
+		tval, err := o.marshalSingular(pval, fd)
+		if !nerr.Merge(err) {
+			return msgFields, err
+		}
+		msgFields = append(msgFields, [2]text.Value{tname, tval})
+	}
+
+	return msgFields, nerr.E
+}
+
 // marshalSingular converts a non-repeated field value to text.Value.
 // This includes all scalar types, enums, messages, and groups.
 func (o MarshalOptions) marshalSingular(val pref.Value, fd pref.FieldDescriptor) (text.Value, error) {
@@ -252,6 +269,41 @@
 	sort.Slice(values, less)
 }
 
+// appendExtensions marshals extension fields and appends them to the given [][2]text.Value.
+func (o MarshalOptions) appendExtensions(msgFields [][2]text.Value, knownFields pref.KnownFields) ([][2]text.Value, error) {
+	var nerr errors.NonFatal
+	xtTypes := knownFields.ExtensionTypes()
+	xtFields := make([][2]text.Value, 0, xtTypes.Len())
+
+	var err error
+	xtTypes.Range(func(xt pref.ExtensionType) bool {
+		// TODO: Handle MessageSet.  Field name should be message_set_extension
+		// of message type without any fields and has message option
+		// message_set_wire_format=true.
+
+		num := xt.Number()
+		if knownFields.Has(num) {
+			// Use string type to produce [name] format.
+			tname := text.ValueOf(string(xt.FullName()))
+			pval := knownFields.Get(num)
+			xtFields, err = o.appendField(xtFields, tname, pval, xt)
+			if err != nil {
+				return false
+			}
+		}
+		return true
+	})
+	if !nerr.Merge(err) {
+		return msgFields, err
+	}
+
+	// Sort extensions lexicographically and append to output.
+	sort.SliceStable(xtFields, func(i, j int) bool {
+		return xtFields[i][0].String() < xtFields[j][0].String()
+	})
+	return append(msgFields, xtFields...), nerr.E
+}
+
 // appendUnknown parses the given []byte and appends field(s) into the given fields slice.
 // This function assumes proper encoding in the given []byte.
 func appendUnknown(fields [][2]text.Value, b []byte) [][2]text.Value {