goprotobuf: Various improvements to extension handling.
R=r
CC=golang-dev
http://codereview.appspot.com/4917043
diff --git a/proto/extensions.go b/proto/extensions.go
index f44d741..e0ea0bc 100644
--- a/proto/extensions.go
+++ b/proto/extensions.go
@@ -52,7 +52,7 @@
// extendableProto is an interface implemented by any protocol buffer that may be extended.
type extendableProto interface {
ExtensionRangeArray() []ExtensionRange
- ExtensionMap() map[int32][]byte
+ ExtensionMap() map[int32]Extension
}
// ExtensionDesc represents an extension specification.
@@ -65,6 +65,29 @@
Tag string // protobuf tag style
}
+/*
+Extension represents an extension in a message.
+
+When an extension is stored in a message using SetExtension
+only desc and value are set. When the message is marshaled
+enc will be set to the encoded form of the message.
+
+When a message is unmarshaled and contains extensions, each
+extension will have only enc set. When such an extension is
+accessed using GetExtension (or GetExtensions) desc and value
+will be set.
+*/
+type Extension struct {
+ desc *ExtensionDesc
+ value interface{}
+ enc []byte
+}
+
+// SetRawExtension is for testing only.
+func SetRawExtension(base extendableProto, id int32, b []byte) {
+ base.ExtensionMap()[id] = Extension{enc: b}
+}
+
// isExtensionField returns true iff the given field number is in an extension range.
func isExtensionField(pb extendableProto, field int32) bool {
for _, er := range pb.ExtensionRangeArray() {
@@ -88,6 +111,36 @@
return nil
}
+// encodeExtensionMap encodes any unmarshaled (unencoded) extensions in m.
+func encodeExtensionMap(m map[int32]Extension) os.Error {
+ for k, e := range m {
+ if e.value == nil || e.desc == nil {
+ // Extension is only in its encoded form.
+ continue
+ }
+
+ // We don't skip extensions that have an encoded form set,
+ // because the extension value may have been mutated after
+ // the last time this function was called.
+
+ et := reflect.TypeOf(e.desc.ExtensionType)
+ props := new(Properties)
+ props.Init(et, "unknown_name", e.desc.Tag, 0)
+
+ p := NewBuffer(nil)
+ // The encoder must be passed a pointer to e.value.
+ // Allocate a copy of value so that we can use its address.
+ x := reflect.New(et)
+ x.Elem().Set(reflect.ValueOf(e.value))
+ if err := props.enc(p, props, x.Pointer()); err != nil {
+ return err
+ }
+ e.enc = p.buf
+ m[k] = e
+ }
+ return nil
+}
+
// HasExtension returns whether the given extension is present in pb.
func HasExtension(pb extendableProto, extension *ExtensionDesc) bool {
// TODO: Check types, field numbers, etc.?
@@ -98,7 +151,7 @@
// ClearExtension removes the given extension from pb.
func ClearExtension(pb extendableProto, extension *ExtensionDesc) {
// TODO: Check types, field numbers, etc.?
- pb.ExtensionMap()[extension.Field] = nil, false
+ pb.ExtensionMap()[extension.Field] = Extension{}, false
}
// GetExtension parses and returns the given extension of pb.
@@ -108,14 +161,24 @@
return nil, err
}
- b, ok := pb.ExtensionMap()[extension.Field]
+ e, ok := pb.ExtensionMap()[extension.Field]
if !ok {
return nil, nil // not an error
}
+ if e.value != nil {
+ // Already decoded. Check the descriptor, though.
+ if e.desc != extension {
+ // This shouldn't happen. If it does, it means that
+ // GetExtension was called twice with two different
+ // descriptors with the same field number.
+ return nil, os.NewError("proto: descriptor conflict")
+ }
+ return e.value, nil
+ }
// Discard wire type and field number varint. It isn't needed.
- _, n := DecodeVarint(b)
- o := NewBuffer(b[n:])
+ _, n := DecodeVarint(e.enc)
+ o := NewBuffer(e.enc[n:])
t := reflect.TypeOf(extension.ExtensionType)
props := &Properties{}
@@ -132,7 +195,12 @@
if err := props.dec(o, props, uintptr(base), sbase); err != nil {
return nil, err
}
- return unsafe.Unreflect(t, base), nil
+ // Remember the decoded version and drop the encoded version.
+ // That way it is safe to mutate what we return.
+ e.value = unsafe.Unreflect(t, base)
+ e.desc = extension
+ e.enc = nil
+ return e.value, nil
}
// GetExtensions returns a slice of the extensions present in pb that are also listed in es.
@@ -167,18 +235,7 @@
return os.NewError("bad extension value type")
}
- props := new(Properties)
- props.Init(reflect.TypeOf(extension.ExtensionType), "unknown_name", extension.Tag, 0)
-
- p := NewBuffer(nil)
- // The encoder must be passed a pointer to value.
- // Allocate a copy of value so that we can use its address.
- x := reflect.New(typ)
- x.Elem().Set(reflect.ValueOf(value))
- if err := props.enc(p, props, x.Pointer()); err != nil {
- return err
- }
- pb.ExtensionMap()[extension.Field] = p.buf
+ pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value}
return nil
}