Implement oneof support.
This includes the code generation changes,
and the infrastructure to wire it up to the encode/decode machinery.
The overall API changes are these:
- oneofs in a message are replaced by a single interface field
- each field in a oneof gets a distinguished type that satisfies
the corresponding interface
- a type switch may be used to distinguish between oneof fields
Fixes #29.
diff --git a/Makefile b/Makefile
index 43b9499..d24239e 100644
--- a/Makefile
+++ b/Makefile
@@ -51,5 +51,6 @@
regenerate:
make -C protoc-gen-go/descriptor regenerate
make -C protoc-gen-go/plugin regenerate
+ make -C protoc-gen-go/testdata regenerate
make -C proto/testdata regenerate
make -C jsonpb/jsonpb_test_proto regenerate
diff --git a/jsonpb/jsonpb.go b/jsonpb/jsonpb.go
index fb5e087..2663c92 100644
--- a/jsonpb/jsonpb.go
+++ b/jsonpb/jsonpb.go
@@ -111,6 +111,18 @@
}
}
+ // Oneof fields need special handling.
+ if valueField.Tag.Get("protobuf_oneof") != "" {
+ // value is an interface containing &T{real_value}.
+ sv := value.Elem().Elem() // interface -> *T -> T
+ value = sv.Field(0)
+ valueField = sv.Type().Field(0)
+
+ var p proto.Properties
+ p.Parse(sv.Type().Field(0).Tag.Get("protobuf"))
+ fieldName = p.OrigName
+ }
+
out.write(writeBeforeField)
if m.Indent != "" {
out.write(indent)
@@ -319,6 +331,44 @@
delete(jsonFields, fieldName)
}
}
+ // Check for any oneof fields.
+ // This might be slow; we can optimise it if it becomes a problem.
+ type oneofMessage interface {
+ XXX_OneofFuncs() (func(proto.Message, *proto.Buffer) error, func(proto.Message, int, int, *proto.Buffer) (bool, error), []interface{})
+ }
+ var oneofTypes []interface{}
+ if om, ok := reflect.Zero(reflect.PtrTo(targetType)).Interface().(oneofMessage); ok {
+ _, _, oneofTypes = om.XXX_OneofFuncs()
+ }
+ for fname, raw := range jsonFields {
+ for _, oot := range oneofTypes {
+ sp := reflect.ValueOf(oot).Type() // *T
+ var props proto.Properties
+ props.Parse(sp.Elem().Field(0).Tag.Get("protobuf"))
+ if props.OrigName != fname {
+ continue
+ }
+ nv := reflect.New(sp.Elem())
+ // There will be exactly one interface field that
+ // this new value is assignable to.
+ for i := 0; i < targetType.NumField(); i++ {
+ f := targetType.Field(i)
+ if f.Type.Kind() != reflect.Interface {
+ continue
+ }
+ if !nv.Type().AssignableTo(f.Type) {
+ continue
+ }
+ target.Field(i).Set(nv)
+ break
+ }
+ if err := unmarshalValue(nv.Elem().Field(0), raw); err != nil {
+ return err
+ }
+ delete(jsonFields, fname)
+ break
+ }
+ }
if len(jsonFields) > 0 {
// Pick any field to be the scapegoat.
var f string
diff --git a/jsonpb/jsonpb_test.go b/jsonpb/jsonpb_test.go
index 180d5d2..200e02c 100644
--- a/jsonpb/jsonpb_test.go
+++ b/jsonpb/jsonpb_test.go
@@ -32,6 +32,7 @@
package jsonpb
import (
+ "reflect"
"testing"
pb "github.com/golang/protobuf/jsonpb/jsonpb_test_proto"
@@ -291,6 +292,8 @@
{"proto2 map<bool, Object>", marshaler,
&pb.Maps{MBoolSimple: map[bool]*pb.Simple{true: &pb.Simple{OInt32: proto.Int32(1)}}},
`{"m_bool_simple":{"true":{"o_int32":1}}}`},
+ {"oneof, not set", marshaler, &pb.MsgWithOneof{}, `{}`},
+ {"oneof, set", marshaler, &pb.MsgWithOneof{Union: &pb.MsgWithOneof_Title{"Grand Poobah"}}, `{"title":"Grand Poobah"}`},
}
func TestMarshaling(t *testing.T) {
@@ -325,13 +328,13 @@
{"map<int64, int32>", `{"nummy":{"1":2,"3":4}}`, &pb.Mappy{Nummy: map[int64]int32{1: 2, 3: 4}}},
{"map<string, string>", `{"strry":{"\"one\"":"two","three":"four"}}`, &pb.Mappy{Strry: map[string]string{`"one"`: "two", "three": "four"}}},
{"map<int32, Object>", `{"objjy":{"1":{"dub":1}}}`, &pb.Mappy{Objjy: map[int32]*pb.Simple3{1: &pb.Simple3{Dub: 1}}}},
+ {"oneof", `{"salary":31000}`, &pb.MsgWithOneof{Union: &pb.MsgWithOneof_Salary{31000}}},
}
func TestUnmarshaling(t *testing.T) {
for _, tt := range unmarshalingTests {
// Make a new instance of the type of our expected object.
- p := proto.Clone(tt.pb)
- p.Reset()
+ p := reflect.New(reflect.TypeOf(tt.pb).Elem()).Interface().(proto.Message)
err := UnmarshalString(tt.json, p)
if err != nil {
diff --git a/jsonpb/jsonpb_test_proto/more_test_objects.pb.go b/jsonpb/jsonpb_test_proto/more_test_objects.pb.go
index 615d57a..2634853 100644
--- a/jsonpb/jsonpb_test_proto/more_test_objects.pb.go
+++ b/jsonpb/jsonpb_test_proto/more_test_objects.pb.go
@@ -16,9 +16,13 @@
package jsonpb
import proto "github.com/golang/protobuf/proto"
+import fmt "fmt"
+import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
type Simple3 struct {
Dub float64 `protobuf:"fixed64,1,opt,name=dub" json:"dub,omitempty"`
@@ -74,6 +78,3 @@
}
return nil
}
-
-func init() {
-}
diff --git a/jsonpb/jsonpb_test_proto/test_objects.pb.go b/jsonpb/jsonpb_test_proto/test_objects.pb.go
index 1f2f061..8c5b025 100644
--- a/jsonpb/jsonpb_test_proto/test_objects.pb.go
+++ b/jsonpb/jsonpb_test_proto/test_objects.pb.go
@@ -5,10 +5,12 @@
package jsonpb
import proto "github.com/golang/protobuf/proto"
+import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
+var _ = fmt.Errorf
var _ = math.Inf
type Widget_Color int32
@@ -322,6 +324,100 @@
return nil
}
+type MsgWithOneof struct {
+ // Types that are valid to be assigned to Union:
+ // *MsgWithOneof_Title
+ // *MsgWithOneof_Salary
+ Union isMsgWithOneof_Union `protobuf_oneof:"union"`
+ XXX_unrecognized []byte `json:"-"`
+}
+
+func (m *MsgWithOneof) Reset() { *m = MsgWithOneof{} }
+func (m *MsgWithOneof) String() string { return proto.CompactTextString(m) }
+func (*MsgWithOneof) ProtoMessage() {}
+
+type isMsgWithOneof_Union interface {
+ isMsgWithOneof_Union()
+}
+
+type MsgWithOneof_Title struct {
+ Title string `protobuf:"bytes,1,opt,name=title"`
+}
+type MsgWithOneof_Salary struct {
+ Salary int64 `protobuf:"varint,2,opt,name=salary"`
+}
+
+func (*MsgWithOneof_Title) isMsgWithOneof_Union() {}
+func (*MsgWithOneof_Salary) isMsgWithOneof_Union() {}
+
+func (m *MsgWithOneof) GetUnion() isMsgWithOneof_Union {
+ if m != nil {
+ return m.Union
+ }
+ return nil
+}
+
+func (m *MsgWithOneof) GetTitle() string {
+ if x, ok := m.GetUnion().(*MsgWithOneof_Title); ok {
+ return x.Title
+ }
+ return ""
+}
+
+func (m *MsgWithOneof) GetSalary() int64 {
+ if x, ok := m.GetUnion().(*MsgWithOneof_Salary); ok {
+ return x.Salary
+ }
+ return 0
+}
+
+// XXX_OneofFuncs is for the internal use of the proto package.
+func (*MsgWithOneof) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), []interface{}) {
+ return _MsgWithOneof_OneofMarshaler, _MsgWithOneof_OneofUnmarshaler, []interface{}{
+ (*MsgWithOneof_Title)(nil),
+ (*MsgWithOneof_Salary)(nil),
+ }
+}
+
+func _MsgWithOneof_OneofMarshaler(msg proto.Message, b *proto.Buffer) error {
+ m := msg.(*MsgWithOneof)
+ // union
+ switch x := m.Union.(type) {
+ case *MsgWithOneof_Title:
+ b.EncodeVarint(1<<3 | proto.WireBytes)
+ b.EncodeStringBytes(x.Title)
+ case *MsgWithOneof_Salary:
+ b.EncodeVarint(2<<3 | proto.WireVarint)
+ b.EncodeVarint(uint64(x.Salary))
+ case nil:
+ default:
+ return fmt.Errorf("MsgWithOneof.Union has unexpected type %T", x)
+ }
+ return nil
+}
+
+func _MsgWithOneof_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) {
+ m := msg.(*MsgWithOneof)
+ switch tag {
+ case 1: // union.title
+ if wire != proto.WireBytes {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeStringBytes()
+ m.Union = &MsgWithOneof_Title{x}
+ return true, err
+ case 2: // union.salary
+ if wire != proto.WireVarint {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeVarint()
+ m.Union = &MsgWithOneof_Salary{int64(x)}
+ return true, err
+ default:
+ return false, nil
+ }
+}
+
func init() {
proto.RegisterEnum("jsonpb.Widget_Color", Widget_Color_name, Widget_Color_value)
}
diff --git a/jsonpb/jsonpb_test_proto/test_objects.proto b/jsonpb/jsonpb_test_proto/test_objects.proto
index e48a3e8..85700bf 100644
--- a/jsonpb/jsonpb_test_proto/test_objects.proto
+++ b/jsonpb/jsonpb_test_proto/test_objects.proto
@@ -84,3 +84,10 @@
map<int64, string> m_int64_str = 1;
map<bool, Simple> m_bool_simple = 2;
}
+
+message MsgWithOneof {
+ oneof union {
+ string title = 1;
+ int64 salary = 2;
+ }
+}
diff --git a/proto/all_test.go b/proto/all_test.go
index b787d58..49ce01f 100644
--- a/proto/all_test.go
+++ b/proto/all_test.go
@@ -1958,6 +1958,58 @@
}
}
+func TestOneof(t *testing.T) {
+ m := &Communique{}
+ b, err := Marshal(m)
+ if err != nil {
+ t.Fatalf("Marshal of empty message with oneof: %v", err)
+ }
+ if len(b) != 0 {
+ t.Errorf("Marshal of empty message yielded too many bytes: %v", b)
+ }
+
+ m = &Communique{
+ Union: &Communique_Name{"Barry"},
+ }
+
+ // Round-trip.
+ b, err = Marshal(m)
+ if err != nil {
+ t.Fatalf("Marshal of message with oneof: %v", err)
+ }
+ if len(b) != 7 { // name tag/wire (1) + name len (1) + name (5)
+ t.Errorf("Incorrect marshal of message with oneof: %v", b)
+ }
+ m.Reset()
+ if err := Unmarshal(b, m); err != nil {
+ t.Fatalf("Unmarshal of message with oneof: %v", err)
+ }
+ if x, ok := m.Union.(*Communique_Name); !ok || x.Name != "Barry" {
+ t.Errorf("After round trip, Union = %+v", m.Union)
+ }
+ if name := m.GetName(); name != "Barry" {
+ t.Errorf("After round trip, GetName = %q, want %q", name, "Barry")
+ }
+
+ // Let's try with a message in the oneof.
+ m.Union = &Communique_Msg{&Strings{StringField: String("deep deep string")}}
+ b, err = Marshal(m)
+ if err != nil {
+ t.Fatalf("Marshal of message with oneof set to message: %v", err)
+ }
+ if len(b) != 20 { // msg tag/wire (1) + msg len (1) + msg (1 + 1 + 16)
+ t.Errorf("Incorrect marshal of message with oneof set to message: %v", b)
+ }
+ m.Reset()
+ if err := Unmarshal(b, m); err != nil {
+ t.Fatalf("Unmarshal of message with oneof set to message: %v", err)
+ }
+ ss, ok := m.Union.(*Communique_Msg)
+ if !ok || ss.Msg.GetStringField() != "deep deep string" {
+ t.Errorf("After round trip with oneof set to message, Union = %+v", m.Union)
+ }
+}
+
// Benchmarks
func testMsg() *GoTest {
diff --git a/proto/clone.go b/proto/clone.go
index 915a68b..0155cc7 100644
--- a/proto/clone.go
+++ b/proto/clone.go
@@ -120,6 +120,13 @@
return
}
out.Set(in)
+ case reflect.Interface:
+ // Probably a oneof field; copy non-nil values.
+ if in.IsNil() {
+ return
+ }
+ out.Set(reflect.New(in.Elem().Elem().Type())) // interface -> *T -> T -> new(T)
+ mergeAny(out.Elem(), in.Elem(), false, nil)
case reflect.Map:
if in.Len() == 0 {
return
diff --git a/proto/clone_test.go b/proto/clone_test.go
index a1c697b..76720f1 100644
--- a/proto/clone_test.go
+++ b/proto/clone_test.go
@@ -232,6 +232,28 @@
Data: []byte("texas!"),
},
},
+ // Oneof fields should merge by assignment.
+ {
+ src: &pb.Communique{
+ Union: &pb.Communique_Number{41},
+ },
+ dst: &pb.Communique{
+ Union: &pb.Communique_Name{"Bobby Tables"},
+ },
+ want: &pb.Communique{
+ Union: &pb.Communique_Number{41},
+ },
+ },
+ // Oneof nil is the same as not set.
+ {
+ src: &pb.Communique{},
+ dst: &pb.Communique{
+ Union: &pb.Communique_Name{"Bobby Tables"},
+ },
+ want: &pb.Communique{
+ Union: &pb.Communique_Name{"Bobby Tables"},
+ },
+ },
}
func TestMerge(t *testing.T) {
diff --git a/proto/decode.go b/proto/decode.go
index bf71dca..8486635 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -46,6 +46,10 @@
// errOverflow is returned when an integer is too large to be represented.
var errOverflow = errors.New("proto: integer overflow")
+// ErrInternalBadWireType is returned by generated code when an incorrect
+// wire type is encountered. It does not get returned to user code.
+var ErrInternalBadWireType = errors.New("proto: internal error: bad wiretype for oneof")
+
// The fundamental decoders that interpret bytes on the wire.
// Those that take integer types all return uint64 and are
// therefore of type valueDecoder.
@@ -314,6 +318,24 @@
return NewBuffer(buf).Unmarshal(pb)
}
+// DecodeMessage reads a count-delimited message from the Buffer.
+func (p *Buffer) DecodeMessage(pb Message) error {
+ enc, err := p.DecodeRawBytes(false)
+ if err != nil {
+ return err
+ }
+ return NewBuffer(enc).Unmarshal(pb)
+}
+
+// DecodeGroup reads a tag-delimited group from the Buffer.
+func (p *Buffer) DecodeGroup(pb Message) error {
+ typ, base, err := getbase(pb)
+ if err != nil {
+ return err
+ }
+ return p.unmarshalType(typ.Elem(), GetProperties(typ.Elem()), true, base)
+}
+
// Unmarshal parses the protocol buffer representation in the
// Buffer and places the decoded result in pb. If the struct
// underlying pb does not match the data in the buffer, the results can be
@@ -377,6 +399,20 @@
continue
}
}
+ // Maybe it's a oneof?
+ if prop.oneofUnmarshaler != nil {
+ m := structPointer_Interface(base, st).(Message)
+ // First return value indicates whether tag is a oneof field.
+ ok, err = prop.oneofUnmarshaler(m, tag, wire, o)
+ if err == ErrInternalBadWireType {
+ // Map the error to something more descriptive.
+ // Do the formatting here to save generated code space.
+ err = fmt.Errorf("bad wiretype for oneof field in %T", m)
+ }
+ if ok {
+ continue
+ }
+ }
err = o.skipAndSave(st, tag, wire, base, prop.unrecField)
continue
}
diff --git a/proto/encode.go b/proto/encode.go
index 91f3f07..fe48cc7 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -228,6 +228,20 @@
return p.buf, err
}
+// EncodeMessage writes the protocol buffer to the Buffer,
+// prefixed by a varint-encoded length.
+func (p *Buffer) EncodeMessage(pb Message) error {
+ t, base, err := getbase(pb)
+ if structPointer_IsNil(base) {
+ return ErrNil
+ }
+ if err == nil {
+ var state errorState
+ err = p.enc_len_struct(GetProperties(t.Elem()), base, &state)
+ }
+ return err
+}
+
// Marshal takes the protocol buffer
// and encodes it into the wire format, writing the result to the
// Buffer.
@@ -1201,6 +1215,14 @@
}
}
+ // Do oneof fields.
+ if prop.oneofMarshaler != nil {
+ m := structPointer_Interface(base, prop.stype).(Message)
+ if err := prop.oneofMarshaler(m, o); err != nil {
+ return err
+ }
+ }
+
// Add unrecognized fields at the end.
if prop.unrecField.IsValid() {
v := *structPointer_Bytes(base, prop.unrecField)
@@ -1226,6 +1248,27 @@
n += len(v)
}
+ // Factor in any oneof fields.
+ // TODO: This could be faster and use less reflection.
+ if prop.oneofMarshaler != nil {
+ sv := reflect.ValueOf(structPointer_Interface(base, prop.stype)).Elem()
+ for i := 0; i < prop.stype.NumField(); i++ {
+ fv := sv.Field(i)
+ if fv.Kind() != reflect.Interface || fv.IsNil() {
+ continue
+ }
+ if prop.stype.Field(i).Tag.Get("protobuf_oneof") == "" {
+ continue
+ }
+ spv := fv.Elem() // interface -> *T
+ sv := spv.Elem() // *T -> T
+ sf := sv.Type().Field(0) // StructField inside T
+ var prop Properties
+ prop.Init(sf.Type, "whatever", sf.Tag.Get("protobuf"), &sf)
+ n += prop.size(&prop, toStructPointer(spv))
+ }
+ }
+
return
}
diff --git a/proto/equal.go b/proto/equal.go
index d8673a3..5475c3d 100644
--- a/proto/equal.go
+++ b/proto/equal.go
@@ -154,6 +154,17 @@
return v1.Float() == v2.Float()
case reflect.Int32, reflect.Int64:
return v1.Int() == v2.Int()
+ case reflect.Interface:
+ // Probably a oneof field; compare the inner values.
+ n1, n2 := v1.IsNil(), v2.IsNil()
+ if n1 || n2 {
+ return n1 == n2
+ }
+ e1, e2 := v1.Elem(), v2.Elem()
+ if e1.Type() != e2.Type() {
+ return false
+ }
+ return equalAny(e1, e2)
case reflect.Map:
if v1.Len() != v2.Len() {
return false
diff --git a/proto/equal_test.go b/proto/equal_test.go
index b322f65..0e0db8a 100644
--- a/proto/equal_test.go
+++ b/proto/equal_test.go
@@ -180,6 +180,24 @@
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Rob"}},
false,
},
+ {
+ "oneof same",
+ &pb.Communique{Union: &pb.Communique_Number{41}},
+ &pb.Communique{Union: &pb.Communique_Number{41}},
+ true,
+ },
+ {
+ "oneof one nil",
+ &pb.Communique{Union: &pb.Communique_Number{41}},
+ &pb.Communique{},
+ false,
+ },
+ {
+ "oneof different",
+ &pb.Communique{Union: &pb.Communique_Number{41}},
+ &pb.Communique{Union: &pb.Communique_Name{"Bobby Tables"}},
+ false,
+ },
}
func TestEqual(t *testing.T) {
diff --git a/proto/lib.go b/proto/lib.go
index 95f7975..187bd65 100644
--- a/proto/lib.go
+++ b/proto/lib.go
@@ -66,6 +66,8 @@
that contain it (if any) followed by the CamelCased name of the
extension field itself. HasExtension, ClearExtension, GetExtension
and SetExtension are functions for manipulating extensions.
+ - Oneof field sets are given a single field in their message,
+ with distinguished wrapper types for each possible field value.
- Marshal and Unmarshal are functions to encode and decode the wire format.
The simplest way to describe this is to see an example.
diff --git a/proto/properties.go b/proto/properties.go
index d74844a..5685445 100644
--- a/proto/properties.go
+++ b/proto/properties.go
@@ -84,6 +84,12 @@
// A valueDecoder decodes a single integer in a particular encoding.
type valueDecoder func(o *Buffer) (x uint64, err error)
+// A oneofMarshaler does the marshaling for all oneof fields in a message.
+type oneofMarshaler func(Message, *Buffer) error
+
+// A oneofUnmarshaler does the unmarshaling for a oneof field in a message.
+type oneofUnmarshaler func(Message, int, int, *Buffer) (bool, error)
+
// tagMap is an optimization over map[int]int for typical protocol buffer
// use-cases. Encoded protocol buffers are often in tag order with small tag
// numbers.
@@ -132,6 +138,11 @@
order []int // list of struct field numbers in tag order
unrecField field // field id of the XXX_unrecognized []byte field
extendable bool // is this an extendable proto
+
+ oneofMarshaler oneofMarshaler
+ oneofUnmarshaler oneofUnmarshaler
+ stype reflect.Type
+ oneofTypes []interface{}
}
// Implement the sorting interface so we can sort the fields in tag order, as recommended by the spec.
@@ -665,6 +676,7 @@
if f.Name == "XXX_unrecognized" { // special case
prop.unrecField = toField(&f)
}
+ oneof := f.Tag.Get("protobuf_oneof") != "" // special case
prop.Prop[i] = p
prop.order[i] = i
if debug {
@@ -674,7 +686,7 @@
}
print("\n")
}
- if p.enc == nil && !strings.HasPrefix(f.Name, "XXX_") {
+ if p.enc == nil && !strings.HasPrefix(f.Name, "XXX_") && !oneof {
fmt.Fprintln(os.Stderr, "proto: no encoder for", f.Name, f.Type.String(), "[GetProperties]")
}
}
@@ -682,6 +694,14 @@
// Re-order prop.order.
sort.Sort(prop)
+ type oneofMessage interface {
+ XXX_OneofFuncs() (func(Message, *Buffer) error, func(Message, int, int, *Buffer) (bool, error), []interface{})
+ }
+ if om, ok := reflect.Zero(reflect.PtrTo(t)).Interface().(oneofMessage); ok {
+ prop.oneofMarshaler, prop.oneofUnmarshaler, prop.oneofTypes = om.XXX_OneofFuncs()
+ prop.stype = t
+ }
+
// build required counts
// build tags
reqCount := 0
diff --git a/proto/size_test.go b/proto/size_test.go
index db5614f..53806a3 100644
--- a/proto/size_test.go
+++ b/proto/size_test.go
@@ -124,6 +124,10 @@
{"map field with big entry", &pb.MessageWithMap{NameMapping: map[int32]string{8: strings.Repeat("x", 125)}}},
{"map field with big key and val", &pb.MessageWithMap{StrToStr: map[string]string{strings.Repeat("x", 70): strings.Repeat("y", 70)}}},
{"map field with big numeric key", &pb.MessageWithMap{NameMapping: map[int32]string{0xf00d: "om nom nom"}}},
+
+ {"oneof not set", &pb.Communique{}},
+ {"oneof int32", &pb.Communique{Union: &pb.Communique_Number{3}}},
+ {"oneof string", &pb.Communique{Union: &pb.Communique_Name{"Rhythmic Fman"}}},
}
func TestSize(t *testing.T) {
diff --git a/proto/testdata/test.pb.go b/proto/testdata/test.pb.go
index 13674a4..c22fe31 100644
--- a/proto/testdata/test.pb.go
+++ b/proto/testdata/test.pb.go
@@ -35,14 +35,17 @@
GroupNew
FloatingPoint
MessageWithMap
+ Communique
*/
package testdata
import proto "github.com/golang/protobuf/proto"
+import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
+var _ = fmt.Errorf
var _ = math.Inf
type FOO int32
@@ -1986,6 +1989,205 @@
return nil
}
+type Communique struct {
+ MakeMeCry *bool `protobuf:"varint,1,opt,name=make_me_cry" json:"make_me_cry,omitempty"`
+ // This is a oneof, called "union".
+ //
+ // Types that are valid to be assigned to Union:
+ // *Communique_Number
+ // *Communique_Name
+ // *Communique_Data
+ // *Communique_TempC
+ // *Communique_Col
+ // *Communique_Msg
+ Union isCommunique_Union `protobuf_oneof:"union"`
+ XXX_unrecognized []byte `json:"-"`
+}
+
+func (m *Communique) Reset() { *m = Communique{} }
+func (m *Communique) String() string { return proto.CompactTextString(m) }
+func (*Communique) ProtoMessage() {}
+
+type isCommunique_Union interface {
+ isCommunique_Union()
+}
+
+type Communique_Number struct {
+ Number int32 `protobuf:"varint,5,opt,name=number"`
+}
+type Communique_Name struct {
+ Name string `protobuf:"bytes,6,opt,name=name"`
+}
+type Communique_Data struct {
+ Data []byte `protobuf:"bytes,7,opt,name=data"`
+}
+type Communique_TempC struct {
+ TempC float64 `protobuf:"fixed64,8,opt,name=temp_c"`
+}
+type Communique_Col struct {
+ Col MyMessage_Color `protobuf:"varint,9,opt,name=col,enum=testdata.MyMessage_Color"`
+}
+type Communique_Msg struct {
+ Msg *Strings `protobuf:"bytes,10,opt,name=msg"`
+}
+
+func (*Communique_Number) isCommunique_Union() {}
+func (*Communique_Name) isCommunique_Union() {}
+func (*Communique_Data) isCommunique_Union() {}
+func (*Communique_TempC) isCommunique_Union() {}
+func (*Communique_Col) isCommunique_Union() {}
+func (*Communique_Msg) isCommunique_Union() {}
+
+func (m *Communique) GetUnion() isCommunique_Union {
+ if m != nil {
+ return m.Union
+ }
+ return nil
+}
+
+func (m *Communique) GetMakeMeCry() bool {
+ if m != nil && m.MakeMeCry != nil {
+ return *m.MakeMeCry
+ }
+ return false
+}
+
+func (m *Communique) GetNumber() int32 {
+ if x, ok := m.GetUnion().(*Communique_Number); ok {
+ return x.Number
+ }
+ return 0
+}
+
+func (m *Communique) GetName() string {
+ if x, ok := m.GetUnion().(*Communique_Name); ok {
+ return x.Name
+ }
+ return ""
+}
+
+func (m *Communique) GetData() []byte {
+ if x, ok := m.GetUnion().(*Communique_Data); ok {
+ return x.Data
+ }
+ return nil
+}
+
+func (m *Communique) GetTempC() float64 {
+ if x, ok := m.GetUnion().(*Communique_TempC); ok {
+ return x.TempC
+ }
+ return 0
+}
+
+func (m *Communique) GetCol() MyMessage_Color {
+ if x, ok := m.GetUnion().(*Communique_Col); ok {
+ return x.Col
+ }
+ return MyMessage_RED
+}
+
+func (m *Communique) GetMsg() *Strings {
+ if x, ok := m.GetUnion().(*Communique_Msg); ok {
+ return x.Msg
+ }
+ return nil
+}
+
+// XXX_OneofFuncs is for the internal use of the proto package.
+func (*Communique) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), []interface{}) {
+ return _Communique_OneofMarshaler, _Communique_OneofUnmarshaler, []interface{}{
+ (*Communique_Number)(nil),
+ (*Communique_Name)(nil),
+ (*Communique_Data)(nil),
+ (*Communique_TempC)(nil),
+ (*Communique_Col)(nil),
+ (*Communique_Msg)(nil),
+ }
+}
+
+func _Communique_OneofMarshaler(msg proto.Message, b *proto.Buffer) error {
+ m := msg.(*Communique)
+ // union
+ switch x := m.Union.(type) {
+ case *Communique_Number:
+ b.EncodeVarint(5<<3 | proto.WireVarint)
+ b.EncodeVarint(uint64(x.Number))
+ case *Communique_Name:
+ b.EncodeVarint(6<<3 | proto.WireBytes)
+ b.EncodeStringBytes(x.Name)
+ case *Communique_Data:
+ b.EncodeVarint(7<<3 | proto.WireBytes)
+ b.EncodeRawBytes(x.Data)
+ case *Communique_TempC:
+ b.EncodeVarint(8<<3 | proto.WireFixed64)
+ b.EncodeFixed64(math.Float64bits(x.TempC))
+ case *Communique_Col:
+ b.EncodeVarint(9<<3 | proto.WireVarint)
+ b.EncodeVarint(uint64(x.Col))
+ case *Communique_Msg:
+ b.EncodeVarint(10<<3 | proto.WireBytes)
+ if err := b.EncodeMessage(x.Msg); err != nil {
+ return err
+ }
+ case nil:
+ default:
+ return fmt.Errorf("Communique.Union has unexpected type %T", x)
+ }
+ return nil
+}
+
+func _Communique_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) {
+ m := msg.(*Communique)
+ switch tag {
+ case 5: // union.number
+ if wire != proto.WireVarint {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeVarint()
+ m.Union = &Communique_Number{int32(x)}
+ return true, err
+ case 6: // union.name
+ if wire != proto.WireBytes {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeStringBytes()
+ m.Union = &Communique_Name{x}
+ return true, err
+ case 7: // union.data
+ if wire != proto.WireBytes {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeRawBytes(true)
+ m.Union = &Communique_Data{x}
+ return true, err
+ case 8: // union.temp_c
+ if wire != proto.WireFixed64 {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeFixed64()
+ m.Union = &Communique_TempC{math.Float64frombits(x)}
+ return true, err
+ case 9: // union.col
+ if wire != proto.WireVarint {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeVarint()
+ m.Union = &Communique_Col{MyMessage_Color(x)}
+ return true, err
+ case 10: // union.msg
+ if wire != proto.WireBytes {
+ return true, proto.ErrInternalBadWireType
+ }
+ msg := new(Strings)
+ err := b.DecodeMessage(msg)
+ m.Union = &Communique_Msg{msg}
+ return true, err
+ default:
+ return false, nil
+ }
+}
+
var E_Greeting = &proto.ExtensionDesc{
ExtendedType: (*MyMessage)(nil),
ExtensionType: ([]string)(nil),
diff --git a/proto/testdata/test.proto b/proto/testdata/test.proto
index 440dba3..3d1cbb6 100644
--- a/proto/testdata/test.proto
+++ b/proto/testdata/test.proto
@@ -478,3 +478,17 @@
map<bool, bytes> byte_mapping = 3;
map<string, string> str_to_str = 4;
}
+
+message Communique {
+ optional bool make_me_cry = 1;
+
+ // This is a oneof, called "union".
+ oneof union {
+ int32 number = 5;
+ string name = 6;
+ bytes data = 7;
+ double temp_c = 8;
+ MyMessage.Color col = 9;
+ Strings msg = 10;
+ }
+}
diff --git a/proto/text.go b/proto/text.go
index f3db2cf..51202bd 100644
--- a/proto/text.go
+++ b/proto/text.go
@@ -322,6 +322,23 @@
}
}
+ if fv.Kind() == reflect.Interface {
+ // Check if it is a oneof.
+ if st.Field(i).Tag.Get("protobuf_oneof") != "" {
+ // fv is nil, or holds a pointer to generated struct.
+ // That generated struct has exactly one field,
+ // which has a protobuf struct tag.
+ if fv.IsNil() {
+ continue
+ }
+ inner := fv.Elem().Elem() // interface -> *T -> T
+ tag := inner.Type().Field(0).Tag.Get("protobuf")
+ props.Parse(tag) // Overwrite the outer props.
+ // Write the value in the oneof, not the oneof itself.
+ fv = inner.Field(0)
+ }
+ }
+
if err := writeName(w, props); err != nil {
return err
}
diff --git a/proto/text_parser.go b/proto/text_parser.go
index 7d0c757..2e2a67f 100644
--- a/proto/text_parser.go
+++ b/proto/text_parser.go
@@ -385,8 +385,7 @@
}
// Returns the index in the struct for the named field, as well as the parsed tag properties.
-func structFieldByName(st reflect.Type, name string) (int, *Properties, bool) {
- sprops := GetProperties(st)
+func structFieldByName(sprops *StructProperties, name string) (int, *Properties, bool) {
i, ok := sprops.decoderOrigNames[name]
if ok {
return i, sprops.Prop[i], true
@@ -438,7 +437,8 @@
func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
st := sv.Type()
- reqCount := GetProperties(st).reqCount
+ sprops := GetProperties(st)
+ reqCount := sprops.reqCount
var reqFieldErr error
fieldSet := make(map[string]bool)
// A struct is a sequence of "name: value", terminated by one of
@@ -520,95 +520,129 @@
sl = reflect.Append(sl, ext)
SetExtension(ep, desc, sl.Interface())
}
+ if err := p.consumeOptionalSeparator(); err != nil {
+ return err
+ }
+ continue
+ }
+
+ // This is a normal, non-extension field.
+ name := tok.value
+ var dst reflect.Value
+ fi, props, ok := structFieldByName(sprops, name)
+ if ok {
+ dst = sv.Field(fi)
} else {
- // This is a normal, non-extension field.
- name := tok.value
- fi, props, ok := structFieldByName(st, name)
- if !ok {
- return p.errorf("unknown field name %q in %v", name, st)
+ // Maybe it is a oneof.
+ // TODO: If this shows in profiles, cache the mapping.
+ for _, oot := range sprops.oneofTypes {
+ sp := reflect.ValueOf(oot).Type() // *T
+ var p Properties
+ p.Parse(sp.Elem().Field(0).Tag.Get("protobuf"))
+ if p.OrigName != name {
+ continue
+ }
+ nv := reflect.New(sp.Elem())
+ dst = nv.Elem().Field(0)
+ props = &p
+ // There will be exactly one interface field that
+ // this new value is assignable to.
+ for i := 0; i < st.NumField(); i++ {
+ f := st.Field(i)
+ if f.Type.Kind() != reflect.Interface {
+ continue
+ }
+ if !nv.Type().AssignableTo(f.Type) {
+ continue
+ }
+ sv.Field(i).Set(nv)
+ break
+ }
+ break
}
+ }
+ if !dst.IsValid() {
+ return p.errorf("unknown field name %q in %v", name, st)
+ }
- dst := sv.Field(fi)
-
- if dst.Kind() == reflect.Map {
- // Consume any colon.
- if err := p.checkForColon(props, dst.Type()); err != nil {
- return err
- }
-
- // Construct the map if it doesn't already exist.
- if dst.IsNil() {
- dst.Set(reflect.MakeMap(dst.Type()))
- }
- key := reflect.New(dst.Type().Key()).Elem()
- val := reflect.New(dst.Type().Elem()).Elem()
-
- // The map entry should be this sequence of tokens:
- // < key : KEY value : VALUE >
- // Technically the "key" and "value" could come in any order,
- // but in practice they won't.
-
- tok := p.next()
- var terminator string
- switch tok.value {
- case "<":
- terminator = ">"
- case "{":
- terminator = "}"
- default:
- return p.errorf("expected '{' or '<', found %q", tok.value)
- }
- if err := p.consumeToken("key"); err != nil {
- return err
- }
- if err := p.consumeToken(":"); err != nil {
- return err
- }
- if err := p.readAny(key, props.mkeyprop); err != nil {
- return err
- }
- if err := p.consumeOptionalSeparator(); err != nil {
- return err
- }
- if err := p.consumeToken("value"); err != nil {
- return err
- }
- if err := p.checkForColon(props.mvalprop, dst.Type().Elem()); err != nil {
- return err
- }
- if err := p.readAny(val, props.mvalprop); err != nil {
- return err
- }
- if err := p.consumeOptionalSeparator(); err != nil {
- return err
- }
- if err := p.consumeToken(terminator); err != nil {
- return err
- }
-
- dst.SetMapIndex(key, val)
- continue
- }
-
- // Check that it's not already set if it's not a repeated field.
- if !props.Repeated && fieldSet[name] {
- return p.errorf("non-repeated field %q was repeated", name)
- }
-
- if err := p.checkForColon(props, st.Field(fi).Type); err != nil {
+ if dst.Kind() == reflect.Map {
+ // Consume any colon.
+ if err := p.checkForColon(props, dst.Type()); err != nil {
return err
}
- // Parse into the field.
- fieldSet[name] = true
- if err := p.readAny(dst, props); err != nil {
- if _, ok := err.(*RequiredNotSetError); !ok {
- return err
- }
- reqFieldErr = err
- } else if props.Required {
- reqCount--
+ // Construct the map if it doesn't already exist.
+ if dst.IsNil() {
+ dst.Set(reflect.MakeMap(dst.Type()))
}
+ key := reflect.New(dst.Type().Key()).Elem()
+ val := reflect.New(dst.Type().Elem()).Elem()
+
+ // The map entry should be this sequence of tokens:
+ // < key : KEY value : VALUE >
+ // Technically the "key" and "value" could come in any order,
+ // but in practice they won't.
+
+ tok := p.next()
+ var terminator string
+ switch tok.value {
+ case "<":
+ terminator = ">"
+ case "{":
+ terminator = "}"
+ default:
+ return p.errorf("expected '{' or '<', found %q", tok.value)
+ }
+ if err := p.consumeToken("key"); err != nil {
+ return err
+ }
+ if err := p.consumeToken(":"); err != nil {
+ return err
+ }
+ if err := p.readAny(key, props.mkeyprop); err != nil {
+ return err
+ }
+ if err := p.consumeOptionalSeparator(); err != nil {
+ return err
+ }
+ if err := p.consumeToken("value"); err != nil {
+ return err
+ }
+ if err := p.checkForColon(props.mvalprop, dst.Type().Elem()); err != nil {
+ return err
+ }
+ if err := p.readAny(val, props.mvalprop); err != nil {
+ return err
+ }
+ if err := p.consumeOptionalSeparator(); err != nil {
+ return err
+ }
+ if err := p.consumeToken(terminator); err != nil {
+ return err
+ }
+
+ dst.SetMapIndex(key, val)
+ continue
+ }
+
+ // Check that it's not already set if it's not a repeated field.
+ if !props.Repeated && fieldSet[name] {
+ return p.errorf("non-repeated field %q was repeated", name)
+ }
+
+ if err := p.checkForColon(props, dst.Type()); err != nil {
+ return err
+ }
+
+ // Parse into the field.
+ fieldSet[name] = true
+ if err := p.readAny(dst, props); err != nil {
+ if _, ok := err.(*RequiredNotSetError); !ok {
+ return err
+ }
+ reqFieldErr = err
+ } else if props.Required {
+ reqCount--
}
if err := p.consumeOptionalSeparator(); err != nil {
diff --git a/proto/text_parser_test.go b/proto/text_parser_test.go
index 0754b26..a2a9604 100644
--- a/proto/text_parser_test.go
+++ b/proto/text_parser_test.go
@@ -486,6 +486,18 @@
}
}
+func TestOneofParsing(t *testing.T) {
+ const in = `name:"Shrek"`
+ m := new(Communique)
+ want := &Communique{Union: &Communique_Name{"Shrek"}}
+ if err := UnmarshalText(in, m); err != nil {
+ t.Fatal(err)
+ }
+ if !Equal(m, want) {
+ t.Errorf("\n got %v\nwant %v", m, want)
+ }
+}
+
var benchInput string
func init() {
diff --git a/proto/text_test.go b/proto/text_test.go
index 64579e9..7ff180d 100644
--- a/proto/text_test.go
+++ b/proto/text_test.go
@@ -208,6 +208,28 @@
}
}
+func TestTextOneof(t *testing.T) {
+ tests := []struct {
+ m proto.Message
+ want string
+ }{
+ // zero message
+ {&pb.Communique{}, ``},
+ // scalar field
+ {&pb.Communique{Union: &pb.Communique_Number{4}}, `number:4`},
+ // message field
+ {&pb.Communique{Union: &pb.Communique_Msg{
+ &pb.Strings{StringField: proto.String("why hello!")},
+ }}, `msg:<string_field:"why hello!" >`},
+ }
+ for _, test := range tests {
+ got := strings.TrimSpace(test.m.String())
+ if got != test.want {
+ t.Errorf("\n got %s\nwant %s", got, test.want)
+ }
+ }
+}
+
func BenchmarkMarshalTextBuffered(b *testing.B) {
buf := new(bytes.Buffer)
m := newTestMessage()
diff --git a/protoc-gen-go/descriptor/descriptor.pb.go b/protoc-gen-go/descriptor/descriptor.pb.go
index b5c59d2..f36533d 100644
--- a/protoc-gen-go/descriptor/descriptor.pb.go
+++ b/protoc-gen-go/descriptor/descriptor.pb.go
@@ -31,10 +31,12 @@
package descriptor
import proto "github.com/golang/protobuf/proto"
+import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
+var _ = fmt.Errorf
var _ = math.Inf
type FieldDescriptorProto_Type int32
diff --git a/protoc-gen-go/generator/generator.go b/protoc-gen-go/generator/generator.go
index d00ce88..096f992 100644
--- a/protoc-gen-go/generator/generator.go
+++ b/protoc-gen-go/generator/generator.go
@@ -109,6 +109,7 @@
*descriptor.DescriptorProto
parent *Descriptor // The containing message, if any.
nested []*Descriptor // Inner messages, if any.
+ enums []*EnumDescriptor // Inner enums, if any.
ext []*ExtensionDescriptor // Extensions, if any.
typename []string // Cached typename vector.
index int // The index into the container, whether the file or another message.
@@ -658,6 +659,7 @@
// Register the support package names. They might collide with the
// name of a package we import.
g.Pkg = map[string]string{
+ "fmt": RegisterUniquePackageName("fmt", nil),
"math": RegisterUniquePackageName("math", nil),
"proto": RegisterUniquePackageName("proto", nil),
}
@@ -691,6 +693,7 @@
descs := wrapDescriptors(f)
g.buildNestedDescriptors(descs)
enums := wrapEnumDescriptors(f, descs)
+ g.buildNestedEnums(descs, enums)
exts := wrapExtensions(f)
imps := wrapImported(f, g)
fd := &FileDescriptor{
@@ -726,21 +729,33 @@
func (g *Generator) buildNestedDescriptors(descs []*Descriptor) {
for _, desc := range descs {
if len(desc.NestedType) != 0 {
- desc.nested = make([]*Descriptor, len(desc.NestedType))
- n := 0
for _, nest := range descs {
if nest.parent == desc {
- desc.nested[n] = nest
- n++
+ desc.nested = append(desc.nested, nest)
}
}
- if n != len(desc.NestedType) {
+ if len(desc.nested) != len(desc.NestedType) {
g.Fail("internal error: nesting failure for", desc.GetName())
}
}
}
}
+func (g *Generator) buildNestedEnums(descs []*Descriptor, enums []*EnumDescriptor) {
+ for _, desc := range descs {
+ if len(desc.EnumType) != 0 {
+ for _, enum := range enums {
+ if enum.parent == desc {
+ desc.enums = append(desc.enums, enum)
+ }
+ }
+ if len(desc.enums) != len(desc.EnumType) {
+ g.Fail("internal error: enum nesting failure for", desc.GetName())
+ }
+ }
+ }
+}
+
// Construct the Descriptor
func newDescriptor(desc *descriptor.DescriptorProto, parent *Descriptor, file *descriptor.FileDescriptorProto, index int) *Descriptor {
d := &Descriptor{
@@ -1131,14 +1146,17 @@
// PrintComments prints any comments from the source .proto file.
// The path is a comma-separated list of integers.
+// It returns an indication of whether any comments were printed.
// See descriptor.proto for its format.
-func (g *Generator) PrintComments(path string) {
+func (g *Generator) PrintComments(path string) bool {
if loc, ok := g.file.comments[path]; ok {
text := strings.TrimSuffix(loc.GetLeadingComments(), "\n")
for _, line := range strings.Split(text, "\n") {
g.P("// ", strings.TrimPrefix(line, " "))
}
+ return true
}
+ return false
}
func (g *Generator) fileByName(filename string) *FileDescriptor {
@@ -1164,12 +1182,10 @@
func (g *Generator) generateImports() {
// We almost always need a proto import. Rather than computing when we
// do, which is tricky when there's a plugin, just import it and
- // reference it later. The same argument applies to the math package,
- // for handling bit patterns for floating-point numbers.
+ // reference it later. The same argument applies to the fmt and math packages.
g.P("import " + g.Pkg["proto"] + " " + strconv.Quote(g.ImportPrefix+"github.com/golang/protobuf/proto"))
- if !g.file.proto3 {
- g.P("import " + g.Pkg["math"] + ` "math"`)
- }
+ g.P("import " + g.Pkg["fmt"] + ` "fmt"`)
+ g.P("import " + g.Pkg["math"] + ` "math"`)
for i, s := range g.file.Dependency {
fd := g.fileByName(s)
// Do not import our own package.
@@ -1206,9 +1222,8 @@
}
g.P("// Reference imports to suppress errors if they are not otherwise used.")
g.P("var _ = ", g.Pkg["proto"], ".Marshal")
- if !g.file.proto3 {
- g.P("var _ = ", g.Pkg["math"], ".Inf")
- }
+ g.P("var _ = ", g.Pkg["fmt"], ".Errorf")
+ g.P("var _ = ", g.Pkg["math"], ".Inf")
g.P()
}
@@ -1501,6 +1516,8 @@
typ = "[]" + typ
} else if message != nil && message.proto3() {
return
+ } else if field.OneofIndex != nil && message != nil {
+ return
} else if needsStar(*field.Type) {
typ = "*" + typ
}
@@ -1541,25 +1558,59 @@
}
fieldNames := make(map[*descriptor.FieldDescriptorProto]string)
fieldGetterNames := make(map[*descriptor.FieldDescriptorProto]string)
+ fieldTypes := make(map[*descriptor.FieldDescriptorProto]string)
mapFieldTypes := make(map[*descriptor.FieldDescriptorProto]string)
+ oneofFieldName := make(map[int32]string) // indexed by oneof_index field of FieldDescriptorProto
+ oneofDisc := make(map[int32]string) // name of discriminator method
+ oneofTypeName := make(map[*descriptor.FieldDescriptorProto]string) // without star
+ oneofInsertPoints := make(map[int32]int) // oneof_index => offset of g.Buffer
+
g.PrintComments(message.path)
g.P("type ", ccTypeName, " struct {")
g.In()
- for i, field := range message.Field {
- g.PrintComments(fmt.Sprintf("%s,%d,%d", message.path, messageFieldPath, i))
-
- fieldName := CamelCase(*field.Name)
- for usedNames[fieldName] {
- fieldName += "_"
+ allocName := func(basis string) string {
+ n := CamelCase(basis)
+ for usedNames[n] {
+ n += "_"
}
+ usedNames[n] = true
+ return n
+ }
+
+ for i, field := range message.Field {
+ fieldName := allocName(*field.Name)
fieldGetterName := fieldName
- usedNames[fieldName] = true
typename, wiretype := g.GoType(message, field)
jsonName := *field.Name
tag := fmt.Sprintf("protobuf:%s json:%q", g.goTag(message, field, wiretype), jsonName+",omitempty")
+ fieldNames[field] = fieldName
+ fieldGetterNames[field] = fieldGetterName
+
+ oneof := field.OneofIndex != nil
+ if oneof && oneofFieldName[*field.OneofIndex] == "" {
+ odp := message.OneofDecl[int(*field.OneofIndex)]
+ fname := allocName(odp.GetName())
+
+ // This is the first field of a oneof we haven't seen before.
+ // Generate the union field.
+ com := g.PrintComments(fmt.Sprintf("%s,%d,%d", message.path, messageOneofPath, *field.OneofIndex))
+ if com {
+ g.P("//")
+ }
+ g.P("// Types that are valid to be assigned to ", fname, ":")
+ // Generate the rest of this comment later,
+ // when we've computed any disambiguation.
+ oneofInsertPoints[*field.OneofIndex] = g.Buffer.Len()
+
+ dname := "is" + ccTypeName + "_" + fname
+ oneofFieldName[*field.OneofIndex] = fname
+ oneofDisc[*field.OneofIndex] = dname
+ g.P(fname, " ", dname, " `protobuf_oneof:\"", odp.GetName(), "\"`")
+ }
+
if *field.Type == descriptor.FieldDescriptorProto_TYPE_MESSAGE {
desc := g.ObjectNamed(field.GetTypeName())
if d, ok := desc.(*Descriptor); ok && d.GetOptions().GetMapEntry() {
@@ -1590,8 +1641,38 @@
}
}
- fieldNames[field] = fieldName
- fieldGetterNames[field] = fieldGetterName
+ fieldTypes[field] = typename
+
+ if oneof {
+ tname := ccTypeName + "_" + fieldName
+ // It is possible for this to collide with a message or enum
+ // nested in this message. Check for collisions.
+ for {
+ ok := true
+ for _, desc := range message.nested {
+ if strings.Join(desc.TypeName(), "_") == tname {
+ ok = false
+ break
+ }
+ }
+ for _, enum := range message.enums {
+ if strings.Join(enum.TypeName(), "_") == tname {
+ ok = false
+ break
+ }
+ }
+ if !ok {
+ tname += "_"
+ continue
+ }
+ break
+ }
+
+ oneofTypeName[field] = tname
+ continue
+ }
+
+ g.PrintComments(fmt.Sprintf("%s,%d,%d", message.path, messageFieldPath, i))
g.P(fieldName, "\t", typename, "\t`", tag, "`")
g.RecordTypeUse(field.GetTypeName())
}
@@ -1604,6 +1685,23 @@
g.Out()
g.P("}")
+ // Update g.Buffer to list valid oneof types.
+ // We do this down here, after we've disambiguated the oneof type names.
+ // We go in reverse order of insertion point to avoid invalidating offsets.
+ for oi := int32(len(message.OneofDecl)); oi >= 0; oi-- {
+ ip := oneofInsertPoints[oi]
+ all := g.Buffer.Bytes()
+ rem := all[ip:]
+ g.Buffer = bytes.NewBuffer(all[:ip:ip]) // set cap so we don't scribble on rem
+ for _, field := range message.Field {
+ if field.OneofIndex == nil || *field.OneofIndex != oi {
+ continue
+ }
+ g.P("//\t*", oneofTypeName[field])
+ }
+ g.Buffer.Write(rem)
+ }
+
// Reset, String and ProtoMessage methods.
g.P("func (m *", ccTypeName, ") Reset() { *m = ", ccTypeName, "{} }")
g.P("func (m *", ccTypeName, ") String() string { return ", g.Pkg["proto"], ".CompactTextString(m) }")
@@ -1724,9 +1822,48 @@
}
g.P()
+ // Oneof per-field types, discriminants and getters.
+ //
+ // Generate unexported named types for the discriminant interfaces.
+ // We shouldn't have to do this, but there was (~19 Aug 2015) a compiler/linker bug
+ // that was triggered by using anonymous interfaces here.
+ // TODO: Revisit this and consider reverting back to anonymous interfaces.
+ for oi := range message.OneofDecl {
+ dname := oneofDisc[int32(oi)]
+ g.P("type ", dname, " interface { ", dname, "() }")
+ }
+ g.P()
+ for _, field := range message.Field {
+ if field.OneofIndex == nil {
+ continue
+ }
+ _, wiretype := g.GoType(message, field)
+ tag := "protobuf:" + g.goTag(message, field, wiretype)
+ g.P("type ", oneofTypeName[field], " struct{ ", fieldNames[field], " ", fieldTypes[field], " `", tag, "` }")
+ g.RecordTypeUse(field.GetTypeName())
+ }
+ g.P()
+ for _, field := range message.Field {
+ if field.OneofIndex == nil {
+ continue
+ }
+ g.P("func (*", oneofTypeName[field], ") ", oneofDisc[*field.OneofIndex], "() {}")
+ }
+ g.P()
+ for oi := range message.OneofDecl {
+ fname := oneofFieldName[int32(oi)]
+ g.P("func (m *", ccTypeName, ") Get", fname, "() ", oneofDisc[int32(oi)], " {")
+ g.P("if m != nil { return m.", fname, " }")
+ g.P("return nil")
+ g.P("}")
+ }
+ g.P()
+
// Field getters
var getters []getterSymbol
for _, field := range message.Field {
+ oneof := field.OneofIndex != nil
+
fname := fieldNames[field]
typename, _ := g.GoType(message, field)
if t, ok := mapFieldTypes[field]; ok {
@@ -1739,8 +1876,8 @@
star = "*"
}
- // In proto3, only generate getters for message fields.
- if message.proto3() && *field.Type != descriptor.FieldDescriptorProto_TYPE_MESSAGE {
+ // In proto3, only generate getters for message fields and oneof fields.
+ if message.proto3() && *field.Type != descriptor.FieldDescriptorProto_TYPE_MESSAGE && !oneof {
continue
}
@@ -1787,7 +1924,7 @@
if isRepeated(field) {
typeDefaultIsNil = true
}
- if typeDefaultIsNil {
+ if typeDefaultIsNil && !oneof {
// A bytes field with no explicit default needs less generated code,
// as does a message or group field, or a repeated field.
g.P("if m != nil {")
@@ -1801,11 +1938,19 @@
g.P()
continue
}
- g.P("if m != nil && m." + fname + " != nil {")
- g.In()
- g.P("return " + star + "m." + fname)
- g.Out()
- g.P("}")
+ if !oneof {
+ g.P("if m != nil && m." + fname + " != nil {")
+ g.In()
+ g.P("return " + star + "m." + fname)
+ g.Out()
+ g.P("}")
+ } else {
+ uname := oneofFieldName[*field.OneofIndex]
+ tname := oneofTypeName[field]
+ g.P("if x, ok := m.Get", uname, "().(*", tname, "); ok {")
+ g.P("return x.", fname)
+ g.P("}")
+ }
if hasDef {
if *field.Type != descriptor.FieldDescriptorProto_TYPE_BYTES {
g.P("return " + def)
@@ -1820,6 +1965,11 @@
g.P("return false")
case descriptor.FieldDescriptorProto_TYPE_STRING:
g.P(`return ""`)
+ case descriptor.FieldDescriptorProto_TYPE_GROUP,
+ descriptor.FieldDescriptorProto_TYPE_MESSAGE,
+ descriptor.FieldDescriptorProto_TYPE_BYTES:
+ // This is only possible for oneof fields.
+ g.P("return nil")
case descriptor.FieldDescriptorProto_TYPE_ENUM:
// The default default for an enum is the first value in the enum,
// not zero.
@@ -1855,6 +2005,207 @@
g.file.addExport(message, ms)
}
+ // Oneof functions
+ if len(message.OneofDecl) > 0 {
+ fieldWire := make(map[*descriptor.FieldDescriptorProto]string)
+
+ // method
+ enc := "_" + ccTypeName + "_OneofMarshaler"
+ dec := "_" + ccTypeName + "_OneofUnmarshaler"
+ encSig := "(msg " + g.Pkg["proto"] + ".Message, b *" + g.Pkg["proto"] + ".Buffer) error"
+ decSig := "(msg " + g.Pkg["proto"] + ".Message, tag, wire int, b *" + g.Pkg["proto"] + ".Buffer) (bool, error)"
+
+ g.P("// XXX_OneofFuncs is for the internal use of the proto package.")
+ g.P("func (*", ccTypeName, ") XXX_OneofFuncs() (func", encSig, ", func", decSig, ", []interface{}) {")
+ g.P("return ", enc, ", ", dec, ", []interface{}{")
+ for _, field := range message.Field {
+ if field.OneofIndex == nil {
+ continue
+ }
+ g.P("(*", oneofTypeName[field], ")(nil),")
+ }
+ g.P("}")
+ g.P("}")
+ g.P()
+
+ // marshaler
+ g.P("func ", enc, encSig, " {")
+ g.P("m := msg.(*", ccTypeName, ")")
+ for oi, odp := range message.OneofDecl {
+ g.P("// ", odp.GetName())
+ fname := oneofFieldName[int32(oi)]
+ g.P("switch x := m.", fname, ".(type) {")
+ for _, field := range message.Field {
+ if field.OneofIndex == nil || int(*field.OneofIndex) != oi {
+ continue
+ }
+ g.P("case *", oneofTypeName[field], ":")
+ var wire, pre, post string
+ val := "x." + fieldNames[field] // overridden for TYPE_BOOL
+ canFail := false // only TYPE_MESSAGE and TYPE_GROUP can fail
+ switch *field.Type {
+ case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
+ wire = "WireFixed64"
+ pre = "b.EncodeFixed64(" + g.Pkg["math"] + ".Float64bits("
+ post = "))"
+ case descriptor.FieldDescriptorProto_TYPE_FLOAT:
+ wire = "WireFixed32"
+ pre = "b.EncodeFixed32(uint64(" + g.Pkg["math"] + ".Float32bits("
+ post = ")))"
+ case descriptor.FieldDescriptorProto_TYPE_INT64,
+ descriptor.FieldDescriptorProto_TYPE_UINT64:
+ wire = "WireVarint"
+ pre, post = "b.EncodeVarint(uint64(", "))"
+ case descriptor.FieldDescriptorProto_TYPE_INT32,
+ descriptor.FieldDescriptorProto_TYPE_UINT32,
+ descriptor.FieldDescriptorProto_TYPE_ENUM:
+ wire = "WireVarint"
+ pre, post = "b.EncodeVarint(uint64(", "))"
+ case descriptor.FieldDescriptorProto_TYPE_FIXED64,
+ descriptor.FieldDescriptorProto_TYPE_SFIXED64:
+ wire = "WireFixed64"
+ pre, post = "b.EncodeFixed64(uint64(", "))"
+ case descriptor.FieldDescriptorProto_TYPE_FIXED32,
+ descriptor.FieldDescriptorProto_TYPE_SFIXED32:
+ wire = "WireFixed32"
+ pre, post = "b.EncodeFixed32(uint64(", "))"
+ case descriptor.FieldDescriptorProto_TYPE_BOOL:
+ // bool needs special handling.
+ g.P("t := uint64(0)")
+ g.P("if ", val, " { t = 1 }")
+ val = "t"
+ wire = "WireVarint"
+ pre, post = "b.EncodeVarint(", ")"
+ case descriptor.FieldDescriptorProto_TYPE_STRING:
+ wire = "WireBytes"
+ pre, post = "b.EncodeStringBytes(", ")"
+ case descriptor.FieldDescriptorProto_TYPE_GROUP:
+ wire = "WireStartGroup"
+ pre, post = "b.Marshal(", ")"
+ canFail = true
+ case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
+ wire = "WireBytes"
+ pre, post = "b.EncodeMessage(", ")"
+ canFail = true
+ case descriptor.FieldDescriptorProto_TYPE_BYTES:
+ wire = "WireBytes"
+ pre, post = "b.EncodeRawBytes(", ")"
+ case descriptor.FieldDescriptorProto_TYPE_SINT32:
+ wire = "WireVarint"
+ pre, post = "b.EncodeZigzag32(uint64(", "))"
+ case descriptor.FieldDescriptorProto_TYPE_SINT64:
+ wire = "WireVarint"
+ pre, post = "b.EncodeZigzag64(uint64(", "))"
+ default:
+ g.Fail("unhandled oneof field type ", field.Type.String())
+ }
+ fieldWire[field] = wire
+ g.P("b.EncodeVarint(", field.Number, "<<3|", g.Pkg["proto"], ".", wire, ")")
+ if !canFail {
+ g.P(pre, val, post)
+ } else {
+ g.P("if err := ", pre, val, post, "; err != nil {")
+ g.P("return err")
+ g.P("}")
+ }
+ if *field.Type == descriptor.FieldDescriptorProto_TYPE_GROUP {
+ g.P("b.EncodeVarint(", field.Number, "<<3|", g.Pkg["proto"], ".WireEndGroup)")
+ }
+ }
+ g.P("case nil:")
+ g.P("default: return ", g.Pkg["fmt"], `.Errorf("`, ccTypeName, ".", fname, ` has unexpected type %T", x)`)
+ g.P("}")
+ }
+ g.P("return nil")
+ g.P("}")
+ g.P()
+
+ // unmarshaler
+ g.P("func ", dec, decSig, " {")
+ g.P("m := msg.(*", ccTypeName, ")")
+ g.P("switch tag {")
+ for _, field := range message.Field {
+ if field.OneofIndex == nil {
+ continue
+ }
+ odp := message.OneofDecl[int(*field.OneofIndex)]
+ g.P("case ", field.Number, ": // ", odp.GetName(), ".", *field.Name)
+ g.P("if wire != ", g.Pkg["proto"], ".", fieldWire[field], " {")
+ g.P("return true, ", g.Pkg["proto"], ".ErrInternalBadWireType")
+ g.P("}")
+ lhs := "x, err" // overridden for TYPE_MESSAGE and TYPE_GROUP
+ var dec, cast, cast2 string
+ switch *field.Type {
+ case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
+ dec, cast = "b.DecodeFixed64()", g.Pkg["math"]+".Float64frombits"
+ case descriptor.FieldDescriptorProto_TYPE_FLOAT:
+ dec, cast, cast2 = "b.DecodeFixed32()", "uint32", g.Pkg["math"]+".Float32frombits"
+ case descriptor.FieldDescriptorProto_TYPE_INT64:
+ dec, cast = "b.DecodeVarint()", "int64"
+ case descriptor.FieldDescriptorProto_TYPE_UINT64:
+ dec = "b.DecodeVarint()"
+ case descriptor.FieldDescriptorProto_TYPE_INT32:
+ dec, cast = "b.DecodeVarint()", "int32"
+ case descriptor.FieldDescriptorProto_TYPE_FIXED64:
+ dec = "b.DecodeFixed64()"
+ case descriptor.FieldDescriptorProto_TYPE_FIXED32:
+ dec, cast = "b.DecodeFixed32()", "uint32"
+ case descriptor.FieldDescriptorProto_TYPE_BOOL:
+ dec = "b.DecodeVarint()"
+ // handled specially below
+ case descriptor.FieldDescriptorProto_TYPE_STRING:
+ dec = "b.DecodeStringBytes()"
+ case descriptor.FieldDescriptorProto_TYPE_GROUP:
+ g.P("msg := new(", fieldTypes[field][1:], ")") // drop star
+ lhs = "err"
+ dec = "b.DecodeGroup(msg)"
+ // handled specially below
+ case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
+ g.P("msg := new(", fieldTypes[field][1:], ")") // drop star
+ lhs = "err"
+ dec = "b.DecodeMessage(msg)"
+ // handled specially below
+ case descriptor.FieldDescriptorProto_TYPE_BYTES:
+ dec = "b.DecodeRawBytes(true)"
+ case descriptor.FieldDescriptorProto_TYPE_UINT32:
+ dec, cast = "b.DecodeVarint()", "uint32"
+ case descriptor.FieldDescriptorProto_TYPE_ENUM:
+ dec, cast = "b.DecodeVarint()", fieldTypes[field]
+ case descriptor.FieldDescriptorProto_TYPE_SFIXED32:
+ dec, cast = "b.DecodeFixed32()", "int32"
+ case descriptor.FieldDescriptorProto_TYPE_SFIXED64:
+ dec, cast = "b.DecodeFixed64()", "int64"
+ case descriptor.FieldDescriptorProto_TYPE_SINT32:
+ dec, cast = "b.DecodeZigzag32()", "int32"
+ case descriptor.FieldDescriptorProto_TYPE_SINT64:
+ dec, cast = "b.DecodeZigzag64()", "int64"
+ default:
+ g.Fail("unhandled oneof field type ", field.Type.String())
+ }
+ g.P(lhs, " := ", dec)
+ val := "x"
+ if cast != "" {
+ val = cast + "(" + val + ")"
+ }
+ if cast2 != "" {
+ val = cast2 + "(" + val + ")"
+ }
+ switch *field.Type {
+ case descriptor.FieldDescriptorProto_TYPE_BOOL:
+ val += " != 0"
+ case descriptor.FieldDescriptorProto_TYPE_GROUP,
+ descriptor.FieldDescriptorProto_TYPE_MESSAGE:
+ val = "msg"
+ }
+ g.P("m.", oneofFieldName[*field.OneofIndex], " = &", oneofTypeName[field], "{", val, "}")
+ g.P("return true, err")
+ }
+ g.P("default: return false, nil")
+ g.P("}")
+ g.P("}")
+ g.P()
+ }
+
for _, ext := range message.ext {
g.generateExtension(ext)
}
@@ -2093,6 +2444,7 @@
messageFieldPath = 2 // field
messageMessagePath = 3 // nested_type
messageEnumPath = 4 // enum_type
+ messageOneofPath = 8 // oneof_decl
// tag numbers in EnumDescriptorProto
enumValuePath = 2 // value
)
diff --git a/protoc-gen-go/plugin/plugin.pb.go b/protoc-gen-go/plugin/plugin.pb.go
index 31ff1f5..52bff88 100644
--- a/protoc-gen-go/plugin/plugin.pb.go
+++ b/protoc-gen-go/plugin/plugin.pb.go
@@ -15,11 +15,13 @@
package google_protobuf_compiler
import proto "github.com/golang/protobuf/proto"
+import fmt "fmt"
import math "math"
import google_protobuf "github.com/golang/protobuf/protoc-gen-go/descriptor"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
+var _ = fmt.Errorf
var _ = math.Inf
// An encoded CodeGeneratorRequest is written to the plugin's stdin.
@@ -184,6 +186,3 @@
}
return ""
}
-
-func init() {
-}
diff --git a/protoc-gen-go/testdata/my_test/test.pb.go b/protoc-gen-go/testdata/my_test/test.pb.go
index 4500414..d7ea8c8 100644
--- a/protoc-gen-go/testdata/my_test/test.pb.go
+++ b/protoc-gen-go/testdata/my_test/test.pb.go
@@ -17,16 +17,19 @@
ReplyExtensions
OtherReplyExtensions
OldReply
+ Communique
*/
package my_test
import proto "github.com/golang/protobuf/proto"
+import fmt "fmt"
import math "math"
// discarding unused import multitest2 "multi"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
+var _ = fmt.Errorf
var _ = math.Inf
type HatType int32
@@ -458,6 +461,329 @@
return m.XXX_extensions
}
+type Communique struct {
+ MakeMeCry *bool `protobuf:"varint,1,opt,name=make_me_cry" json:"make_me_cry,omitempty"`
+ // This is a oneof, called "union".
+ //
+ // Types that are valid to be assigned to Union:
+ // *Communique_Number
+ // *Communique_Name
+ // *Communique_Data
+ // *Communique_TempC
+ // *Communique_Height
+ // *Communique_Today
+ // *Communique_Maybe
+ // *Communique_Delta_
+ // *Communique_Msg
+ // *Communique_Somegroup
+ Union isCommunique_Union `protobuf_oneof:"union"`
+ XXX_unrecognized []byte `json:"-"`
+}
+
+func (m *Communique) Reset() { *m = Communique{} }
+func (m *Communique) String() string { return proto.CompactTextString(m) }
+func (*Communique) ProtoMessage() {}
+
+type isCommunique_Union interface {
+ isCommunique_Union()
+}
+
+type Communique_Number struct {
+ Number int32 `protobuf:"varint,5,opt,name=number"`
+}
+type Communique_Name struct {
+ Name string `protobuf:"bytes,6,opt,name=name"`
+}
+type Communique_Data struct {
+ Data []byte `protobuf:"bytes,7,opt,name=data"`
+}
+type Communique_TempC struct {
+ TempC float64 `protobuf:"fixed64,8,opt,name=temp_c"`
+}
+type Communique_Height struct {
+ Height float32 `protobuf:"fixed32,9,opt,name=height"`
+}
+type Communique_Today struct {
+ Today Days `protobuf:"varint,10,opt,name=today,enum=my.test.Days"`
+}
+type Communique_Maybe struct {
+ Maybe bool `protobuf:"varint,11,opt,name=maybe"`
+}
+type Communique_Delta_ struct {
+ Delta int32 `protobuf:"zigzag32,12,opt,name=delta"`
+}
+type Communique_Msg struct {
+ Msg *Reply `protobuf:"bytes,13,opt,name=msg"`
+}
+type Communique_Somegroup struct {
+ Somegroup *Communique_SomeGroup `protobuf:"group,14,opt,name=SomeGroup"`
+}
+
+func (*Communique_Number) isCommunique_Union() {}
+func (*Communique_Name) isCommunique_Union() {}
+func (*Communique_Data) isCommunique_Union() {}
+func (*Communique_TempC) isCommunique_Union() {}
+func (*Communique_Height) isCommunique_Union() {}
+func (*Communique_Today) isCommunique_Union() {}
+func (*Communique_Maybe) isCommunique_Union() {}
+func (*Communique_Delta_) isCommunique_Union() {}
+func (*Communique_Msg) isCommunique_Union() {}
+func (*Communique_Somegroup) isCommunique_Union() {}
+
+func (m *Communique) GetUnion() isCommunique_Union {
+ if m != nil {
+ return m.Union
+ }
+ return nil
+}
+
+func (m *Communique) GetMakeMeCry() bool {
+ if m != nil && m.MakeMeCry != nil {
+ return *m.MakeMeCry
+ }
+ return false
+}
+
+func (m *Communique) GetNumber() int32 {
+ if x, ok := m.GetUnion().(*Communique_Number); ok {
+ return x.Number
+ }
+ return 0
+}
+
+func (m *Communique) GetName() string {
+ if x, ok := m.GetUnion().(*Communique_Name); ok {
+ return x.Name
+ }
+ return ""
+}
+
+func (m *Communique) GetData() []byte {
+ if x, ok := m.GetUnion().(*Communique_Data); ok {
+ return x.Data
+ }
+ return nil
+}
+
+func (m *Communique) GetTempC() float64 {
+ if x, ok := m.GetUnion().(*Communique_TempC); ok {
+ return x.TempC
+ }
+ return 0
+}
+
+func (m *Communique) GetHeight() float32 {
+ if x, ok := m.GetUnion().(*Communique_Height); ok {
+ return x.Height
+ }
+ return 0
+}
+
+func (m *Communique) GetToday() Days {
+ if x, ok := m.GetUnion().(*Communique_Today); ok {
+ return x.Today
+ }
+ return Days_MONDAY
+}
+
+func (m *Communique) GetMaybe() bool {
+ if x, ok := m.GetUnion().(*Communique_Maybe); ok {
+ return x.Maybe
+ }
+ return false
+}
+
+func (m *Communique) GetDelta() int32 {
+ if x, ok := m.GetUnion().(*Communique_Delta_); ok {
+ return x.Delta
+ }
+ return 0
+}
+
+func (m *Communique) GetMsg() *Reply {
+ if x, ok := m.GetUnion().(*Communique_Msg); ok {
+ return x.Msg
+ }
+ return nil
+}
+
+func (m *Communique) GetSomegroup() *Communique_SomeGroup {
+ if x, ok := m.GetUnion().(*Communique_Somegroup); ok {
+ return x.Somegroup
+ }
+ return nil
+}
+
+// XXX_OneofFuncs is for the internal use of the proto package.
+func (*Communique) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), []interface{}) {
+ return _Communique_OneofMarshaler, _Communique_OneofUnmarshaler, []interface{}{
+ (*Communique_Number)(nil),
+ (*Communique_Name)(nil),
+ (*Communique_Data)(nil),
+ (*Communique_TempC)(nil),
+ (*Communique_Height)(nil),
+ (*Communique_Today)(nil),
+ (*Communique_Maybe)(nil),
+ (*Communique_Delta_)(nil),
+ (*Communique_Msg)(nil),
+ (*Communique_Somegroup)(nil),
+ }
+}
+
+func _Communique_OneofMarshaler(msg proto.Message, b *proto.Buffer) error {
+ m := msg.(*Communique)
+ // union
+ switch x := m.Union.(type) {
+ case *Communique_Number:
+ b.EncodeVarint(5<<3 | proto.WireVarint)
+ b.EncodeVarint(uint64(x.Number))
+ case *Communique_Name:
+ b.EncodeVarint(6<<3 | proto.WireBytes)
+ b.EncodeStringBytes(x.Name)
+ case *Communique_Data:
+ b.EncodeVarint(7<<3 | proto.WireBytes)
+ b.EncodeRawBytes(x.Data)
+ case *Communique_TempC:
+ b.EncodeVarint(8<<3 | proto.WireFixed64)
+ b.EncodeFixed64(math.Float64bits(x.TempC))
+ case *Communique_Height:
+ b.EncodeVarint(9<<3 | proto.WireFixed32)
+ b.EncodeFixed32(uint64(math.Float32bits(x.Height)))
+ case *Communique_Today:
+ b.EncodeVarint(10<<3 | proto.WireVarint)
+ b.EncodeVarint(uint64(x.Today))
+ case *Communique_Maybe:
+ t := uint64(0)
+ if x.Maybe {
+ t = 1
+ }
+ b.EncodeVarint(11<<3 | proto.WireVarint)
+ b.EncodeVarint(t)
+ case *Communique_Delta_:
+ b.EncodeVarint(12<<3 | proto.WireVarint)
+ b.EncodeZigzag32(uint64(x.Delta))
+ case *Communique_Msg:
+ b.EncodeVarint(13<<3 | proto.WireBytes)
+ if err := b.EncodeMessage(x.Msg); err != nil {
+ return err
+ }
+ case *Communique_Somegroup:
+ b.EncodeVarint(14<<3 | proto.WireStartGroup)
+ if err := b.Marshal(x.Somegroup); err != nil {
+ return err
+ }
+ b.EncodeVarint(14<<3 | proto.WireEndGroup)
+ case nil:
+ default:
+ return fmt.Errorf("Communique.Union has unexpected type %T", x)
+ }
+ return nil
+}
+
+func _Communique_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) {
+ m := msg.(*Communique)
+ switch tag {
+ case 5: // union.number
+ if wire != proto.WireVarint {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeVarint()
+ m.Union = &Communique_Number{int32(x)}
+ return true, err
+ case 6: // union.name
+ if wire != proto.WireBytes {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeStringBytes()
+ m.Union = &Communique_Name{x}
+ return true, err
+ case 7: // union.data
+ if wire != proto.WireBytes {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeRawBytes(true)
+ m.Union = &Communique_Data{x}
+ return true, err
+ case 8: // union.temp_c
+ if wire != proto.WireFixed64 {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeFixed64()
+ m.Union = &Communique_TempC{math.Float64frombits(x)}
+ return true, err
+ case 9: // union.height
+ if wire != proto.WireFixed32 {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeFixed32()
+ m.Union = &Communique_Height{math.Float32frombits(uint32(x))}
+ return true, err
+ case 10: // union.today
+ if wire != proto.WireVarint {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeVarint()
+ m.Union = &Communique_Today{Days(x)}
+ return true, err
+ case 11: // union.maybe
+ if wire != proto.WireVarint {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeVarint()
+ m.Union = &Communique_Maybe{x != 0}
+ return true, err
+ case 12: // union.delta
+ if wire != proto.WireVarint {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeZigzag32()
+ m.Union = &Communique_Delta_{int32(x)}
+ return true, err
+ case 13: // union.msg
+ if wire != proto.WireBytes {
+ return true, proto.ErrInternalBadWireType
+ }
+ msg := new(Reply)
+ err := b.DecodeMessage(msg)
+ m.Union = &Communique_Msg{msg}
+ return true, err
+ case 14: // union.somegroup
+ if wire != proto.WireStartGroup {
+ return true, proto.ErrInternalBadWireType
+ }
+ msg := new(Communique_SomeGroup)
+ err := b.DecodeGroup(msg)
+ m.Union = &Communique_Somegroup{msg}
+ return true, err
+ default:
+ return false, nil
+ }
+}
+
+type Communique_SomeGroup struct {
+ Member *string `protobuf:"bytes,15,opt,name=member" json:"member,omitempty"`
+ XXX_unrecognized []byte `json:"-"`
+}
+
+func (m *Communique_SomeGroup) Reset() { *m = Communique_SomeGroup{} }
+func (m *Communique_SomeGroup) String() string { return proto.CompactTextString(m) }
+func (*Communique_SomeGroup) ProtoMessage() {}
+
+func (m *Communique_SomeGroup) GetMember() string {
+ if m != nil && m.Member != nil {
+ return *m.Member
+ }
+ return ""
+}
+
+type Communique_Delta struct {
+ XXX_unrecognized []byte `json:"-"`
+}
+
+func (m *Communique_Delta) Reset() { *m = Communique_Delta{} }
+func (m *Communique_Delta) String() string { return proto.CompactTextString(m) }
+func (*Communique_Delta) ProtoMessage() {}
+
var E_Tag = &proto.ExtensionDesc{
ExtendedType: (*Reply)(nil),
ExtensionType: (*string)(nil),
diff --git a/protoc-gen-go/testdata/my_test/test.pb.go.golden b/protoc-gen-go/testdata/my_test/test.pb.go.golden
index 4500414..d7ea8c8 100644
--- a/protoc-gen-go/testdata/my_test/test.pb.go.golden
+++ b/protoc-gen-go/testdata/my_test/test.pb.go.golden
@@ -17,16 +17,19 @@
ReplyExtensions
OtherReplyExtensions
OldReply
+ Communique
*/
package my_test
import proto "github.com/golang/protobuf/proto"
+import fmt "fmt"
import math "math"
// discarding unused import multitest2 "multi"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
+var _ = fmt.Errorf
var _ = math.Inf
type HatType int32
@@ -458,6 +461,329 @@
return m.XXX_extensions
}
+type Communique struct {
+ MakeMeCry *bool `protobuf:"varint,1,opt,name=make_me_cry" json:"make_me_cry,omitempty"`
+ // This is a oneof, called "union".
+ //
+ // Types that are valid to be assigned to Union:
+ // *Communique_Number
+ // *Communique_Name
+ // *Communique_Data
+ // *Communique_TempC
+ // *Communique_Height
+ // *Communique_Today
+ // *Communique_Maybe
+ // *Communique_Delta_
+ // *Communique_Msg
+ // *Communique_Somegroup
+ Union isCommunique_Union `protobuf_oneof:"union"`
+ XXX_unrecognized []byte `json:"-"`
+}
+
+func (m *Communique) Reset() { *m = Communique{} }
+func (m *Communique) String() string { return proto.CompactTextString(m) }
+func (*Communique) ProtoMessage() {}
+
+type isCommunique_Union interface {
+ isCommunique_Union()
+}
+
+type Communique_Number struct {
+ Number int32 `protobuf:"varint,5,opt,name=number"`
+}
+type Communique_Name struct {
+ Name string `protobuf:"bytes,6,opt,name=name"`
+}
+type Communique_Data struct {
+ Data []byte `protobuf:"bytes,7,opt,name=data"`
+}
+type Communique_TempC struct {
+ TempC float64 `protobuf:"fixed64,8,opt,name=temp_c"`
+}
+type Communique_Height struct {
+ Height float32 `protobuf:"fixed32,9,opt,name=height"`
+}
+type Communique_Today struct {
+ Today Days `protobuf:"varint,10,opt,name=today,enum=my.test.Days"`
+}
+type Communique_Maybe struct {
+ Maybe bool `protobuf:"varint,11,opt,name=maybe"`
+}
+type Communique_Delta_ struct {
+ Delta int32 `protobuf:"zigzag32,12,opt,name=delta"`
+}
+type Communique_Msg struct {
+ Msg *Reply `protobuf:"bytes,13,opt,name=msg"`
+}
+type Communique_Somegroup struct {
+ Somegroup *Communique_SomeGroup `protobuf:"group,14,opt,name=SomeGroup"`
+}
+
+func (*Communique_Number) isCommunique_Union() {}
+func (*Communique_Name) isCommunique_Union() {}
+func (*Communique_Data) isCommunique_Union() {}
+func (*Communique_TempC) isCommunique_Union() {}
+func (*Communique_Height) isCommunique_Union() {}
+func (*Communique_Today) isCommunique_Union() {}
+func (*Communique_Maybe) isCommunique_Union() {}
+func (*Communique_Delta_) isCommunique_Union() {}
+func (*Communique_Msg) isCommunique_Union() {}
+func (*Communique_Somegroup) isCommunique_Union() {}
+
+func (m *Communique) GetUnion() isCommunique_Union {
+ if m != nil {
+ return m.Union
+ }
+ return nil
+}
+
+func (m *Communique) GetMakeMeCry() bool {
+ if m != nil && m.MakeMeCry != nil {
+ return *m.MakeMeCry
+ }
+ return false
+}
+
+func (m *Communique) GetNumber() int32 {
+ if x, ok := m.GetUnion().(*Communique_Number); ok {
+ return x.Number
+ }
+ return 0
+}
+
+func (m *Communique) GetName() string {
+ if x, ok := m.GetUnion().(*Communique_Name); ok {
+ return x.Name
+ }
+ return ""
+}
+
+func (m *Communique) GetData() []byte {
+ if x, ok := m.GetUnion().(*Communique_Data); ok {
+ return x.Data
+ }
+ return nil
+}
+
+func (m *Communique) GetTempC() float64 {
+ if x, ok := m.GetUnion().(*Communique_TempC); ok {
+ return x.TempC
+ }
+ return 0
+}
+
+func (m *Communique) GetHeight() float32 {
+ if x, ok := m.GetUnion().(*Communique_Height); ok {
+ return x.Height
+ }
+ return 0
+}
+
+func (m *Communique) GetToday() Days {
+ if x, ok := m.GetUnion().(*Communique_Today); ok {
+ return x.Today
+ }
+ return Days_MONDAY
+}
+
+func (m *Communique) GetMaybe() bool {
+ if x, ok := m.GetUnion().(*Communique_Maybe); ok {
+ return x.Maybe
+ }
+ return false
+}
+
+func (m *Communique) GetDelta() int32 {
+ if x, ok := m.GetUnion().(*Communique_Delta_); ok {
+ return x.Delta
+ }
+ return 0
+}
+
+func (m *Communique) GetMsg() *Reply {
+ if x, ok := m.GetUnion().(*Communique_Msg); ok {
+ return x.Msg
+ }
+ return nil
+}
+
+func (m *Communique) GetSomegroup() *Communique_SomeGroup {
+ if x, ok := m.GetUnion().(*Communique_Somegroup); ok {
+ return x.Somegroup
+ }
+ return nil
+}
+
+// XXX_OneofFuncs is for the internal use of the proto package.
+func (*Communique) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), []interface{}) {
+ return _Communique_OneofMarshaler, _Communique_OneofUnmarshaler, []interface{}{
+ (*Communique_Number)(nil),
+ (*Communique_Name)(nil),
+ (*Communique_Data)(nil),
+ (*Communique_TempC)(nil),
+ (*Communique_Height)(nil),
+ (*Communique_Today)(nil),
+ (*Communique_Maybe)(nil),
+ (*Communique_Delta_)(nil),
+ (*Communique_Msg)(nil),
+ (*Communique_Somegroup)(nil),
+ }
+}
+
+func _Communique_OneofMarshaler(msg proto.Message, b *proto.Buffer) error {
+ m := msg.(*Communique)
+ // union
+ switch x := m.Union.(type) {
+ case *Communique_Number:
+ b.EncodeVarint(5<<3 | proto.WireVarint)
+ b.EncodeVarint(uint64(x.Number))
+ case *Communique_Name:
+ b.EncodeVarint(6<<3 | proto.WireBytes)
+ b.EncodeStringBytes(x.Name)
+ case *Communique_Data:
+ b.EncodeVarint(7<<3 | proto.WireBytes)
+ b.EncodeRawBytes(x.Data)
+ case *Communique_TempC:
+ b.EncodeVarint(8<<3 | proto.WireFixed64)
+ b.EncodeFixed64(math.Float64bits(x.TempC))
+ case *Communique_Height:
+ b.EncodeVarint(9<<3 | proto.WireFixed32)
+ b.EncodeFixed32(uint64(math.Float32bits(x.Height)))
+ case *Communique_Today:
+ b.EncodeVarint(10<<3 | proto.WireVarint)
+ b.EncodeVarint(uint64(x.Today))
+ case *Communique_Maybe:
+ t := uint64(0)
+ if x.Maybe {
+ t = 1
+ }
+ b.EncodeVarint(11<<3 | proto.WireVarint)
+ b.EncodeVarint(t)
+ case *Communique_Delta_:
+ b.EncodeVarint(12<<3 | proto.WireVarint)
+ b.EncodeZigzag32(uint64(x.Delta))
+ case *Communique_Msg:
+ b.EncodeVarint(13<<3 | proto.WireBytes)
+ if err := b.EncodeMessage(x.Msg); err != nil {
+ return err
+ }
+ case *Communique_Somegroup:
+ b.EncodeVarint(14<<3 | proto.WireStartGroup)
+ if err := b.Marshal(x.Somegroup); err != nil {
+ return err
+ }
+ b.EncodeVarint(14<<3 | proto.WireEndGroup)
+ case nil:
+ default:
+ return fmt.Errorf("Communique.Union has unexpected type %T", x)
+ }
+ return nil
+}
+
+func _Communique_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) {
+ m := msg.(*Communique)
+ switch tag {
+ case 5: // union.number
+ if wire != proto.WireVarint {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeVarint()
+ m.Union = &Communique_Number{int32(x)}
+ return true, err
+ case 6: // union.name
+ if wire != proto.WireBytes {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeStringBytes()
+ m.Union = &Communique_Name{x}
+ return true, err
+ case 7: // union.data
+ if wire != proto.WireBytes {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeRawBytes(true)
+ m.Union = &Communique_Data{x}
+ return true, err
+ case 8: // union.temp_c
+ if wire != proto.WireFixed64 {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeFixed64()
+ m.Union = &Communique_TempC{math.Float64frombits(x)}
+ return true, err
+ case 9: // union.height
+ if wire != proto.WireFixed32 {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeFixed32()
+ m.Union = &Communique_Height{math.Float32frombits(uint32(x))}
+ return true, err
+ case 10: // union.today
+ if wire != proto.WireVarint {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeVarint()
+ m.Union = &Communique_Today{Days(x)}
+ return true, err
+ case 11: // union.maybe
+ if wire != proto.WireVarint {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeVarint()
+ m.Union = &Communique_Maybe{x != 0}
+ return true, err
+ case 12: // union.delta
+ if wire != proto.WireVarint {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeZigzag32()
+ m.Union = &Communique_Delta_{int32(x)}
+ return true, err
+ case 13: // union.msg
+ if wire != proto.WireBytes {
+ return true, proto.ErrInternalBadWireType
+ }
+ msg := new(Reply)
+ err := b.DecodeMessage(msg)
+ m.Union = &Communique_Msg{msg}
+ return true, err
+ case 14: // union.somegroup
+ if wire != proto.WireStartGroup {
+ return true, proto.ErrInternalBadWireType
+ }
+ msg := new(Communique_SomeGroup)
+ err := b.DecodeGroup(msg)
+ m.Union = &Communique_Somegroup{msg}
+ return true, err
+ default:
+ return false, nil
+ }
+}
+
+type Communique_SomeGroup struct {
+ Member *string `protobuf:"bytes,15,opt,name=member" json:"member,omitempty"`
+ XXX_unrecognized []byte `json:"-"`
+}
+
+func (m *Communique_SomeGroup) Reset() { *m = Communique_SomeGroup{} }
+func (m *Communique_SomeGroup) String() string { return proto.CompactTextString(m) }
+func (*Communique_SomeGroup) ProtoMessage() {}
+
+func (m *Communique_SomeGroup) GetMember() string {
+ if m != nil && m.Member != nil {
+ return *m.Member
+ }
+ return ""
+}
+
+type Communique_Delta struct {
+ XXX_unrecognized []byte `json:"-"`
+}
+
+func (m *Communique_Delta) Reset() { *m = Communique_Delta{} }
+func (m *Communique_Delta) String() string { return proto.CompactTextString(m) }
+func (*Communique_Delta) ProtoMessage() {}
+
var E_Tag = &proto.ExtensionDesc{
ExtendedType: (*Reply)(nil),
ExtensionType: (*string)(nil),
diff --git a/protoc-gen-go/testdata/my_test/test.proto b/protoc-gen-go/testdata/my_test/test.proto
index 975520f..9acd4ce 100644
--- a/protoc-gen-go/testdata/my_test/test.proto
+++ b/protoc-gen-go/testdata/my_test/test.proto
@@ -130,3 +130,25 @@
extensions 100 to max;
}
+message Communique {
+ optional bool make_me_cry = 1;
+
+ // This is a oneof, called "union".
+ oneof union {
+ int32 number = 5;
+ string name = 6;
+ bytes data = 7;
+ double temp_c = 8;
+ float height = 9;
+ Days today = 10;
+ bool maybe = 11;
+ sint32 delta = 12; // name will conflict with Delta below
+ Reply msg = 13;
+ group SomeGroup = 14 {
+ optional string member = 15;
+ }
+ }
+
+ message Delta {}
+}
+