goprotobuf: fix integer overflows.

1) It's possible to panic the decoder by overflowing a length check.

2) (minor) the decoder was silently truncating varints that were larger than 64 bits. This isn't strictly a problem, but it could lead to a situation where a different decoder could decode a given message differently. Thus, if the message was vetted by one decoder and processed by another, an attacker could exploit this difference.

R=dsymonds
CC=golang-dev
https://codereview.appspot.com/11094044
diff --git a/proto/all_test.go b/proto/all_test.go
index 07163f6..6478c57 100644
--- a/proto/all_test.go
+++ b/proto/all_test.go
@@ -1413,6 +1413,20 @@
 	Unmarshal(b, new(MyMessage))
 }
 
+func TestLengthOverflow(t *testing.T) {
+	// Overflowing a length should not panic.
+	b := []byte{2<<3 | WireBytes, 1, 1, 3<<3 | WireBytes, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, 0x01}
+	Unmarshal(b, new(MyMessage))
+}
+
+func TestVarintOverflow(t *testing.T) {
+	// Overflowing a 64-bit length should not be allowed.
+	b := []byte{1<<3 | WireVarint, 0x01, 3<<3 | WireBytes, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01}
+	if err := Unmarshal(b, new(MyMessage)); err == nil {
+		t.Fatalf("Overflowed uint64 length without error")
+	}
+}
+
 func TestUnmarshalFuzz(t *testing.T) {
 	const N = 1000
 	seed := time.Now().UnixNano()
diff --git a/proto/decode.go b/proto/decode.go
index a38d391..f951c01 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -48,6 +48,9 @@
 // to convert an encoded protocol buffer into a struct of the wrong type.
 var ErrWrongType = errors.New("field/encoding mismatch: wrong type for field")
 
+// errOverflow is returned when an integer is too large to be represented.
+var errOverflow = errors.New("proto: integer overflow")
+
 // The fundamental decoders that interpret bytes on the wire.
 // Those that take integer types all return uint64 and are
 // therefore of type valueDecoder.
@@ -60,7 +63,7 @@
 // protocol buffer types.
 func DecodeVarint(buf []byte) (x uint64, n int) {
 	// x, n already 0
-	for shift := uint(0); ; shift += 7 {
+	for shift := uint(0); shift < 64; shift += 7 {
 		if n >= len(buf) {
 			return 0, 0
 		}
@@ -68,10 +71,12 @@
 		n++
 		x |= (b & 0x7F) << shift
 		if (b & 0x80) == 0 {
-			break
+			return x, n
 		}
 	}
-	return x, n
+
+	// The number is too large to represent in a 64-bit value.
+	return 0, 0
 }
 
 // DecodeVarint reads a varint-encoded integer from the Buffer.
@@ -84,7 +89,7 @@
 	i := p.index
 	l := len(p.buf)
 
-	for shift := uint(0); ; shift += 7 {
+	for shift := uint(0); shift < 64; shift += 7 {
 		if i >= l {
 			err = io.ErrUnexpectedEOF
 			return
@@ -93,10 +98,13 @@
 		i++
 		x |= (uint64(b) & 0x7F) << shift
 		if b < 0x80 {
-			break
+			p.index = i
+			return
 		}
 	}
-	p.index = i
+
+	// The number is too large to represent in a 64-bit value.
+	err = errOverflow
 	return
 }
 
@@ -106,7 +114,7 @@
 func (p *Buffer) DecodeFixed64() (x uint64, err error) {
 	// x, err already 0
 	i := p.index + 8
-	if i > len(p.buf) {
+	if i < 0 || i > len(p.buf) {
 		err = io.ErrUnexpectedEOF
 		return
 	}
@@ -129,7 +137,7 @@
 func (p *Buffer) DecodeFixed32() (x uint64, err error) {
 	// x, err already 0
 	i := p.index + 4
-	if i > len(p.buf) {
+	if i < 0 || i > len(p.buf) {
 		err = io.ErrUnexpectedEOF
 		return
 	}
@@ -182,13 +190,14 @@
 	if nb < 0 {
 		return nil, fmt.Errorf("proto: bad byte length %d", nb)
 	}
-	if p.index+nb > len(p.buf) {
+	end := p.index + nb
+	if end < p.index || end > len(p.buf) {
 		return nil, io.ErrUnexpectedEOF
 	}
 
 	if !alloc {
 		// todo: check if can get more uses of alloc=false
-		buf = p.buf[p.index : p.index+nb]
+		buf = p.buf[p.index:end]
 		p.index += nb
 		return
 	}
@@ -213,7 +222,6 @@
 // If the protocol buffer has extensions, and the field matches, add it as an extension.
 // Otherwise, if the XXX_unrecognized field exists, append the skipped data there.
 func (o *Buffer) skipAndSave(t reflect.Type, tag, wire int, base structPointer, unrecField field) error {
-
 	oi := o.index
 
 	err := o.skip(t, tag, wire)
@@ -555,6 +563,9 @@
 	nb := int(nn) // number of bytes of encoded int32s
 
 	fin := o.index + nb
+	if fin < o.index {
+		return errOverflow
+	}
 	for o.index < fin {
 		u, err := p.valDec(o)
 		if err != nil {
@@ -587,6 +598,9 @@
 	nb := int(nn) // number of bytes of encoded int64s
 
 	fin := o.index + nb
+	if fin < o.index {
+		return errOverflow
+	}
 	for o.index < fin {
 		u, err := p.valDec(o)
 		if err != nil {