goprotobuf: Various improvements to extension handling.

R=r
CC=golang-dev
http://codereview.appspot.com/4917043
diff --git a/compiler/generator/generator.go b/compiler/generator/generator.go
index c572df3..760c5bf 100644
--- a/compiler/generator/generator.go
+++ b/compiler/generator/generator.go
@@ -271,7 +271,7 @@
 	if ms.hasExtensions {
 		g.P("func (*", ms.sym, ") ExtensionRangeArray() []", g.ProtoPkg, ".ExtensionRange ",
 			"{ return (*", remoteSym, ")(nil).ExtensionRangeArray() }")
-		g.P("func (this *", ms.sym, ") ExtensionMap() map[int32][]byte ",
+		g.P("func (this *", ms.sym, ") ExtensionMap() map[int32]", g.ProtoPkg, ".Extension ",
 			"{ return (*", remoteSym, ")(this).ExtensionMap() }")
 		if ms.isMessageSet {
 			g.P("func (this *", ms.sym, ") Marshal() ([]byte, os.Error) ",
@@ -1132,7 +1132,7 @@
 		g.RecordTypeUse(proto.GetString(field.TypeName))
 	}
 	if len(message.ExtensionRange) > 0 {
-		g.P("XXX_extensions\t\tmap[int32][]byte")
+		g.P("XXX_extensions\t\tmap[int32]", g.ProtoPkg, ".Extension")
 	}
 	g.P("XXX_unrecognized\t[]byte")
 	g.Out()
@@ -1170,7 +1170,7 @@
 		g.In()
 		for _, r := range message.ExtensionRange {
 			end := fmt.Sprint(*r.End - 1) // make range inclusive on both ends
-			g.P(g.ProtoPkg+".ExtensionRange{", r.Start, ", ", end, "},")
+			g.P("{", r.Start, ", ", end, "},")
 		}
 		g.Out()
 		g.P("}")
@@ -1179,11 +1179,11 @@
 		g.P("return extRange_", ccTypeName)
 		g.Out()
 		g.P("}")
-		g.P("func (this *", ccTypeName, ") ExtensionMap() map[int32][]byte {")
+		g.P("func (this *", ccTypeName, ") ExtensionMap() map[int32]", g.ProtoPkg, ".Extension {")
 		g.In()
 		g.P("if this.XXX_extensions == nil {")
 		g.In()
-		g.P("this.XXX_extensions = make(map[int32][]byte)")
+		g.P("this.XXX_extensions = make(map[int32]", g.ProtoPkg, ".Extension)")
 		g.Out()
 		g.P("}")
 		g.P("return this.XXX_extensions")
@@ -1264,7 +1264,7 @@
 	g.P("ExtendedType: (", extendedType, ")(nil),")
 	g.P("ExtensionType: (", fieldType, ")(nil),")
 	g.P("Field: ", field.Number, ",")
-	g.P(`Name: "`, g.packageName, ".", *field.Name, `",`)
+	g.P(`Name: "`, g.packageName, ".", strings.Join(ext.TypeName(), "."), `",`)
 	g.P("Tag: ", tag, ",")
 
 	g.Out()
diff --git a/compiler/testdata/test.pb.go.golden b/compiler/testdata/test.pb.go.golden
index 4d210b8..1295e07 100644
--- a/compiler/testdata/test.pb.go.golden
+++ b/compiler/testdata/test.pb.go.golden
@@ -145,7 +145,7 @@
 type Reply struct {
 	Found            []*Reply_Entry `protobuf:"bytes,1,rep,name=found" json:"found"`
 	CompactKeys      []int32        `protobuf:"varint,2,rep,packed,name=compact_keys" json:"compact_keys"`
-	XXX_extensions   map[int32][]byte
+	XXX_extensions   map[int32]proto.Extension
 	XXX_unrecognized []byte
 }
 
@@ -153,15 +153,15 @@
 func (this *Reply) String() string { return proto.CompactTextString(this) }
 
 var extRange_Reply = []proto.ExtensionRange{
-	proto.ExtensionRange{100, 536870911},
+	{100, 536870911},
 }
 
 func (*Reply) ExtensionRangeArray() []proto.ExtensionRange {
 	return extRange_Reply
 }
-func (this *Reply) ExtensionMap() map[int32][]byte {
+func (this *Reply) ExtensionMap() map[int32]proto.Extension {
 	if this.XXX_extensions == nil {
-		this.XXX_extensions = make(map[int32][]byte)
+		this.XXX_extensions = make(map[int32]proto.Extension)
 	}
 	return this.XXX_extensions
 }
@@ -189,12 +189,12 @@
 	ExtendedType:  (*Reply)(nil),
 	ExtensionType: (*float64)(nil),
 	Field:         101,
-	Name:          "my_test.time",
+	Name:          "my_test.ReplyExtensions.time",
 	Tag:           "fixed64,101,opt,name=time",
 }
 
 type OldReply struct {
-	XXX_extensions   map[int32][]byte
+	XXX_extensions   map[int32]proto.Extension
 	XXX_unrecognized []byte
 }
 
@@ -212,15 +212,15 @@
 var _ proto.Unmarshaler = (*OldReply)(nil)
 
 var extRange_OldReply = []proto.ExtensionRange{
-	proto.ExtensionRange{100, 536870911},
+	{100, 536870911},
 }
 
 func (*OldReply) ExtensionRangeArray() []proto.ExtensionRange {
 	return extRange_OldReply
 }
-func (this *OldReply) ExtensionMap() map[int32][]byte {
+func (this *OldReply) ExtensionMap() map[int32]proto.Extension {
 	if this.XXX_extensions == nil {
-		this.XXX_extensions = make(map[int32][]byte)
+		this.XXX_extensions = make(map[int32]proto.Extension)
 	}
 	return this.XXX_extensions
 }
diff --git a/proto/decode.go b/proto/decode.go
index 1c2bb3b..fe28821 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -358,7 +358,7 @@
 			iv := unsafe.Unreflect(t, unsafe.Pointer(&o.ptr))
 			if e, ok := iv.(extendableProto); ok && isExtensionField(e, int32(tag)) {
 				if err = o.skip(st, tag, wire); err == nil {
-					e.ExtensionMap()[int32(tag)] = append([]byte(nil), o.buf[oi:o.index]...)
+					e.ExtensionMap()[int32(tag)] = Extension{enc: append([]byte(nil), o.buf[oi:o.index]...)}
 				}
 				continue
 			}
diff --git a/proto/encode.go b/proto/encode.go
index 3952dc4..5736153 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -576,9 +576,12 @@
 
 // Encode an extension map.
 func (o *Buffer) enc_map(p *Properties, base uintptr) os.Error {
-	v := *(*map[int32][]byte)(unsafe.Pointer(base + p.offset))
-	for _, b := range v {
-		o.buf = append(o.buf, b...)
+	v := *(*map[int32]Extension)(unsafe.Pointer(base + p.offset))
+	if err := encodeExtensionMap(v); err != nil {
+		return err
+	}
+	for _, e := range v {
+		o.buf = append(o.buf, e.enc...)
 	}
 	return nil
 }
diff --git a/proto/extensions.go b/proto/extensions.go
index f44d741..e0ea0bc 100644
--- a/proto/extensions.go
+++ b/proto/extensions.go
@@ -52,7 +52,7 @@
 // extendableProto is an interface implemented by any protocol buffer that may be extended.
 type extendableProto interface {
 	ExtensionRangeArray() []ExtensionRange
-	ExtensionMap() map[int32][]byte
+	ExtensionMap() map[int32]Extension
 }
 
 // ExtensionDesc represents an extension specification.
@@ -65,6 +65,29 @@
 	Tag           string      // protobuf tag style
 }
 
+/*
+Extension represents an extension in a message.
+
+When an extension is stored in a message using SetExtension
+only desc and value are set. When the message is marshaled
+enc will be set to the encoded form of the message.
+
+When a message is unmarshaled and contains extensions, each
+extension will have only enc set. When such an extension is
+accessed using GetExtension (or GetExtensions) desc and value
+will be set.
+*/
+type Extension struct {
+	desc  *ExtensionDesc
+	value interface{}
+	enc   []byte
+}
+
+// SetRawExtension is for testing only.
+func SetRawExtension(base extendableProto, id int32, b []byte) {
+	base.ExtensionMap()[id] = Extension{enc: b}
+}
+
 // isExtensionField returns true iff the given field number is in an extension range.
 func isExtensionField(pb extendableProto, field int32) bool {
 	for _, er := range pb.ExtensionRangeArray() {
@@ -88,6 +111,36 @@
 	return nil
 }
 
+// encodeExtensionMap encodes any unmarshaled (unencoded) extensions in m.
+func encodeExtensionMap(m map[int32]Extension) os.Error {
+	for k, e := range m {
+		if e.value == nil || e.desc == nil {
+			// Extension is only in its encoded form.
+			continue
+		}
+
+		// We don't skip extensions that have an encoded form set,
+		// because the extension value may have been mutated after
+		// the last time this function was called.
+
+		et := reflect.TypeOf(e.desc.ExtensionType)
+		props := new(Properties)
+		props.Init(et, "unknown_name", e.desc.Tag, 0)
+
+		p := NewBuffer(nil)
+		// The encoder must be passed a pointer to e.value.
+		// Allocate a copy of value so that we can use its address.
+		x := reflect.New(et)
+		x.Elem().Set(reflect.ValueOf(e.value))
+		if err := props.enc(p, props, x.Pointer()); err != nil {
+			return err
+		}
+		e.enc = p.buf
+		m[k] = e
+	}
+	return nil
+}
+
 // HasExtension returns whether the given extension is present in pb.
 func HasExtension(pb extendableProto, extension *ExtensionDesc) bool {
 	// TODO: Check types, field numbers, etc.?
@@ -98,7 +151,7 @@
 // ClearExtension removes the given extension from pb.
 func ClearExtension(pb extendableProto, extension *ExtensionDesc) {
 	// TODO: Check types, field numbers, etc.?
-	pb.ExtensionMap()[extension.Field] = nil, false
+	pb.ExtensionMap()[extension.Field] = Extension{}, false
 }
 
 // GetExtension parses and returns the given extension of pb.
@@ -108,14 +161,24 @@
 		return nil, err
 	}
 
-	b, ok := pb.ExtensionMap()[extension.Field]
+	e, ok := pb.ExtensionMap()[extension.Field]
 	if !ok {
 		return nil, nil // not an error
 	}
+	if e.value != nil {
+		// Already decoded. Check the descriptor, though.
+		if e.desc != extension {
+			// This shouldn't happen. If it does, it means that
+			// GetExtension was called twice with two different
+			// descriptors with the same field number.
+			return nil, os.NewError("proto: descriptor conflict")
+		}
+		return e.value, nil
+	}
 
 	// Discard wire type and field number varint. It isn't needed.
-	_, n := DecodeVarint(b)
-	o := NewBuffer(b[n:])
+	_, n := DecodeVarint(e.enc)
+	o := NewBuffer(e.enc[n:])
 
 	t := reflect.TypeOf(extension.ExtensionType)
 	props := &Properties{}
@@ -132,7 +195,12 @@
 	if err := props.dec(o, props, uintptr(base), sbase); err != nil {
 		return nil, err
 	}
-	return unsafe.Unreflect(t, base), nil
+	// Remember the decoded version and drop the encoded version.
+	// That way it is safe to mutate what we return.
+	e.value = unsafe.Unreflect(t, base)
+	e.desc = extension
+	e.enc = nil
+	return e.value, nil
 }
 
 // GetExtensions returns a slice of the extensions present in pb that are also listed in es.
@@ -167,18 +235,7 @@
 		return os.NewError("bad extension value type")
 	}
 
-	props := new(Properties)
-	props.Init(reflect.TypeOf(extension.ExtensionType), "unknown_name", extension.Tag, 0)
-
-	p := NewBuffer(nil)
-	// The encoder must be passed a pointer to value.
-	// Allocate a copy of value so that we can use its address.
-	x := reflect.New(typ)
-	x.Elem().Set(reflect.ValueOf(value))
-	if err := props.enc(p, props, x.Pointer()); err != nil {
-		return err
-	}
-	pb.ExtensionMap()[extension.Field] = p.buf
+	pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value}
 	return nil
 }
 
diff --git a/proto/message_set.go b/proto/message_set.go
index 036a8aa..05e18a8 100644
--- a/proto/message_set.go
+++ b/proto/message_set.go
@@ -38,6 +38,7 @@
 import (
 	"bytes"
 	"os"
+	"reflect"
 )
 
 // ErrNoMessageTypeId occurs when a protocol buffer does not have a message type ID.
@@ -145,16 +146,20 @@
 
 // MarshalMessageSet encodes the extension map represented by m in the message set wire format.
 // It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option.
-func MarshalMessageSet(m map[int32][]byte) ([]byte, os.Error) {
+func MarshalMessageSet(m map[int32]Extension) ([]byte, os.Error) {
+	if err := encodeExtensionMap(m); err != nil {
+		return nil, err
+	}
+
 	ms := &MessageSet{Item: make([]*_MessageSet_Item, len(m))}
 	i := 0
-	for k, v := range m {
+	for k, e := range m {
 		// Remove the wire type and field number varint, as well as the length varint.
-		v = skipVarint(skipVarint(v))
+		msg := skipVarint(skipVarint(e.enc))
 
 		ms.Item[i] = &_MessageSet_Item{
 			TypeId:  Int32(k),
-			Message: v,
+			Message: msg,
 		}
 		i++
 	}
@@ -163,7 +168,7 @@
 
 // UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format.
 // It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
-func UnmarshalMessageSet(buf []byte, m map[int32][]byte) os.Error {
+func UnmarshalMessageSet(buf []byte, m map[int32]Extension) os.Error {
 	ms := new(MessageSet)
 	if err := Unmarshal(buf, ms); err != nil {
 		return err
@@ -174,7 +179,24 @@
 		b = append(b, EncodeVarint(uint64(len(item.Message)))...)
 		b = append(b, item.Message...)
 
-		m[*item.TypeId] = b
+		m[*item.TypeId] = Extension{enc: b}
 	}
 	return nil
 }
+
+// A global registry of types that can be used in a MessageSet.
+
+var messageSetMap = make(map[int32]messageSetDesc)
+
+type messageSetDesc struct {
+	t    reflect.Type // pointer to struct
+	name string
+}
+
+// RegisterMessageSetType is called from the generated code.
+func RegisterMessageSetType(i messageTypeIder, name string) {
+	messageSetMap[i.MessageTypeId()] = messageSetDesc{
+		t:    reflect.TypeOf(i),
+		name: name,
+	}
+}
diff --git a/proto/text.go b/proto/text.go
index 54a89f7..e1a08a9 100644
--- a/proto/text.go
+++ b/proto/text.go
@@ -32,7 +32,6 @@
 package proto
 
 // Functions for writing the text protocol buffer format.
-// TODO: message sets.
 
 import (
 	"bytes"
@@ -41,6 +40,7 @@
 	"log"
 	"os"
 	"reflect"
+	"sort"
 	"strconv"
 	"strings"
 )
@@ -105,17 +105,33 @@
 	}
 }
 
-var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem()
+var (
+	messageSetType      = reflect.TypeOf((*MessageSet)(nil)).Elem()
+	extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem()
+)
 
 func writeStruct(w *textWriter, sv reflect.Value) {
+	if sv.Type() == messageSetType {
+		writeMessageSet(w, sv.Addr().Interface().(*MessageSet))
+		return
+	}
+
 	st := sv.Type()
 	sprops := GetProperties(st)
 	for i := 0; i < sv.NumField(); i++ {
-		if strings.HasPrefix(st.Field(i).Name, "XXX_") {
+		fv := sv.Field(i)
+		if name := st.Field(i).Name; strings.HasPrefix(name, "XXX_") {
+			// There's only two XXX_ fields:
+			//   XXX_unrecognized []byte
+			//   XXX_extensions   map[int32]proto.Extension
+			// The first is handled here;
+			// the second is handled at the bottom of this function.
+			if name == "XXX_unrecognized" && !fv.IsNil() {
+				writeUnknownStruct(w, fv.Interface().([]byte))
+			}
 			continue
 		}
 		props := sprops.Prop[i]
-		fv := sv.Field(i)
 		if fv.Kind() == reflect.Ptr && fv.IsNil() {
 			// Field not filled in. This could be an optional field or
 			// a required field that wasn't filled in. Either way, there
@@ -152,7 +168,7 @@
 		w.WriteByte('\n')
 	}
 
-	// Extensions.
+	// Extensions (the XXX_extensions field).
 	pv := sv.Addr()
 	if pv.Type().Implements(extendableProtoType) {
 		writeExtensions(w, pv)
@@ -211,19 +227,121 @@
 	}
 }
 
+func writeMessageSet(w *textWriter, ms *MessageSet) {
+	for _, item := range ms.Item {
+		id := *item.TypeId
+		if msd, ok := messageSetMap[id]; ok {
+			// Known message set type.
+			fmt.Fprintf(w, "[%s]: <\n", msd.name)
+			w.indent()
+
+			pb := reflect.New(msd.t.Elem())
+			if err := Unmarshal(item.Message, pb.Interface()); err != nil {
+				fmt.Fprintf(w, "/* bad message: %v */\n", err)
+			} else {
+				writeStruct(w, pb.Elem())
+			}
+		} else {
+			// Unknown type.
+			fmt.Fprintf(w, "[%d]: <\n", id)
+			w.indent()
+			writeUnknownStruct(w, item.Message)
+		}
+		w.unindent()
+		w.Write([]byte(">\n"))
+	}
+}
+
+func writeUnknownStruct(w *textWriter, data []byte) {
+	if !w.compact {
+		fmt.Fprintf(w, "/* %d unknown bytes */\n", len(data))
+	}
+	b := NewBuffer(data)
+	for b.index < len(b.buf) {
+		x, err := b.DecodeVarint()
+		if err != nil {
+			fmt.Fprintf(w, "/* %v */\n", err)
+			return
+		}
+		wire, tag := x&7, x>>3
+		if wire == WireEndGroup {
+			w.unindent()
+			w.Write([]byte("}\n"))
+			continue
+		}
+		fmt.Fprintf(w, "tag%d", tag)
+		if wire != WireStartGroup {
+			w.WriteByte(':')
+		}
+		if !w.compact || wire == WireStartGroup {
+			w.WriteByte(' ')
+		}
+		switch wire {
+		case WireBytes:
+			buf, err := b.DecodeRawBytes(false)
+			if err == nil {
+				fmt.Fprintf(w, "%q", buf)
+			} else {
+				fmt.Fprintf(w, "/* %v */", err)
+			}
+		case WireFixed32:
+			x, err := b.DecodeFixed32()
+			writeUnknownInt(w, x, err)
+		case WireFixed64:
+			x, err := b.DecodeFixed64()
+			writeUnknownInt(w, x, err)
+		case WireStartGroup:
+			fmt.Fprint(w, "{")
+			w.indent()
+		case WireVarint:
+			x, err := b.DecodeVarint()
+			writeUnknownInt(w, x, err)
+		default:
+			fmt.Fprintf(w, "/* unknown wire type %d */", wire)
+		}
+		w.WriteByte('\n')
+	}
+}
+
+func writeUnknownInt(w *textWriter, x uint64, err os.Error) {
+	if err == nil {
+		fmt.Fprint(w, x)
+	} else {
+		fmt.Fprintf(w, "/* %v */", err)
+	}
+}
+
+type int32Slice []int32
+
+func (s int32Slice) Len() int           { return len(s) }
+func (s int32Slice) Less(i, j int) bool { return s[i] < s[j] }
+func (s int32Slice) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
+
 // writeExtensions writes all the extensions in pv.
 // pv is assumed to be a pointer to a protocol message struct that is extendable.
 func writeExtensions(w *textWriter, pv reflect.Value) {
 	emap := extensionMaps[pv.Type().Elem()]
 	ep := pv.Interface().(extendableProto)
-	for extNum := range ep.ExtensionMap() {
+
+	// Order the extensions by ID.
+	// This isn't strictly necessary, but it will give us
+	// canonical output, which will also make testing easier.
+	m := ep.ExtensionMap()
+	ids := make([]int32, 0, len(m))
+	for id := range m {
+		ids = append(ids, id)
+	}
+	sort.Sort(int32Slice(ids))
+
+	for _, extNum := range ids {
+		ext := m[extNum]
 		var desc *ExtensionDesc
 		if emap != nil {
 			desc = emap[extNum]
 		}
 		if desc == nil {
-			// TODO: Handle printing unknown extensions.
-			fmt.Fprintln(os.Stderr, "proto: unknown extension: ", extNum)
+			// Unknown extension.
+			writeUnknownStruct(w, ext.enc)
 			continue
 		}
 
diff --git a/proto/text_test.go b/proto/text_test.go
index 6c3f546..00f17f8 100644
--- a/proto/text_test.go
+++ b/proto/text_test.go
@@ -33,6 +33,7 @@
 
 import (
 	"bytes"
+	"strings"
 	"testing"
 
 	"goprotobuf.googlecode.com/hg/proto"
@@ -68,6 +69,9 @@
 		Somegroup: &pb.MyMessage_SomeGroup{
 			GroupField: proto.Int32(8),
 		},
+		// One normally wouldn't do this.
+		// This is an undeclared tag 13, as a varint (wire type 0) with value 4.
+		XXX_unrecognized: []byte{13<<3 | 0, 4},
 	}
 	ext := &pb.Ext{
 		Data: proto.String("Big gobs for big rats"),
@@ -75,6 +79,19 @@
 	if err := proto.SetExtension(msg, pb.E_Ext_More, ext); err != nil {
 		panic(err)
 	}
+
+	// Add an unknown extension. We marshal a pb.Ext, and fake the ID.
+	b, err := proto.Marshal(&pb.Ext{Data: proto.String("3G skiing")})
+	if err != nil {
+		panic(err)
+	}
+	b = append(proto.EncodeVarint(104<<3|proto.WireBytes), b...)
+	proto.SetRawExtension(msg, 104, b)
+
+	// Extensions can be plain fields, too, so let's test that.
+	b = append(proto.EncodeVarint(105<<3|proto.WireVarint), 19)
+	proto.SetRawExtension(msg, 105, b)
+
 	return msg
 }
 
@@ -104,9 +121,15 @@
 SomeGroup {
   group_field: 8
 }
-[test_proto.more]: <
+/* 2 unknown bytes */
+tag13: 4
+[test_proto.Ext.more]: <
   data: "Big gobs for big rats"
 >
+/* 13 unknown bytes */
+tag104: "\t3G skiing"
+/* 3 unknown bytes */
+tag105: 19
 `
 
 func TestMarshalTextFull(t *testing.T) {
@@ -121,9 +144,22 @@
 func compact(src string) string {
 	// s/[ \n]+/ /g; s/ $//;
 	dst := make([]byte, len(src))
-	space := false
+	space, comment := false, false
 	j := 0
 	for i := 0; i < len(src); i++ {
+		if strings.HasPrefix(src[i:], "/*") {
+			comment = true
+			i++
+			continue
+		}
+		if comment && strings.HasPrefix(src[i:], "*/") {
+			comment = false
+			i++
+			continue
+		}
+		if comment {
+			continue
+		}
 		c := src[i]
 		if c == ' ' || c == '\n' {
 			space = true
@@ -158,3 +194,5 @@
 		t.Errorf("Got:\n===\n%v===\nExpected:\n===\n%v===\n", s, compactText)
 	}
 }
+
+