goprotobuf: Various improvements to extension handling.
R=r
CC=golang-dev
http://codereview.appspot.com/4917043
diff --git a/compiler/generator/generator.go b/compiler/generator/generator.go
index c572df3..760c5bf 100644
--- a/compiler/generator/generator.go
+++ b/compiler/generator/generator.go
@@ -271,7 +271,7 @@
if ms.hasExtensions {
g.P("func (*", ms.sym, ") ExtensionRangeArray() []", g.ProtoPkg, ".ExtensionRange ",
"{ return (*", remoteSym, ")(nil).ExtensionRangeArray() }")
- g.P("func (this *", ms.sym, ") ExtensionMap() map[int32][]byte ",
+ g.P("func (this *", ms.sym, ") ExtensionMap() map[int32]", g.ProtoPkg, ".Extension ",
"{ return (*", remoteSym, ")(this).ExtensionMap() }")
if ms.isMessageSet {
g.P("func (this *", ms.sym, ") Marshal() ([]byte, os.Error) ",
@@ -1132,7 +1132,7 @@
g.RecordTypeUse(proto.GetString(field.TypeName))
}
if len(message.ExtensionRange) > 0 {
- g.P("XXX_extensions\t\tmap[int32][]byte")
+ g.P("XXX_extensions\t\tmap[int32]", g.ProtoPkg, ".Extension")
}
g.P("XXX_unrecognized\t[]byte")
g.Out()
@@ -1170,7 +1170,7 @@
g.In()
for _, r := range message.ExtensionRange {
end := fmt.Sprint(*r.End - 1) // make range inclusive on both ends
- g.P(g.ProtoPkg+".ExtensionRange{", r.Start, ", ", end, "},")
+ g.P("{", r.Start, ", ", end, "},")
}
g.Out()
g.P("}")
@@ -1179,11 +1179,11 @@
g.P("return extRange_", ccTypeName)
g.Out()
g.P("}")
- g.P("func (this *", ccTypeName, ") ExtensionMap() map[int32][]byte {")
+ g.P("func (this *", ccTypeName, ") ExtensionMap() map[int32]", g.ProtoPkg, ".Extension {")
g.In()
g.P("if this.XXX_extensions == nil {")
g.In()
- g.P("this.XXX_extensions = make(map[int32][]byte)")
+ g.P("this.XXX_extensions = make(map[int32]", g.ProtoPkg, ".Extension)")
g.Out()
g.P("}")
g.P("return this.XXX_extensions")
@@ -1264,7 +1264,7 @@
g.P("ExtendedType: (", extendedType, ")(nil),")
g.P("ExtensionType: (", fieldType, ")(nil),")
g.P("Field: ", field.Number, ",")
- g.P(`Name: "`, g.packageName, ".", *field.Name, `",`)
+ g.P(`Name: "`, g.packageName, ".", strings.Join(ext.TypeName(), "."), `",`)
g.P("Tag: ", tag, ",")
g.Out()
diff --git a/compiler/testdata/test.pb.go.golden b/compiler/testdata/test.pb.go.golden
index 4d210b8..1295e07 100644
--- a/compiler/testdata/test.pb.go.golden
+++ b/compiler/testdata/test.pb.go.golden
@@ -145,7 +145,7 @@
type Reply struct {
Found []*Reply_Entry `protobuf:"bytes,1,rep,name=found" json:"found"`
CompactKeys []int32 `protobuf:"varint,2,rep,packed,name=compact_keys" json:"compact_keys"`
- XXX_extensions map[int32][]byte
+ XXX_extensions map[int32]proto.Extension
XXX_unrecognized []byte
}
@@ -153,15 +153,15 @@
func (this *Reply) String() string { return proto.CompactTextString(this) }
var extRange_Reply = []proto.ExtensionRange{
- proto.ExtensionRange{100, 536870911},
+ {100, 536870911},
}
func (*Reply) ExtensionRangeArray() []proto.ExtensionRange {
return extRange_Reply
}
-func (this *Reply) ExtensionMap() map[int32][]byte {
+func (this *Reply) ExtensionMap() map[int32]proto.Extension {
if this.XXX_extensions == nil {
- this.XXX_extensions = make(map[int32][]byte)
+ this.XXX_extensions = make(map[int32]proto.Extension)
}
return this.XXX_extensions
}
@@ -189,12 +189,12 @@
ExtendedType: (*Reply)(nil),
ExtensionType: (*float64)(nil),
Field: 101,
- Name: "my_test.time",
+ Name: "my_test.ReplyExtensions.time",
Tag: "fixed64,101,opt,name=time",
}
type OldReply struct {
- XXX_extensions map[int32][]byte
+ XXX_extensions map[int32]proto.Extension
XXX_unrecognized []byte
}
@@ -212,15 +212,15 @@
var _ proto.Unmarshaler = (*OldReply)(nil)
var extRange_OldReply = []proto.ExtensionRange{
- proto.ExtensionRange{100, 536870911},
+ {100, 536870911},
}
func (*OldReply) ExtensionRangeArray() []proto.ExtensionRange {
return extRange_OldReply
}
-func (this *OldReply) ExtensionMap() map[int32][]byte {
+func (this *OldReply) ExtensionMap() map[int32]proto.Extension {
if this.XXX_extensions == nil {
- this.XXX_extensions = make(map[int32][]byte)
+ this.XXX_extensions = make(map[int32]proto.Extension)
}
return this.XXX_extensions
}
diff --git a/proto/decode.go b/proto/decode.go
index 1c2bb3b..fe28821 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -358,7 +358,7 @@
iv := unsafe.Unreflect(t, unsafe.Pointer(&o.ptr))
if e, ok := iv.(extendableProto); ok && isExtensionField(e, int32(tag)) {
if err = o.skip(st, tag, wire); err == nil {
- e.ExtensionMap()[int32(tag)] = append([]byte(nil), o.buf[oi:o.index]...)
+ e.ExtensionMap()[int32(tag)] = Extension{enc: append([]byte(nil), o.buf[oi:o.index]...)}
}
continue
}
diff --git a/proto/encode.go b/proto/encode.go
index 3952dc4..5736153 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -576,9 +576,12 @@
// Encode an extension map.
func (o *Buffer) enc_map(p *Properties, base uintptr) os.Error {
- v := *(*map[int32][]byte)(unsafe.Pointer(base + p.offset))
- for _, b := range v {
- o.buf = append(o.buf, b...)
+ v := *(*map[int32]Extension)(unsafe.Pointer(base + p.offset))
+ if err := encodeExtensionMap(v); err != nil {
+ return err
+ }
+ for _, e := range v {
+ o.buf = append(o.buf, e.enc...)
}
return nil
}
diff --git a/proto/extensions.go b/proto/extensions.go
index f44d741..e0ea0bc 100644
--- a/proto/extensions.go
+++ b/proto/extensions.go
@@ -52,7 +52,7 @@
// extendableProto is an interface implemented by any protocol buffer that may be extended.
type extendableProto interface {
ExtensionRangeArray() []ExtensionRange
- ExtensionMap() map[int32][]byte
+ ExtensionMap() map[int32]Extension
}
// ExtensionDesc represents an extension specification.
@@ -65,6 +65,29 @@
Tag string // protobuf tag style
}
+/*
+Extension represents an extension in a message.
+
+When an extension is stored in a message using SetExtension
+only desc and value are set. When the message is marshaled
+enc will be set to the encoded form of the message.
+
+When a message is unmarshaled and contains extensions, each
+extension will have only enc set. When such an extension is
+accessed using GetExtension (or GetExtensions) desc and value
+will be set.
+*/
+type Extension struct {
+ desc *ExtensionDesc
+ value interface{}
+ enc []byte
+}
+
+// SetRawExtension is for testing only.
+func SetRawExtension(base extendableProto, id int32, b []byte) {
+ base.ExtensionMap()[id] = Extension{enc: b}
+}
+
// isExtensionField returns true iff the given field number is in an extension range.
func isExtensionField(pb extendableProto, field int32) bool {
for _, er := range pb.ExtensionRangeArray() {
@@ -88,6 +111,36 @@
return nil
}
+// encodeExtensionMap encodes any unmarshaled (unencoded) extensions in m.
+func encodeExtensionMap(m map[int32]Extension) os.Error {
+ for k, e := range m {
+ if e.value == nil || e.desc == nil {
+ // Extension is only in its encoded form.
+ continue
+ }
+
+ // We don't skip extensions that have an encoded form set,
+ // because the extension value may have been mutated after
+ // the last time this function was called.
+
+ et := reflect.TypeOf(e.desc.ExtensionType)
+ props := new(Properties)
+ props.Init(et, "unknown_name", e.desc.Tag, 0)
+
+ p := NewBuffer(nil)
+ // The encoder must be passed a pointer to e.value.
+ // Allocate a copy of value so that we can use its address.
+ x := reflect.New(et)
+ x.Elem().Set(reflect.ValueOf(e.value))
+ if err := props.enc(p, props, x.Pointer()); err != nil {
+ return err
+ }
+ e.enc = p.buf
+ m[k] = e
+ }
+ return nil
+}
+
// HasExtension returns whether the given extension is present in pb.
func HasExtension(pb extendableProto, extension *ExtensionDesc) bool {
// TODO: Check types, field numbers, etc.?
@@ -98,7 +151,7 @@
// ClearExtension removes the given extension from pb.
func ClearExtension(pb extendableProto, extension *ExtensionDesc) {
// TODO: Check types, field numbers, etc.?
- pb.ExtensionMap()[extension.Field] = nil, false
+ pb.ExtensionMap()[extension.Field] = Extension{}, false
}
// GetExtension parses and returns the given extension of pb.
@@ -108,14 +161,24 @@
return nil, err
}
- b, ok := pb.ExtensionMap()[extension.Field]
+ e, ok := pb.ExtensionMap()[extension.Field]
if !ok {
return nil, nil // not an error
}
+ if e.value != nil {
+ // Already decoded. Check the descriptor, though.
+ if e.desc != extension {
+ // This shouldn't happen. If it does, it means that
+ // GetExtension was called twice with two different
+ // descriptors with the same field number.
+ return nil, os.NewError("proto: descriptor conflict")
+ }
+ return e.value, nil
+ }
// Discard wire type and field number varint. It isn't needed.
- _, n := DecodeVarint(b)
- o := NewBuffer(b[n:])
+ _, n := DecodeVarint(e.enc)
+ o := NewBuffer(e.enc[n:])
t := reflect.TypeOf(extension.ExtensionType)
props := &Properties{}
@@ -132,7 +195,12 @@
if err := props.dec(o, props, uintptr(base), sbase); err != nil {
return nil, err
}
- return unsafe.Unreflect(t, base), nil
+ // Remember the decoded version and drop the encoded version.
+ // That way it is safe to mutate what we return.
+ e.value = unsafe.Unreflect(t, base)
+ e.desc = extension
+ e.enc = nil
+ return e.value, nil
}
// GetExtensions returns a slice of the extensions present in pb that are also listed in es.
@@ -167,18 +235,7 @@
return os.NewError("bad extension value type")
}
- props := new(Properties)
- props.Init(reflect.TypeOf(extension.ExtensionType), "unknown_name", extension.Tag, 0)
-
- p := NewBuffer(nil)
- // The encoder must be passed a pointer to value.
- // Allocate a copy of value so that we can use its address.
- x := reflect.New(typ)
- x.Elem().Set(reflect.ValueOf(value))
- if err := props.enc(p, props, x.Pointer()); err != nil {
- return err
- }
- pb.ExtensionMap()[extension.Field] = p.buf
+ pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value}
return nil
}
diff --git a/proto/message_set.go b/proto/message_set.go
index 036a8aa..05e18a8 100644
--- a/proto/message_set.go
+++ b/proto/message_set.go
@@ -38,6 +38,7 @@
import (
"bytes"
"os"
+ "reflect"
)
// ErrNoMessageTypeId occurs when a protocol buffer does not have a message type ID.
@@ -145,16 +146,20 @@
// MarshalMessageSet encodes the extension map represented by m in the message set wire format.
// It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option.
-func MarshalMessageSet(m map[int32][]byte) ([]byte, os.Error) {
+func MarshalMessageSet(m map[int32]Extension) ([]byte, os.Error) {
+ if err := encodeExtensionMap(m); err != nil {
+ return nil, err
+ }
+
ms := &MessageSet{Item: make([]*_MessageSet_Item, len(m))}
i := 0
- for k, v := range m {
+ for k, e := range m {
// Remove the wire type and field number varint, as well as the length varint.
- v = skipVarint(skipVarint(v))
+ msg := skipVarint(skipVarint(e.enc))
ms.Item[i] = &_MessageSet_Item{
TypeId: Int32(k),
- Message: v,
+ Message: msg,
}
i++
}
@@ -163,7 +168,7 @@
// UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format.
// It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
-func UnmarshalMessageSet(buf []byte, m map[int32][]byte) os.Error {
+func UnmarshalMessageSet(buf []byte, m map[int32]Extension) os.Error {
ms := new(MessageSet)
if err := Unmarshal(buf, ms); err != nil {
return err
@@ -174,7 +179,24 @@
b = append(b, EncodeVarint(uint64(len(item.Message)))...)
b = append(b, item.Message...)
- m[*item.TypeId] = b
+ m[*item.TypeId] = Extension{enc: b}
}
return nil
}
+
+// A global registry of types that can be used in a MessageSet.
+
+var messageSetMap = make(map[int32]messageSetDesc)
+
+type messageSetDesc struct {
+ t reflect.Type // pointer to struct
+ name string
+}
+
+// RegisterMessageSetType is called from the generated code.
+func RegisterMessageSetType(i messageTypeIder, name string) {
+ messageSetMap[i.MessageTypeId()] = messageSetDesc{
+ t: reflect.TypeOf(i),
+ name: name,
+ }
+}
diff --git a/proto/text.go b/proto/text.go
index 54a89f7..e1a08a9 100644
--- a/proto/text.go
+++ b/proto/text.go
@@ -32,7 +32,6 @@
package proto
// Functions for writing the text protocol buffer format.
-// TODO: message sets.
import (
"bytes"
@@ -41,6 +40,7 @@
"log"
"os"
"reflect"
+ "sort"
"strconv"
"strings"
)
@@ -105,17 +105,33 @@
}
}
-var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem()
+var (
+ messageSetType = reflect.TypeOf((*MessageSet)(nil)).Elem()
+ extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem()
+)
func writeStruct(w *textWriter, sv reflect.Value) {
+ if sv.Type() == messageSetType {
+ writeMessageSet(w, sv.Addr().Interface().(*MessageSet))
+ return
+ }
+
st := sv.Type()
sprops := GetProperties(st)
for i := 0; i < sv.NumField(); i++ {
- if strings.HasPrefix(st.Field(i).Name, "XXX_") {
+ fv := sv.Field(i)
+ if name := st.Field(i).Name; strings.HasPrefix(name, "XXX_") {
+ // There's only two XXX_ fields:
+ // XXX_unrecognized []byte
+ // XXX_extensions map[int32]proto.Extension
+ // The first is handled here;
+ // the second is handled at the bottom of this function.
+ if name == "XXX_unrecognized" && !fv.IsNil() {
+ writeUnknownStruct(w, fv.Interface().([]byte))
+ }
continue
}
props := sprops.Prop[i]
- fv := sv.Field(i)
if fv.Kind() == reflect.Ptr && fv.IsNil() {
// Field not filled in. This could be an optional field or
// a required field that wasn't filled in. Either way, there
@@ -152,7 +168,7 @@
w.WriteByte('\n')
}
- // Extensions.
+ // Extensions (the XXX_extensions field).
pv := sv.Addr()
if pv.Type().Implements(extendableProtoType) {
writeExtensions(w, pv)
@@ -211,19 +227,121 @@
}
}
+func writeMessageSet(w *textWriter, ms *MessageSet) {
+ for _, item := range ms.Item {
+ id := *item.TypeId
+ if msd, ok := messageSetMap[id]; ok {
+ // Known message set type.
+ fmt.Fprintf(w, "[%s]: <\n", msd.name)
+ w.indent()
+
+ pb := reflect.New(msd.t.Elem())
+ if err := Unmarshal(item.Message, pb.Interface()); err != nil {
+ fmt.Fprintf(w, "/* bad message: %v */\n", err)
+ } else {
+ writeStruct(w, pb.Elem())
+ }
+ } else {
+ // Unknown type.
+ fmt.Fprintf(w, "[%d]: <\n", id)
+ w.indent()
+ writeUnknownStruct(w, item.Message)
+ }
+ w.unindent()
+ w.Write([]byte(">\n"))
+ }
+}
+
+func writeUnknownStruct(w *textWriter, data []byte) {
+ if !w.compact {
+ fmt.Fprintf(w, "/* %d unknown bytes */\n", len(data))
+ }
+ b := NewBuffer(data)
+ for b.index < len(b.buf) {
+ x, err := b.DecodeVarint()
+ if err != nil {
+ fmt.Fprintf(w, "/* %v */\n", err)
+ return
+ }
+ wire, tag := x&7, x>>3
+ if wire == WireEndGroup {
+ w.unindent()
+ w.Write([]byte("}\n"))
+ continue
+ }
+ fmt.Fprintf(w, "tag%d", tag)
+ if wire != WireStartGroup {
+ w.WriteByte(':')
+ }
+ if !w.compact || wire == WireStartGroup {
+ w.WriteByte(' ')
+ }
+ switch wire {
+ case WireBytes:
+ buf, err := b.DecodeRawBytes(false)
+ if err == nil {
+ fmt.Fprintf(w, "%q", buf)
+ } else {
+ fmt.Fprintf(w, "/* %v */", err)
+ }
+ case WireFixed32:
+ x, err := b.DecodeFixed32()
+ writeUnknownInt(w, x, err)
+ case WireFixed64:
+ x, err := b.DecodeFixed64()
+ writeUnknownInt(w, x, err)
+ case WireStartGroup:
+ fmt.Fprint(w, "{")
+ w.indent()
+ case WireVarint:
+ x, err := b.DecodeVarint()
+ writeUnknownInt(w, x, err)
+ default:
+ fmt.Fprintf(w, "/* unknown wire type %d */", wire)
+ }
+ w.WriteByte('\n')
+ }
+}
+
+func writeUnknownInt(w *textWriter, x uint64, err os.Error) {
+ if err == nil {
+ fmt.Fprint(w, x)
+ } else {
+ fmt.Fprintf(w, "/* %v */", err)
+ }
+}
+
+type int32Slice []int32
+
+func (s int32Slice) Len() int { return len(s) }
+func (s int32Slice) Less(i, j int) bool { return s[i] < s[j] }
+func (s int32Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
// writeExtensions writes all the extensions in pv.
// pv is assumed to be a pointer to a protocol message struct that is extendable.
func writeExtensions(w *textWriter, pv reflect.Value) {
emap := extensionMaps[pv.Type().Elem()]
ep := pv.Interface().(extendableProto)
- for extNum := range ep.ExtensionMap() {
+
+ // Order the extensions by ID.
+ // This isn't strictly necessary, but it will give us
+ // canonical output, which will also make testing easier.
+ m := ep.ExtensionMap()
+ ids := make([]int32, 0, len(m))
+ for id := range m {
+ ids = append(ids, id)
+ }
+ sort.Sort(int32Slice(ids))
+
+ for _, extNum := range ids {
+ ext := m[extNum]
var desc *ExtensionDesc
if emap != nil {
desc = emap[extNum]
}
if desc == nil {
- // TODO: Handle printing unknown extensions.
- fmt.Fprintln(os.Stderr, "proto: unknown extension: ", extNum)
+ // Unknown extension.
+ writeUnknownStruct(w, ext.enc)
continue
}
diff --git a/proto/text_test.go b/proto/text_test.go
index 6c3f546..00f17f8 100644
--- a/proto/text_test.go
+++ b/proto/text_test.go
@@ -33,6 +33,7 @@
import (
"bytes"
+ "strings"
"testing"
"goprotobuf.googlecode.com/hg/proto"
@@ -68,6 +69,9 @@
Somegroup: &pb.MyMessage_SomeGroup{
GroupField: proto.Int32(8),
},
+ // One normally wouldn't do this.
+ // This is an undeclared tag 13, as a varint (wire type 0) with value 4.
+ XXX_unrecognized: []byte{13<<3 | 0, 4},
}
ext := &pb.Ext{
Data: proto.String("Big gobs for big rats"),
@@ -75,6 +79,19 @@
if err := proto.SetExtension(msg, pb.E_Ext_More, ext); err != nil {
panic(err)
}
+
+ // Add an unknown extension. We marshal a pb.Ext, and fake the ID.
+ b, err := proto.Marshal(&pb.Ext{Data: proto.String("3G skiing")})
+ if err != nil {
+ panic(err)
+ }
+ b = append(proto.EncodeVarint(104<<3|proto.WireBytes), b...)
+ proto.SetRawExtension(msg, 104, b)
+
+ // Extensions can be plain fields, too, so let's test that.
+ b = append(proto.EncodeVarint(105<<3|proto.WireVarint), 19)
+ proto.SetRawExtension(msg, 105, b)
+
return msg
}
@@ -104,9 +121,15 @@
SomeGroup {
group_field: 8
}
-[test_proto.more]: <
+/* 2 unknown bytes */
+tag13: 4
+[test_proto.Ext.more]: <
data: "Big gobs for big rats"
>
+/* 13 unknown bytes */
+tag104: "\t3G skiing"
+/* 3 unknown bytes */
+tag105: 19
`
func TestMarshalTextFull(t *testing.T) {
@@ -121,9 +144,22 @@
func compact(src string) string {
// s/[ \n]+/ /g; s/ $//;
dst := make([]byte, len(src))
- space := false
+ space, comment := false, false
j := 0
for i := 0; i < len(src); i++ {
+ if strings.HasPrefix(src[i:], "/*") {
+ comment = true
+ i++
+ continue
+ }
+ if comment && strings.HasPrefix(src[i:], "*/") {
+ comment = false
+ i++
+ continue
+ }
+ if comment {
+ continue
+ }
c := src[i]
if c == ' ' || c == '\n' {
space = true
@@ -158,3 +194,5 @@
t.Errorf("Got:\n===\n%v===\nExpected:\n===\n%v===\n", s, compactText)
}
}
+
+