Big sync from internal version:
- include MessageSet support code
- support message_set_wire_format for extensions
- use append throughout encode.go
R=r
CC=golang-dev
http://codereview.appspot.com/3023041
diff --git a/compiler/generator/generator.go b/compiler/generator/generator.go
index e53580a..fcf6a4f 100644
--- a/compiler/generator/generator.go
+++ b/compiler/generator/generator.go
@@ -630,8 +630,9 @@
func (g *Generator) generateImports() {
// We almost always need a proto import. Rather than computing when we
// do, which is tricky when there's a plugin, just import it and
- // reference it later.
+ // reference it later. The same argument applies to the os package.
g.P("import " + g.ProtoPkg + " " + Quote(g.ImportPrefix+"goprotobuf.googlecode.com/hg/proto"))
+ g.P(`import "os"`)
for _, s := range g.file.Dependency {
// Need to find the descriptor for this file
for _, fd := range g.allFiles {
@@ -663,8 +664,9 @@
p.GenerateImports(g.file)
g.P()
}
- g.P("// Reference proto import to suppress error if it's not otherwise used.")
+ g.P("// Reference proto & os imports to suppress error if it's not otherwise used.")
g.P("var _ = ", g.ProtoPkg, ".GetString")
+ g.P("var _ os.Error")
g.P()
}
@@ -898,6 +900,24 @@
// Extension support methods
if len(message.ExtensionRange) > 0 {
+ // message_set_wire_format only makes sense when extensions are defined.
+ if opts := message.Options; opts != nil && proto.GetBool(opts.MessageSetWireFormat) {
+ g.P()
+ g.P("func (this *", ccTypeName, ") Marshal() ([]byte, os.Error) {")
+ g.In()
+ g.P("return ", g.ProtoPkg, ".MarshalMessageSet(this.ExtensionMap())")
+ g.Out()
+ g.P("}")
+ g.P("func (this *", ccTypeName, ") Unmarshal(buf []byte) os.Error {")
+ g.In()
+ g.P("return ", g.ProtoPkg, ".UnmarshalMessageSet(buf, this.ExtensionMap())")
+ g.Out()
+ g.P("}")
+ g.P("// ensure ", ccTypeName, " satisfies proto.Marshaler and proto.Unmarshaler")
+ g.P("var _ ", g.ProtoPkg, ".Marshaler = (*", ccTypeName, ")(nil)")
+ g.P("var _ ", g.ProtoPkg, ".Unmarshaler = (*", ccTypeName, ")(nil)")
+ }
+
g.P()
g.P("var extRange_", ccTypeName, " = []", g.ProtoPkg, ".ExtensionRange{")
g.In()
diff --git a/compiler/testdata/extension_base.proto b/compiler/testdata/extension_base.proto
index 5f91628..bb35b74 100644
--- a/compiler/testdata/extension_base.proto
+++ b/compiler/testdata/extension_base.proto
@@ -36,3 +36,9 @@
extensions 4 to 9;
extensions 16 to max;
}
+
+// Another message that may be extended, using message_set_wire_format.
+message OldStyleMessage {
+ option message_set_wire_format = true;
+ extensions 100 to max;
+}
diff --git a/compiler/testdata/extension_test.go b/compiler/testdata/extension_test.go
index f0831fc..eb0f853 100644
--- a/compiler/testdata/extension_test.go
+++ b/compiler/testdata/extension_test.go
@@ -34,6 +34,7 @@
package main
import (
+ "bytes"
"regexp"
"testing"
@@ -152,6 +153,47 @@
}
}
+func TestMessageSetWireFormat(t *testing.T) {
+ osm := new(base.OldStyleMessage)
+ osp := &user.OldStyleParcel{
+ Name: proto.String("Dave"),
+ Height: proto.Int32(178),
+ }
+
+ err := proto.SetExtension(osm, user.E_OldStyleParcel_MessageSetExtension, osp)
+ if err != nil {
+ t.Fatal("Failed setting extension:", err)
+ }
+
+ buf, err := proto.Marshal(osm)
+ if err != nil {
+ t.Fatal("Failed encoding message:", err)
+ }
+
+ // Data generated from Python implementation.
+ expected := []byte{
+ 11, 16, 209, 15, 26, 9, 10, 4, 68, 97, 118, 101, 16, 178, 1, 12,
+ }
+
+ if !bytes.Equal(expected, buf) {
+ t.Errorf("Encoding mismatch.\nwant %+v\n got %+v", expected, buf)
+ }
+
+ // Check that it is restored correctly.
+ osm = new(base.OldStyleMessage)
+ if err := proto.Unmarshal(buf, osm); err != nil {
+ t.Fatal("Failed decoding message:", err)
+ }
+ osp_out, err := proto.GetExtension(osm, user.E_OldStyleParcel_MessageSetExtension)
+ if err != nil {
+ t.Fatal("Failed getting extension:", err)
+ }
+ osp = osp_out.(*user.OldStyleParcel)
+ if *osp.Name != "Dave" || *osp.Height != 178 {
+ t.Errorf("Retrieved extension from decoded message is not correct: %+v", osp)
+ }
+}
+
func main() {
// simpler than rigging up gotest
testing.Main(regexp.MatchString, []testing.InternalTest{
diff --git a/compiler/testdata/extension_user.proto b/compiler/testdata/extension_user.proto
index 6fb3d72..f1d28cd 100644
--- a/compiler/testdata/extension_user.proto
+++ b/compiler/testdata/extension_user.proto
@@ -71,3 +71,13 @@
optional Announcement loud_ext = 100;
}
}
+
+// Something that can be put in a message set.
+message OldStyleParcel {
+ extend extension_base.OldStyleMessage {
+ optional OldStyleParcel message_set_extension = 2001;
+ }
+
+ required string name = 1;
+ optional int32 height = 2;
+}
diff --git a/compiler/testdata/test.pb.go.golden b/compiler/testdata/test.pb.go.golden
index 36284f0..e24e5da 100644
--- a/compiler/testdata/test.pb.go.golden
+++ b/compiler/testdata/test.pb.go.golden
@@ -4,10 +4,12 @@
package my_test
import proto "goprotobuf.googlecode.com/hg/proto"
+import "os"
import imp "imp.pb"
-// Reference proto import to suppress error if it's not otherwise used.
+// Reference proto & os imports to suppress error if it's not otherwise used.
var _ = proto.GetString
+var _ os.Error
type HatType int32
const (