goprotobuf: Various improvements to extension handling.

R=r
CC=golang-dev
http://codereview.appspot.com/4917043
diff --git a/proto/extensions.go b/proto/extensions.go
index f44d741..e0ea0bc 100644
--- a/proto/extensions.go
+++ b/proto/extensions.go
@@ -52,7 +52,7 @@
 // extendableProto is an interface implemented by any protocol buffer that may be extended.
 type extendableProto interface {
 	ExtensionRangeArray() []ExtensionRange
-	ExtensionMap() map[int32][]byte
+	ExtensionMap() map[int32]Extension
 }
 
 // ExtensionDesc represents an extension specification.
@@ -65,6 +65,29 @@
 	Tag           string      // protobuf tag style
 }
 
+/*
+Extension represents an extension in a message.
+
+When an extension is stored in a message using SetExtension
+only desc and value are set. When the message is marshaled
+enc will be set to the encoded form of the message.
+
+When a message is unmarshaled and contains extensions, each
+extension will have only enc set. When such an extension is
+accessed using GetExtension (or GetExtensions) desc and value
+will be set.
+*/
+type Extension struct {
+	desc  *ExtensionDesc
+	value interface{}
+	enc   []byte
+}
+
+// SetRawExtension is for testing only.
+func SetRawExtension(base extendableProto, id int32, b []byte) {
+	base.ExtensionMap()[id] = Extension{enc: b}
+}
+
 // isExtensionField returns true iff the given field number is in an extension range.
 func isExtensionField(pb extendableProto, field int32) bool {
 	for _, er := range pb.ExtensionRangeArray() {
@@ -88,6 +111,36 @@
 	return nil
 }
 
+// encodeExtensionMap encodes any unmarshaled (unencoded) extensions in m.
+func encodeExtensionMap(m map[int32]Extension) os.Error {
+	for k, e := range m {
+		if e.value == nil || e.desc == nil {
+			// Extension is only in its encoded form.
+			continue
+		}
+
+		// We don't skip extensions that have an encoded form set,
+		// because the extension value may have been mutated after
+		// the last time this function was called.
+
+		et := reflect.TypeOf(e.desc.ExtensionType)
+		props := new(Properties)
+		props.Init(et, "unknown_name", e.desc.Tag, 0)
+
+		p := NewBuffer(nil)
+		// The encoder must be passed a pointer to e.value.
+		// Allocate a copy of value so that we can use its address.
+		x := reflect.New(et)
+		x.Elem().Set(reflect.ValueOf(e.value))
+		if err := props.enc(p, props, x.Pointer()); err != nil {
+			return err
+		}
+		e.enc = p.buf
+		m[k] = e
+	}
+	return nil
+}
+
 // HasExtension returns whether the given extension is present in pb.
 func HasExtension(pb extendableProto, extension *ExtensionDesc) bool {
 	// TODO: Check types, field numbers, etc.?
@@ -98,7 +151,7 @@
 // ClearExtension removes the given extension from pb.
 func ClearExtension(pb extendableProto, extension *ExtensionDesc) {
 	// TODO: Check types, field numbers, etc.?
-	pb.ExtensionMap()[extension.Field] = nil, false
+	pb.ExtensionMap()[extension.Field] = Extension{}, false
 }
 
 // GetExtension parses and returns the given extension of pb.
@@ -108,14 +161,24 @@
 		return nil, err
 	}
 
-	b, ok := pb.ExtensionMap()[extension.Field]
+	e, ok := pb.ExtensionMap()[extension.Field]
 	if !ok {
 		return nil, nil // not an error
 	}
+	if e.value != nil {
+		// Already decoded. Check the descriptor, though.
+		if e.desc != extension {
+			// This shouldn't happen. If it does, it means that
+			// GetExtension was called twice with two different
+			// descriptors with the same field number.
+			return nil, os.NewError("proto: descriptor conflict")
+		}
+		return e.value, nil
+	}
 
 	// Discard wire type and field number varint. It isn't needed.
-	_, n := DecodeVarint(b)
-	o := NewBuffer(b[n:])
+	_, n := DecodeVarint(e.enc)
+	o := NewBuffer(e.enc[n:])
 
 	t := reflect.TypeOf(extension.ExtensionType)
 	props := &Properties{}
@@ -132,7 +195,12 @@
 	if err := props.dec(o, props, uintptr(base), sbase); err != nil {
 		return nil, err
 	}
-	return unsafe.Unreflect(t, base), nil
+	// Remember the decoded version and drop the encoded version.
+	// That way it is safe to mutate what we return.
+	e.value = unsafe.Unreflect(t, base)
+	e.desc = extension
+	e.enc = nil
+	return e.value, nil
 }
 
 // GetExtensions returns a slice of the extensions present in pb that are also listed in es.
@@ -167,18 +235,7 @@
 		return os.NewError("bad extension value type")
 	}
 
-	props := new(Properties)
-	props.Init(reflect.TypeOf(extension.ExtensionType), "unknown_name", extension.Tag, 0)
-
-	p := NewBuffer(nil)
-	// The encoder must be passed a pointer to value.
-	// Allocate a copy of value so that we can use its address.
-	x := reflect.New(typ)
-	x.Elem().Set(reflect.ValueOf(value))
-	if err := props.enc(p, props, x.Pointer()); err != nil {
-		return err
-	}
-	pb.ExtensionMap()[extension.Field] = p.buf
+	pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value}
 	return nil
 }