update proto library prior to avoid New* methods.
diffs generated automatically from Google-internal copy.
all tests pass
R=rsc, dsymonds1
CC=dsymonds
http://codereview.appspot.com/1908042
diff --git a/README b/README
index 6c12f05..9a23306 100644
--- a/README
+++ b/README
@@ -82,8 +82,6 @@
- The zero value for a struct is its correct initialization state.
All desired fields must be set before marshaling.
- A Reset() method will restore a protobuf struct to its zero state.
- - Each type T has a method NewT() to create a new instance. It
- is equivalent to new(T).
- Non-repeated fields are pointers to the values; nil means unset.
That is, optional or required field int32 f becomes F *int32.
- Repeated fields are slices.
@@ -151,7 +149,7 @@
if err != nil {
log.Exit("marshaling error:", err)
}
- newTest := example.NewTest()
+ newTest := &example.Test()
err = proto.Unmarshal(data, newTest)
if err != nil {
log.Exit("unmarshaling error:", err)
diff --git a/compiler/generator/generator.go b/compiler/generator/generator.go
index c2c6297..33bb9c0 100644
--- a/compiler/generator/generator.go
+++ b/compiler/generator/generator.go
@@ -909,17 +909,12 @@
g.Out()
g.P("}")
- // Reset and New functions
+ // Reset function
g.P("func (this *", ccTypeName, ") Reset() {")
g.In()
g.P("*this = ", ccTypeName, "{}")
g.Out()
g.P("}")
- g.P("func New", ccTypeName, "() *", ccTypeName, " {")
- g.In()
- g.P("return new(", ccTypeName, ")")
- g.Out()
- g.P("}")
// Extension support methods
if len(message.ExtensionRange) > 0 {
diff --git a/compiler/testdata/extension_test.go b/compiler/testdata/extension_test.go
index 6ff10ca..85c7d56 100644
--- a/compiler/testdata/extension_test.go
+++ b/compiler/testdata/extension_test.go
@@ -42,7 +42,7 @@
)
func TestSingleFieldExtension(t *testing.T) {
- bm := base.NewBaseMessage()
+ bm := &base.BaseMessage{}
bm.Height = proto.Int32(178)
// Use extension within scope of another type.
@@ -55,7 +55,7 @@
if err != nil {
t.Fatal("Failed encoding message with extension:", err)
}
- bm_new := base.NewBaseMessage()
+ bm_new := &base.BaseMessage{}
if err := proto.Unmarshal(buf, bm_new); err != nil {
t.Fatal("Failed decoding message with extension:", err)
}
@@ -76,7 +76,7 @@
}
func TestMessageExtension(t *testing.T) {
- bm := base.NewBaseMessage()
+ bm := &base.BaseMessage{}
bm.Height = proto.Int32(179)
// Use extension that is itself a message.
@@ -92,7 +92,7 @@
if err != nil {
t.Fatal("Failed encoding message with extension:", err)
}
- bm_new := base.NewBaseMessage()
+ bm_new := &base.BaseMessage{}
if err := proto.Unmarshal(buf, bm_new); err != nil {
t.Fatal("Failed decoding message with extension:", err)
}
@@ -116,7 +116,7 @@
}
func TestTopLevelExtension(t *testing.T) {
- bm := base.NewBaseMessage()
+ bm := &base.BaseMessage{}
bm.Height = proto.Int32(179)
width := proto.Int32(17)
@@ -128,7 +128,7 @@
if err != nil {
t.Fatal("Failed encoding message with extension:", err)
}
- bm_new := base.NewBaseMessage()
+ bm_new := &base.BaseMessage{}
if err := proto.Unmarshal(buf, bm_new); err != nil {
t.Fatal("Failed decoding message with extension:", err)
}
diff --git a/compiler/testdata/main.go b/compiler/testdata/main.go
index 29100be..ed81320 100644
--- a/compiler/testdata/main.go
+++ b/compiler/testdata/main.go
@@ -39,6 +39,6 @@
)
func main() {
- _ = my_test.NewRequest()
- _ = multitest.NewMulti1()
+ _ = &my_test.Request{}
+ _ = &multitest.Multi1{}
}
diff --git a/compiler/testdata/test.pb.go.golden b/compiler/testdata/test.pb.go.golden
index 397bea4..36284f0 100644
--- a/compiler/testdata/test.pb.go.golden
+++ b/compiler/testdata/test.pb.go.golden
@@ -98,9 +98,6 @@
func (this *Request) Reset() {
*this = Request{}
}
-func NewRequest() *Request {
- return new(Request)
-}
const Default_Request_Hat HatType = HatType_FEDORA
type Reply struct {
@@ -111,9 +108,6 @@
func (this *Reply) Reset() {
*this = Reply{}
}
-func NewReply() *Reply {
- return new(Reply)
-}
var extRange_Reply = []proto.ExtensionRange{
proto.ExtensionRange{100, 536870911},
@@ -137,9 +131,6 @@
func (this *Reply_Entry) Reset() {
*this = Reply_Entry{}
}
-func NewReply_Entry() *Reply_Entry {
- return new(Reply_Entry)
-}
const Default_Reply_Entry_Value int64 = 7
type ReplyExtensions struct {
@@ -148,9 +139,6 @@
func (this *ReplyExtensions) Reset() {
*this = ReplyExtensions{}
}
-func NewReplyExtensions() *ReplyExtensions {
- return new(ReplyExtensions)
-}
var E_ReplyExtensions_Time = &proto.ExtensionDesc{
ExtendedType: (*Reply)(nil),
diff --git a/proto/all_test.go b/proto/all_test.go
index 10e2297..39abc77 100644
--- a/proto/all_test.go
+++ b/proto/all_test.go
@@ -68,7 +68,7 @@
}
func initGoTestField() *GoTestField {
- f := NewGoTestField()
+ f := new(GoTestField)
f.Label = String("label")
f.Type = String("type")
return f
@@ -78,25 +78,25 @@
// (It's remarkable that required, optional, and repeated all have
// 8 letters.)
func initGoTest_RequiredGroup() *GoTest_RequiredGroup {
- f := NewGoTest_RequiredGroup()
- f.RequiredField = String("required")
- return f
+ return &GoTest_RequiredGroup{
+ RequiredField: String("required"),
+ }
}
func initGoTest_OptionalGroup() *GoTest_OptionalGroup {
- f := NewGoTest_OptionalGroup()
- f.RequiredField = String("optional")
- return f
+ return &GoTest_OptionalGroup{
+ RequiredField: String("optional"),
+ }
}
func initGoTest_RepeatedGroup() *GoTest_RepeatedGroup {
- f := NewGoTest_RepeatedGroup()
- f.RequiredField = String("repeated")
- return f
+ return &GoTest_RepeatedGroup{
+ RequiredField: String("repeated"),
+ }
}
func initGoTest(setdefaults bool) *GoTest {
- pb := NewGoTest()
+ pb := new(GoTest)
if setdefaults {
pb.F_BoolDefaulted = Bool(Default_GoTest_F_BoolDefaulted)
pb.F_Int32Defaulted = Int32(Default_GoTest_F_Int32Defaulted)
@@ -302,7 +302,7 @@
}
// Now test Unmarshal by recreating the original buffer.
- pbd := NewGoTest()
+ pbd := new(GoTest)
err = o.Unmarshal(pbd)
if err != nil {
t.Fatalf("overify unmarshal err = %v", err)
@@ -426,7 +426,7 @@
// Do we catch the "required bit not set" case?
func TestRequiredBit(t *testing.T) {
o := old()
- pb := NewGoTest()
+ pb := new(GoTest)
if o.Marshal(pb) != ErrRequiredNotSet {
t.Errorf("did not catch missing required fields")
}
@@ -781,24 +781,25 @@
o.Marshal(pb)
// Now new a GoSkipTest record.
- skipgroup := NewGoSkipTest_SkipGroup()
- skipgroup.GroupInt32 = Int32(75)
- skipgroup.GroupString = String("wxyz")
- skip := NewGoSkipTest()
- skip.SkipInt32 = Int32(32)
- skip.SkipFixed32 = Uint32(3232)
- skip.SkipFixed64 = Uint64(6464)
- skip.SkipString = String("skipper")
- skip.Skipgroup = skipgroup
+ skip := &GoSkipTest{
+ SkipInt32: Int32(32),
+ SkipFixed32: Uint32(3232),
+ SkipFixed64: Uint64(6464),
+ SkipString: String("skipper"),
+ Skipgroup: &GoSkipTest_SkipGroup{
+ GroupInt32: Int32(75),
+ GroupString: String("wxyz"),
+ },
+ }
// Marshal it into same buffer.
o.Marshal(skip)
- pbd := NewGoTestField()
+ pbd := new(GoTestField)
o.Unmarshal(pbd)
// The __unrecognized field should be a marshaling of GoSkipTest
- skipd := NewGoSkipTest()
+ skipd := new(GoSkipTest)
o.SetBuf(pbd.XXX_unrecognized)
o.Unmarshal(skipd)
@@ -876,7 +877,7 @@
buf, _ := Marshal(pb)
// Now test Unmarshal by recreating the original buffer.
- pbd := NewGoTest()
+ pbd := new(GoTest)
Unmarshal(buf, pbd)
// Check the checkable values
@@ -961,13 +962,16 @@
}
func TestProto1RepeatedGroup(t *testing.T) {
- pb := NewMessageList()
-
- pb.Message = make([]*MessageList_Message, 2)
- pb.Message[0] = NewMessageList_Message()
- pb.Message[0].Name = String("blah")
- pb.Message[0].Count = Int32(7)
- // NOTE: pb.Message[1] is a nil
+ pb := &MessageList{
+ Message: []*MessageList_Message{
+ &MessageList_Message{
+ Name: String("blah"),
+ Count: Int32(7),
+ },
+ // NOTE: pb.Message[1] is a nil
+ nil,
+ },
+ }
o := old()
if err := o.Marshal(pb); err != ErrRepeatedHasNil {
@@ -995,6 +999,24 @@
}
}
+// Verify that absent required fields cause Marshal/Unmarshal to return errors.
+func TestRequiredFieldEnforcement(t *testing.T) {
+ pb := new(GoTestField)
+ _, err := Marshal(pb)
+ if err == nil || err != ErrRequiredNotSet {
+ t.Errorf("marshal: expected %q, got %q", ErrRequiredNotSet, err)
+ }
+
+ // A slightly sneaky, yet valid, proto. It encodes the same required field twice,
+ // so simply counting the required fields is insufficient.
+ // field 1, encoding 2, value "hi"
+ buf := []byte("\x0A\x02hi\x0A\x02hi")
+ err = Unmarshal(buf, pb)
+ if err == nil || err != ErrRequiredNotSet {
+ t.Errorf("unmarshal: expected %q, got %q", ErrRequiredNotSet, err)
+ }
+}
+
func BenchmarkMarshal(b *testing.B) {
b.StopTimer()
@@ -1030,7 +1052,7 @@
for i := 0; i < N; i++ {
pb.F_Int32Repeated[i] = int32(i)
}
- pbd := NewGoTest()
+ pbd := new(GoTest)
p := NewBuffer(nil)
p.Marshal(pb)
p2 := NewBuffer(nil)
diff --git a/proto/decode.go b/proto/decode.go
index 0c237e1..5ad6625 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -333,6 +333,7 @@
func (o *Buffer) unmarshalType(t *reflect.PtrType, is_group bool, base uintptr) os.Error {
st := t.Elem().(*reflect.StructType)
prop := GetProperties(st)
+ required, reqFields := prop.reqCount, uint64(0)
sbase := getsbase(prop) // scratch area for data items
var err os.Error
@@ -367,19 +368,41 @@
}
p := prop.Prop[fieldnum]
- if p.dec != nil {
- if wire != WireStartGroup && wire != p.WireType {
- err = ErrWrongType
- continue
- }
- err = p.dec(o, p, base, sbase)
+ if p.dec == nil {
+ fmt.Fprintf(os.Stderr, "no protobuf decoder for %s.%s\n", t, st.Field(fieldnum).Name)
continue
}
-
- fmt.Fprintf(os.Stderr, "no protobuf decoder for %s.%s\n", t, st.Field(fieldnum).Name)
+ if wire != WireStartGroup && wire != p.WireType {
+ err = ErrWrongType
+ continue
+ }
+ err = p.dec(o, p, base, sbase)
+ if err == nil && p.Required {
+ // Successfully decoded a required field.
+ if tag <= 64 {
+ // use bitmap for fields 1-64 to catch field reuse.
+ var mask uint64 = 1 << uint64(tag-1)
+ if reqFields&mask == 0 {
+ // new required field
+ reqFields |= mask
+ required--
+ }
+ } else {
+ // This is imprecise. It can be fooled by a required field
+ // with a tag > 64 that is encoded twice; that's very rare.
+ // A fully correct implementation would require allocating
+ // a data structure, which we would like to avoid.
+ required--
+ }
+ }
}
- if err == nil && is_group {
- return io.ErrUnexpectedEOF
+ if err == nil {
+ if is_group {
+ return io.ErrUnexpectedEOF
+ }
+ if required > 0 {
+ return ErrRequiredNotSet
+ }
}
return err
}
diff --git a/proto/encode.go b/proto/encode.go
index 9c323b2..fac4b5e 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -45,7 +45,9 @@
// ErrRequiredNotSet is the error returned if Marshal is called with
// a protocol buffer struct whose required fields have not
-// all been initialized.
+// all been initialized. It is also the error returned if Unmarshal is
+// called with an encoded protocol buffer that does not include all the
+// required fields.
var ErrRequiredNotSet = os.NewError("required fields not set")
// ErrRepeatedHasNil is the error returned if Marshal is called with
diff --git a/proto/testdata/test.pb.go b/proto/testdata/test.pb.go
index 23c1f0e..a60b001 100644
--- a/proto/testdata/test.pb.go
+++ b/proto/testdata/test.pb.go
@@ -102,9 +102,6 @@
func (this *GoEnum) Reset() {
*this = GoEnum{}
}
-func NewGoEnum() *GoEnum {
- return new(GoEnum)
-}
type GoTestField struct {
Label *string "PB(bytes,1,req)"
@@ -114,9 +111,6 @@
func (this *GoTestField) Reset() {
*this = GoTestField{}
}
-func NewGoTestField() *GoTestField {
- return new(GoTestField)
-}
type GoTest struct {
Kind *int32 "PB(varint,1,req)"
@@ -185,9 +179,6 @@
func (this *GoTest) Reset() {
*this = GoTest{}
}
-func NewGoTest() *GoTest {
- return new(GoTest)
-}
const Default_GoTest_F_BoolDefaulted bool = true
const Default_GoTest_F_Int32Defaulted int32 = 32
const Default_GoTest_F_Int64Defaulted int64 = 64
@@ -209,9 +200,6 @@
func (this *GoTest_RequiredGroup) Reset() {
*this = GoTest_RequiredGroup{}
}
-func NewGoTest_RequiredGroup() *GoTest_RequiredGroup {
- return new(GoTest_RequiredGroup)
-}
type GoTest_RepeatedGroup struct {
RequiredField *string "PB(bytes,81,req)"
@@ -220,9 +208,6 @@
func (this *GoTest_RepeatedGroup) Reset() {
*this = GoTest_RepeatedGroup{}
}
-func NewGoTest_RepeatedGroup() *GoTest_RepeatedGroup {
- return new(GoTest_RepeatedGroup)
-}
type GoTest_OptionalGroup struct {
RequiredField *string "PB(bytes,91,req)"
@@ -231,9 +216,6 @@
func (this *GoTest_OptionalGroup) Reset() {
*this = GoTest_OptionalGroup{}
}
-func NewGoTest_OptionalGroup() *GoTest_OptionalGroup {
- return new(GoTest_OptionalGroup)
-}
type GoSkipTest struct {
SkipInt32 *int32 "PB(varint,11,req,name=skip_int32)"
@@ -246,9 +228,6 @@
func (this *GoSkipTest) Reset() {
*this = GoSkipTest{}
}
-func NewGoSkipTest() *GoSkipTest {
- return new(GoSkipTest)
-}
type GoSkipTest_SkipGroup struct {
GroupInt32 *int32 "PB(varint,16,req,name=group_int32)"
@@ -258,9 +237,6 @@
func (this *GoSkipTest_SkipGroup) Reset() {
*this = GoSkipTest_SkipGroup{}
}
-func NewGoSkipTest_SkipGroup() *GoSkipTest_SkipGroup {
- return new(GoSkipTest_SkipGroup)
-}
type InnerMessage struct {
Host *string "PB(bytes,1,req,name=host)"
@@ -271,9 +247,6 @@
func (this *InnerMessage) Reset() {
*this = InnerMessage{}
}
-func NewInnerMessage() *InnerMessage {
- return new(InnerMessage)
-}
const Default_InnerMessage_Port int32 = 4000
type OtherMessage struct {
@@ -286,9 +259,6 @@
func (this *OtherMessage) Reset() {
*this = OtherMessage{}
}
-func NewOtherMessage() *OtherMessage {
- return new(OtherMessage)
-}
type MyMessage struct {
Count *int32 "PB(varint,1,req,name=count)"
@@ -303,9 +273,6 @@
func (this *MyMessage) Reset() {
*this = MyMessage{}
}
-func NewMyMessage() *MyMessage {
- return new(MyMessage)
-}
type MessageList struct {
Message []*MessageList_Message "PB(group,1,rep,name=message)"
@@ -314,9 +281,6 @@
func (this *MessageList) Reset() {
*this = MessageList{}
}
-func NewMessageList() *MessageList {
- return new(MessageList)
-}
type MessageList_Message struct {
Name *string "PB(bytes,2,req,name=name)"
@@ -326,9 +290,6 @@
func (this *MessageList_Message) Reset() {
*this = MessageList_Message{}
}
-func NewMessageList_Message() *MessageList_Message {
- return new(MessageList_Message)
-}
func init() {
proto.RegisterEnum("test_proto.FOO", FOO_name, FOO_value)