Implement oneof support.
This includes the code generation changes,
and the infrastructure to wire it up to the encode/decode machinery.
The overall API changes are these:
- oneofs in a message are replaced by a single interface field
- each field in a oneof gets a distinguished type that satisfies
the corresponding interface
- a type switch may be used to distinguish between oneof fields
Fixes #29.
diff --git a/proto/all_test.go b/proto/all_test.go
index b787d58..49ce01f 100644
--- a/proto/all_test.go
+++ b/proto/all_test.go
@@ -1958,6 +1958,58 @@
}
}
+func TestOneof(t *testing.T) {
+ m := &Communique{}
+ b, err := Marshal(m)
+ if err != nil {
+ t.Fatalf("Marshal of empty message with oneof: %v", err)
+ }
+ if len(b) != 0 {
+ t.Errorf("Marshal of empty message yielded too many bytes: %v", b)
+ }
+
+ m = &Communique{
+ Union: &Communique_Name{"Barry"},
+ }
+
+ // Round-trip.
+ b, err = Marshal(m)
+ if err != nil {
+ t.Fatalf("Marshal of message with oneof: %v", err)
+ }
+ if len(b) != 7 { // name tag/wire (1) + name len (1) + name (5)
+ t.Errorf("Incorrect marshal of message with oneof: %v", b)
+ }
+ m.Reset()
+ if err := Unmarshal(b, m); err != nil {
+ t.Fatalf("Unmarshal of message with oneof: %v", err)
+ }
+ if x, ok := m.Union.(*Communique_Name); !ok || x.Name != "Barry" {
+ t.Errorf("After round trip, Union = %+v", m.Union)
+ }
+ if name := m.GetName(); name != "Barry" {
+ t.Errorf("After round trip, GetName = %q, want %q", name, "Barry")
+ }
+
+ // Let's try with a message in the oneof.
+ m.Union = &Communique_Msg{&Strings{StringField: String("deep deep string")}}
+ b, err = Marshal(m)
+ if err != nil {
+ t.Fatalf("Marshal of message with oneof set to message: %v", err)
+ }
+ if len(b) != 20 { // msg tag/wire (1) + msg len (1) + msg (1 + 1 + 16)
+ t.Errorf("Incorrect marshal of message with oneof set to message: %v", b)
+ }
+ m.Reset()
+ if err := Unmarshal(b, m); err != nil {
+ t.Fatalf("Unmarshal of message with oneof set to message: %v", err)
+ }
+ ss, ok := m.Union.(*Communique_Msg)
+ if !ok || ss.Msg.GetStringField() != "deep deep string" {
+ t.Errorf("After round trip with oneof set to message, Union = %+v", m.Union)
+ }
+}
+
// Benchmarks
func testMsg() *GoTest {
diff --git a/proto/clone.go b/proto/clone.go
index 915a68b..0155cc7 100644
--- a/proto/clone.go
+++ b/proto/clone.go
@@ -120,6 +120,13 @@
return
}
out.Set(in)
+ case reflect.Interface:
+ // Probably a oneof field; copy non-nil values.
+ if in.IsNil() {
+ return
+ }
+ out.Set(reflect.New(in.Elem().Elem().Type())) // interface -> *T -> T -> new(T)
+ mergeAny(out.Elem(), in.Elem(), false, nil)
case reflect.Map:
if in.Len() == 0 {
return
diff --git a/proto/clone_test.go b/proto/clone_test.go
index a1c697b..76720f1 100644
--- a/proto/clone_test.go
+++ b/proto/clone_test.go
@@ -232,6 +232,28 @@
Data: []byte("texas!"),
},
},
+ // Oneof fields should merge by assignment.
+ {
+ src: &pb.Communique{
+ Union: &pb.Communique_Number{41},
+ },
+ dst: &pb.Communique{
+ Union: &pb.Communique_Name{"Bobby Tables"},
+ },
+ want: &pb.Communique{
+ Union: &pb.Communique_Number{41},
+ },
+ },
+ // Oneof nil is the same as not set.
+ {
+ src: &pb.Communique{},
+ dst: &pb.Communique{
+ Union: &pb.Communique_Name{"Bobby Tables"},
+ },
+ want: &pb.Communique{
+ Union: &pb.Communique_Name{"Bobby Tables"},
+ },
+ },
}
func TestMerge(t *testing.T) {
diff --git a/proto/decode.go b/proto/decode.go
index bf71dca..8486635 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -46,6 +46,10 @@
// errOverflow is returned when an integer is too large to be represented.
var errOverflow = errors.New("proto: integer overflow")
+// ErrInternalBadWireType is returned by generated code when an incorrect
+// wire type is encountered. It does not get returned to user code.
+var ErrInternalBadWireType = errors.New("proto: internal error: bad wiretype for oneof")
+
// The fundamental decoders that interpret bytes on the wire.
// Those that take integer types all return uint64 and are
// therefore of type valueDecoder.
@@ -314,6 +318,24 @@
return NewBuffer(buf).Unmarshal(pb)
}
+// DecodeMessage reads a count-delimited message from the Buffer.
+func (p *Buffer) DecodeMessage(pb Message) error {
+ enc, err := p.DecodeRawBytes(false)
+ if err != nil {
+ return err
+ }
+ return NewBuffer(enc).Unmarshal(pb)
+}
+
+// DecodeGroup reads a tag-delimited group from the Buffer.
+func (p *Buffer) DecodeGroup(pb Message) error {
+ typ, base, err := getbase(pb)
+ if err != nil {
+ return err
+ }
+ return p.unmarshalType(typ.Elem(), GetProperties(typ.Elem()), true, base)
+}
+
// Unmarshal parses the protocol buffer representation in the
// Buffer and places the decoded result in pb. If the struct
// underlying pb does not match the data in the buffer, the results can be
@@ -377,6 +399,20 @@
continue
}
}
+ // Maybe it's a oneof?
+ if prop.oneofUnmarshaler != nil {
+ m := structPointer_Interface(base, st).(Message)
+ // First return value indicates whether tag is a oneof field.
+ ok, err = prop.oneofUnmarshaler(m, tag, wire, o)
+ if err == ErrInternalBadWireType {
+ // Map the error to something more descriptive.
+ // Do the formatting here to save generated code space.
+ err = fmt.Errorf("bad wiretype for oneof field in %T", m)
+ }
+ if ok {
+ continue
+ }
+ }
err = o.skipAndSave(st, tag, wire, base, prop.unrecField)
continue
}
diff --git a/proto/encode.go b/proto/encode.go
index 91f3f07..fe48cc7 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -228,6 +228,20 @@
return p.buf, err
}
+// EncodeMessage writes the protocol buffer to the Buffer,
+// prefixed by a varint-encoded length.
+func (p *Buffer) EncodeMessage(pb Message) error {
+ t, base, err := getbase(pb)
+ if structPointer_IsNil(base) {
+ return ErrNil
+ }
+ if err == nil {
+ var state errorState
+ err = p.enc_len_struct(GetProperties(t.Elem()), base, &state)
+ }
+ return err
+}
+
// Marshal takes the protocol buffer
// and encodes it into the wire format, writing the result to the
// Buffer.
@@ -1201,6 +1215,14 @@
}
}
+ // Do oneof fields.
+ if prop.oneofMarshaler != nil {
+ m := structPointer_Interface(base, prop.stype).(Message)
+ if err := prop.oneofMarshaler(m, o); err != nil {
+ return err
+ }
+ }
+
// Add unrecognized fields at the end.
if prop.unrecField.IsValid() {
v := *structPointer_Bytes(base, prop.unrecField)
@@ -1226,6 +1248,27 @@
n += len(v)
}
+ // Factor in any oneof fields.
+ // TODO: This could be faster and use less reflection.
+ if prop.oneofMarshaler != nil {
+ sv := reflect.ValueOf(structPointer_Interface(base, prop.stype)).Elem()
+ for i := 0; i < prop.stype.NumField(); i++ {
+ fv := sv.Field(i)
+ if fv.Kind() != reflect.Interface || fv.IsNil() {
+ continue
+ }
+ if prop.stype.Field(i).Tag.Get("protobuf_oneof") == "" {
+ continue
+ }
+ spv := fv.Elem() // interface -> *T
+ sv := spv.Elem() // *T -> T
+ sf := sv.Type().Field(0) // StructField inside T
+ var prop Properties
+ prop.Init(sf.Type, "whatever", sf.Tag.Get("protobuf"), &sf)
+ n += prop.size(&prop, toStructPointer(spv))
+ }
+ }
+
return
}
diff --git a/proto/equal.go b/proto/equal.go
index d8673a3..5475c3d 100644
--- a/proto/equal.go
+++ b/proto/equal.go
@@ -154,6 +154,17 @@
return v1.Float() == v2.Float()
case reflect.Int32, reflect.Int64:
return v1.Int() == v2.Int()
+ case reflect.Interface:
+ // Probably a oneof field; compare the inner values.
+ n1, n2 := v1.IsNil(), v2.IsNil()
+ if n1 || n2 {
+ return n1 == n2
+ }
+ e1, e2 := v1.Elem(), v2.Elem()
+ if e1.Type() != e2.Type() {
+ return false
+ }
+ return equalAny(e1, e2)
case reflect.Map:
if v1.Len() != v2.Len() {
return false
diff --git a/proto/equal_test.go b/proto/equal_test.go
index b322f65..0e0db8a 100644
--- a/proto/equal_test.go
+++ b/proto/equal_test.go
@@ -180,6 +180,24 @@
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Rob"}},
false,
},
+ {
+ "oneof same",
+ &pb.Communique{Union: &pb.Communique_Number{41}},
+ &pb.Communique{Union: &pb.Communique_Number{41}},
+ true,
+ },
+ {
+ "oneof one nil",
+ &pb.Communique{Union: &pb.Communique_Number{41}},
+ &pb.Communique{},
+ false,
+ },
+ {
+ "oneof different",
+ &pb.Communique{Union: &pb.Communique_Number{41}},
+ &pb.Communique{Union: &pb.Communique_Name{"Bobby Tables"}},
+ false,
+ },
}
func TestEqual(t *testing.T) {
diff --git a/proto/lib.go b/proto/lib.go
index 95f7975..187bd65 100644
--- a/proto/lib.go
+++ b/proto/lib.go
@@ -66,6 +66,8 @@
that contain it (if any) followed by the CamelCased name of the
extension field itself. HasExtension, ClearExtension, GetExtension
and SetExtension are functions for manipulating extensions.
+ - Oneof field sets are given a single field in their message,
+ with distinguished wrapper types for each possible field value.
- Marshal and Unmarshal are functions to encode and decode the wire format.
The simplest way to describe this is to see an example.
diff --git a/proto/properties.go b/proto/properties.go
index d74844a..5685445 100644
--- a/proto/properties.go
+++ b/proto/properties.go
@@ -84,6 +84,12 @@
// A valueDecoder decodes a single integer in a particular encoding.
type valueDecoder func(o *Buffer) (x uint64, err error)
+// A oneofMarshaler does the marshaling for all oneof fields in a message.
+type oneofMarshaler func(Message, *Buffer) error
+
+// A oneofUnmarshaler does the unmarshaling for a oneof field in a message.
+type oneofUnmarshaler func(Message, int, int, *Buffer) (bool, error)
+
// tagMap is an optimization over map[int]int for typical protocol buffer
// use-cases. Encoded protocol buffers are often in tag order with small tag
// numbers.
@@ -132,6 +138,11 @@
order []int // list of struct field numbers in tag order
unrecField field // field id of the XXX_unrecognized []byte field
extendable bool // is this an extendable proto
+
+ oneofMarshaler oneofMarshaler
+ oneofUnmarshaler oneofUnmarshaler
+ stype reflect.Type
+ oneofTypes []interface{}
}
// Implement the sorting interface so we can sort the fields in tag order, as recommended by the spec.
@@ -665,6 +676,7 @@
if f.Name == "XXX_unrecognized" { // special case
prop.unrecField = toField(&f)
}
+ oneof := f.Tag.Get("protobuf_oneof") != "" // special case
prop.Prop[i] = p
prop.order[i] = i
if debug {
@@ -674,7 +686,7 @@
}
print("\n")
}
- if p.enc == nil && !strings.HasPrefix(f.Name, "XXX_") {
+ if p.enc == nil && !strings.HasPrefix(f.Name, "XXX_") && !oneof {
fmt.Fprintln(os.Stderr, "proto: no encoder for", f.Name, f.Type.String(), "[GetProperties]")
}
}
@@ -682,6 +694,14 @@
// Re-order prop.order.
sort.Sort(prop)
+ type oneofMessage interface {
+ XXX_OneofFuncs() (func(Message, *Buffer) error, func(Message, int, int, *Buffer) (bool, error), []interface{})
+ }
+ if om, ok := reflect.Zero(reflect.PtrTo(t)).Interface().(oneofMessage); ok {
+ prop.oneofMarshaler, prop.oneofUnmarshaler, prop.oneofTypes = om.XXX_OneofFuncs()
+ prop.stype = t
+ }
+
// build required counts
// build tags
reqCount := 0
diff --git a/proto/size_test.go b/proto/size_test.go
index db5614f..53806a3 100644
--- a/proto/size_test.go
+++ b/proto/size_test.go
@@ -124,6 +124,10 @@
{"map field with big entry", &pb.MessageWithMap{NameMapping: map[int32]string{8: strings.Repeat("x", 125)}}},
{"map field with big key and val", &pb.MessageWithMap{StrToStr: map[string]string{strings.Repeat("x", 70): strings.Repeat("y", 70)}}},
{"map field with big numeric key", &pb.MessageWithMap{NameMapping: map[int32]string{0xf00d: "om nom nom"}}},
+
+ {"oneof not set", &pb.Communique{}},
+ {"oneof int32", &pb.Communique{Union: &pb.Communique_Number{3}}},
+ {"oneof string", &pb.Communique{Union: &pb.Communique_Name{"Rhythmic Fman"}}},
}
func TestSize(t *testing.T) {
diff --git a/proto/testdata/test.pb.go b/proto/testdata/test.pb.go
index 13674a4..c22fe31 100644
--- a/proto/testdata/test.pb.go
+++ b/proto/testdata/test.pb.go
@@ -35,14 +35,17 @@
GroupNew
FloatingPoint
MessageWithMap
+ Communique
*/
package testdata
import proto "github.com/golang/protobuf/proto"
+import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
+var _ = fmt.Errorf
var _ = math.Inf
type FOO int32
@@ -1986,6 +1989,205 @@
return nil
}
+type Communique struct {
+ MakeMeCry *bool `protobuf:"varint,1,opt,name=make_me_cry" json:"make_me_cry,omitempty"`
+ // This is a oneof, called "union".
+ //
+ // Types that are valid to be assigned to Union:
+ // *Communique_Number
+ // *Communique_Name
+ // *Communique_Data
+ // *Communique_TempC
+ // *Communique_Col
+ // *Communique_Msg
+ Union isCommunique_Union `protobuf_oneof:"union"`
+ XXX_unrecognized []byte `json:"-"`
+}
+
+func (m *Communique) Reset() { *m = Communique{} }
+func (m *Communique) String() string { return proto.CompactTextString(m) }
+func (*Communique) ProtoMessage() {}
+
+type isCommunique_Union interface {
+ isCommunique_Union()
+}
+
+type Communique_Number struct {
+ Number int32 `protobuf:"varint,5,opt,name=number"`
+}
+type Communique_Name struct {
+ Name string `protobuf:"bytes,6,opt,name=name"`
+}
+type Communique_Data struct {
+ Data []byte `protobuf:"bytes,7,opt,name=data"`
+}
+type Communique_TempC struct {
+ TempC float64 `protobuf:"fixed64,8,opt,name=temp_c"`
+}
+type Communique_Col struct {
+ Col MyMessage_Color `protobuf:"varint,9,opt,name=col,enum=testdata.MyMessage_Color"`
+}
+type Communique_Msg struct {
+ Msg *Strings `protobuf:"bytes,10,opt,name=msg"`
+}
+
+func (*Communique_Number) isCommunique_Union() {}
+func (*Communique_Name) isCommunique_Union() {}
+func (*Communique_Data) isCommunique_Union() {}
+func (*Communique_TempC) isCommunique_Union() {}
+func (*Communique_Col) isCommunique_Union() {}
+func (*Communique_Msg) isCommunique_Union() {}
+
+func (m *Communique) GetUnion() isCommunique_Union {
+ if m != nil {
+ return m.Union
+ }
+ return nil
+}
+
+func (m *Communique) GetMakeMeCry() bool {
+ if m != nil && m.MakeMeCry != nil {
+ return *m.MakeMeCry
+ }
+ return false
+}
+
+func (m *Communique) GetNumber() int32 {
+ if x, ok := m.GetUnion().(*Communique_Number); ok {
+ return x.Number
+ }
+ return 0
+}
+
+func (m *Communique) GetName() string {
+ if x, ok := m.GetUnion().(*Communique_Name); ok {
+ return x.Name
+ }
+ return ""
+}
+
+func (m *Communique) GetData() []byte {
+ if x, ok := m.GetUnion().(*Communique_Data); ok {
+ return x.Data
+ }
+ return nil
+}
+
+func (m *Communique) GetTempC() float64 {
+ if x, ok := m.GetUnion().(*Communique_TempC); ok {
+ return x.TempC
+ }
+ return 0
+}
+
+func (m *Communique) GetCol() MyMessage_Color {
+ if x, ok := m.GetUnion().(*Communique_Col); ok {
+ return x.Col
+ }
+ return MyMessage_RED
+}
+
+func (m *Communique) GetMsg() *Strings {
+ if x, ok := m.GetUnion().(*Communique_Msg); ok {
+ return x.Msg
+ }
+ return nil
+}
+
+// XXX_OneofFuncs is for the internal use of the proto package.
+func (*Communique) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), []interface{}) {
+ return _Communique_OneofMarshaler, _Communique_OneofUnmarshaler, []interface{}{
+ (*Communique_Number)(nil),
+ (*Communique_Name)(nil),
+ (*Communique_Data)(nil),
+ (*Communique_TempC)(nil),
+ (*Communique_Col)(nil),
+ (*Communique_Msg)(nil),
+ }
+}
+
+func _Communique_OneofMarshaler(msg proto.Message, b *proto.Buffer) error {
+ m := msg.(*Communique)
+ // union
+ switch x := m.Union.(type) {
+ case *Communique_Number:
+ b.EncodeVarint(5<<3 | proto.WireVarint)
+ b.EncodeVarint(uint64(x.Number))
+ case *Communique_Name:
+ b.EncodeVarint(6<<3 | proto.WireBytes)
+ b.EncodeStringBytes(x.Name)
+ case *Communique_Data:
+ b.EncodeVarint(7<<3 | proto.WireBytes)
+ b.EncodeRawBytes(x.Data)
+ case *Communique_TempC:
+ b.EncodeVarint(8<<3 | proto.WireFixed64)
+ b.EncodeFixed64(math.Float64bits(x.TempC))
+ case *Communique_Col:
+ b.EncodeVarint(9<<3 | proto.WireVarint)
+ b.EncodeVarint(uint64(x.Col))
+ case *Communique_Msg:
+ b.EncodeVarint(10<<3 | proto.WireBytes)
+ if err := b.EncodeMessage(x.Msg); err != nil {
+ return err
+ }
+ case nil:
+ default:
+ return fmt.Errorf("Communique.Union has unexpected type %T", x)
+ }
+ return nil
+}
+
+func _Communique_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) {
+ m := msg.(*Communique)
+ switch tag {
+ case 5: // union.number
+ if wire != proto.WireVarint {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeVarint()
+ m.Union = &Communique_Number{int32(x)}
+ return true, err
+ case 6: // union.name
+ if wire != proto.WireBytes {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeStringBytes()
+ m.Union = &Communique_Name{x}
+ return true, err
+ case 7: // union.data
+ if wire != proto.WireBytes {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeRawBytes(true)
+ m.Union = &Communique_Data{x}
+ return true, err
+ case 8: // union.temp_c
+ if wire != proto.WireFixed64 {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeFixed64()
+ m.Union = &Communique_TempC{math.Float64frombits(x)}
+ return true, err
+ case 9: // union.col
+ if wire != proto.WireVarint {
+ return true, proto.ErrInternalBadWireType
+ }
+ x, err := b.DecodeVarint()
+ m.Union = &Communique_Col{MyMessage_Color(x)}
+ return true, err
+ case 10: // union.msg
+ if wire != proto.WireBytes {
+ return true, proto.ErrInternalBadWireType
+ }
+ msg := new(Strings)
+ err := b.DecodeMessage(msg)
+ m.Union = &Communique_Msg{msg}
+ return true, err
+ default:
+ return false, nil
+ }
+}
+
var E_Greeting = &proto.ExtensionDesc{
ExtendedType: (*MyMessage)(nil),
ExtensionType: ([]string)(nil),
diff --git a/proto/testdata/test.proto b/proto/testdata/test.proto
index 440dba3..3d1cbb6 100644
--- a/proto/testdata/test.proto
+++ b/proto/testdata/test.proto
@@ -478,3 +478,17 @@
map<bool, bytes> byte_mapping = 3;
map<string, string> str_to_str = 4;
}
+
+message Communique {
+ optional bool make_me_cry = 1;
+
+ // This is a oneof, called "union".
+ oneof union {
+ int32 number = 5;
+ string name = 6;
+ bytes data = 7;
+ double temp_c = 8;
+ MyMessage.Color col = 9;
+ Strings msg = 10;
+ }
+}
diff --git a/proto/text.go b/proto/text.go
index f3db2cf..51202bd 100644
--- a/proto/text.go
+++ b/proto/text.go
@@ -322,6 +322,23 @@
}
}
+ if fv.Kind() == reflect.Interface {
+ // Check if it is a oneof.
+ if st.Field(i).Tag.Get("protobuf_oneof") != "" {
+ // fv is nil, or holds a pointer to generated struct.
+ // That generated struct has exactly one field,
+ // which has a protobuf struct tag.
+ if fv.IsNil() {
+ continue
+ }
+ inner := fv.Elem().Elem() // interface -> *T -> T
+ tag := inner.Type().Field(0).Tag.Get("protobuf")
+ props.Parse(tag) // Overwrite the outer props.
+ // Write the value in the oneof, not the oneof itself.
+ fv = inner.Field(0)
+ }
+ }
+
if err := writeName(w, props); err != nil {
return err
}
diff --git a/proto/text_parser.go b/proto/text_parser.go
index 7d0c757..2e2a67f 100644
--- a/proto/text_parser.go
+++ b/proto/text_parser.go
@@ -385,8 +385,7 @@
}
// Returns the index in the struct for the named field, as well as the parsed tag properties.
-func structFieldByName(st reflect.Type, name string) (int, *Properties, bool) {
- sprops := GetProperties(st)
+func structFieldByName(sprops *StructProperties, name string) (int, *Properties, bool) {
i, ok := sprops.decoderOrigNames[name]
if ok {
return i, sprops.Prop[i], true
@@ -438,7 +437,8 @@
func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
st := sv.Type()
- reqCount := GetProperties(st).reqCount
+ sprops := GetProperties(st)
+ reqCount := sprops.reqCount
var reqFieldErr error
fieldSet := make(map[string]bool)
// A struct is a sequence of "name: value", terminated by one of
@@ -520,95 +520,129 @@
sl = reflect.Append(sl, ext)
SetExtension(ep, desc, sl.Interface())
}
+ if err := p.consumeOptionalSeparator(); err != nil {
+ return err
+ }
+ continue
+ }
+
+ // This is a normal, non-extension field.
+ name := tok.value
+ var dst reflect.Value
+ fi, props, ok := structFieldByName(sprops, name)
+ if ok {
+ dst = sv.Field(fi)
} else {
- // This is a normal, non-extension field.
- name := tok.value
- fi, props, ok := structFieldByName(st, name)
- if !ok {
- return p.errorf("unknown field name %q in %v", name, st)
+ // Maybe it is a oneof.
+ // TODO: If this shows in profiles, cache the mapping.
+ for _, oot := range sprops.oneofTypes {
+ sp := reflect.ValueOf(oot).Type() // *T
+ var p Properties
+ p.Parse(sp.Elem().Field(0).Tag.Get("protobuf"))
+ if p.OrigName != name {
+ continue
+ }
+ nv := reflect.New(sp.Elem())
+ dst = nv.Elem().Field(0)
+ props = &p
+ // There will be exactly one interface field that
+ // this new value is assignable to.
+ for i := 0; i < st.NumField(); i++ {
+ f := st.Field(i)
+ if f.Type.Kind() != reflect.Interface {
+ continue
+ }
+ if !nv.Type().AssignableTo(f.Type) {
+ continue
+ }
+ sv.Field(i).Set(nv)
+ break
+ }
+ break
}
+ }
+ if !dst.IsValid() {
+ return p.errorf("unknown field name %q in %v", name, st)
+ }
- dst := sv.Field(fi)
-
- if dst.Kind() == reflect.Map {
- // Consume any colon.
- if err := p.checkForColon(props, dst.Type()); err != nil {
- return err
- }
-
- // Construct the map if it doesn't already exist.
- if dst.IsNil() {
- dst.Set(reflect.MakeMap(dst.Type()))
- }
- key := reflect.New(dst.Type().Key()).Elem()
- val := reflect.New(dst.Type().Elem()).Elem()
-
- // The map entry should be this sequence of tokens:
- // < key : KEY value : VALUE >
- // Technically the "key" and "value" could come in any order,
- // but in practice they won't.
-
- tok := p.next()
- var terminator string
- switch tok.value {
- case "<":
- terminator = ">"
- case "{":
- terminator = "}"
- default:
- return p.errorf("expected '{' or '<', found %q", tok.value)
- }
- if err := p.consumeToken("key"); err != nil {
- return err
- }
- if err := p.consumeToken(":"); err != nil {
- return err
- }
- if err := p.readAny(key, props.mkeyprop); err != nil {
- return err
- }
- if err := p.consumeOptionalSeparator(); err != nil {
- return err
- }
- if err := p.consumeToken("value"); err != nil {
- return err
- }
- if err := p.checkForColon(props.mvalprop, dst.Type().Elem()); err != nil {
- return err
- }
- if err := p.readAny(val, props.mvalprop); err != nil {
- return err
- }
- if err := p.consumeOptionalSeparator(); err != nil {
- return err
- }
- if err := p.consumeToken(terminator); err != nil {
- return err
- }
-
- dst.SetMapIndex(key, val)
- continue
- }
-
- // Check that it's not already set if it's not a repeated field.
- if !props.Repeated && fieldSet[name] {
- return p.errorf("non-repeated field %q was repeated", name)
- }
-
- if err := p.checkForColon(props, st.Field(fi).Type); err != nil {
+ if dst.Kind() == reflect.Map {
+ // Consume any colon.
+ if err := p.checkForColon(props, dst.Type()); err != nil {
return err
}
- // Parse into the field.
- fieldSet[name] = true
- if err := p.readAny(dst, props); err != nil {
- if _, ok := err.(*RequiredNotSetError); !ok {
- return err
- }
- reqFieldErr = err
- } else if props.Required {
- reqCount--
+ // Construct the map if it doesn't already exist.
+ if dst.IsNil() {
+ dst.Set(reflect.MakeMap(dst.Type()))
}
+ key := reflect.New(dst.Type().Key()).Elem()
+ val := reflect.New(dst.Type().Elem()).Elem()
+
+ // The map entry should be this sequence of tokens:
+ // < key : KEY value : VALUE >
+ // Technically the "key" and "value" could come in any order,
+ // but in practice they won't.
+
+ tok := p.next()
+ var terminator string
+ switch tok.value {
+ case "<":
+ terminator = ">"
+ case "{":
+ terminator = "}"
+ default:
+ return p.errorf("expected '{' or '<', found %q", tok.value)
+ }
+ if err := p.consumeToken("key"); err != nil {
+ return err
+ }
+ if err := p.consumeToken(":"); err != nil {
+ return err
+ }
+ if err := p.readAny(key, props.mkeyprop); err != nil {
+ return err
+ }
+ if err := p.consumeOptionalSeparator(); err != nil {
+ return err
+ }
+ if err := p.consumeToken("value"); err != nil {
+ return err
+ }
+ if err := p.checkForColon(props.mvalprop, dst.Type().Elem()); err != nil {
+ return err
+ }
+ if err := p.readAny(val, props.mvalprop); err != nil {
+ return err
+ }
+ if err := p.consumeOptionalSeparator(); err != nil {
+ return err
+ }
+ if err := p.consumeToken(terminator); err != nil {
+ return err
+ }
+
+ dst.SetMapIndex(key, val)
+ continue
+ }
+
+ // Check that it's not already set if it's not a repeated field.
+ if !props.Repeated && fieldSet[name] {
+ return p.errorf("non-repeated field %q was repeated", name)
+ }
+
+ if err := p.checkForColon(props, dst.Type()); err != nil {
+ return err
+ }
+
+ // Parse into the field.
+ fieldSet[name] = true
+ if err := p.readAny(dst, props); err != nil {
+ if _, ok := err.(*RequiredNotSetError); !ok {
+ return err
+ }
+ reqFieldErr = err
+ } else if props.Required {
+ reqCount--
}
if err := p.consumeOptionalSeparator(); err != nil {
diff --git a/proto/text_parser_test.go b/proto/text_parser_test.go
index 0754b26..a2a9604 100644
--- a/proto/text_parser_test.go
+++ b/proto/text_parser_test.go
@@ -486,6 +486,18 @@
}
}
+func TestOneofParsing(t *testing.T) {
+ const in = `name:"Shrek"`
+ m := new(Communique)
+ want := &Communique{Union: &Communique_Name{"Shrek"}}
+ if err := UnmarshalText(in, m); err != nil {
+ t.Fatal(err)
+ }
+ if !Equal(m, want) {
+ t.Errorf("\n got %v\nwant %v", m, want)
+ }
+}
+
var benchInput string
func init() {
diff --git a/proto/text_test.go b/proto/text_test.go
index 64579e9..7ff180d 100644
--- a/proto/text_test.go
+++ b/proto/text_test.go
@@ -208,6 +208,28 @@
}
}
+func TestTextOneof(t *testing.T) {
+ tests := []struct {
+ m proto.Message
+ want string
+ }{
+ // zero message
+ {&pb.Communique{}, ``},
+ // scalar field
+ {&pb.Communique{Union: &pb.Communique_Number{4}}, `number:4`},
+ // message field
+ {&pb.Communique{Union: &pb.Communique_Msg{
+ &pb.Strings{StringField: proto.String("why hello!")},
+ }}, `msg:<string_field:"why hello!" >`},
+ }
+ for _, test := range tests {
+ got := strings.TrimSpace(test.m.String())
+ if got != test.want {
+ t.Errorf("\n got %s\nwant %s", got, test.want)
+ }
+ }
+}
+
func BenchmarkMarshalTextBuffered(b *testing.B) {
buf := new(bytes.Buffer)
m := newTestMessage()