goprotobuf: More efficient text marshaling, plus returning errors from underlying writer.

R=golang-dev, iant
CC=golang-dev
https://codereview.appspot.com/6894047
diff --git a/proto/text.go b/proto/text.go
index 507c214..fc3763e 100644
--- a/proto/text.go
+++ b/proto/text.go
@@ -34,6 +34,7 @@
 // Functions for writing the text protocol buffer format.
 
 import (
+	"bufio"
 	"bytes"
 	"fmt"
 	"io"
@@ -56,13 +57,17 @@
 	backslashBS     = []byte{'\\', '\\'}
 )
 
+type writer interface {
+	io.Writer
+	WriteByte(byte) error
+}
+
 // textWriter is an io.Writer that tracks its indentation level.
 type textWriter struct {
-	ind       int
-	complete  bool // if the current position is a complete line
-	compact   bool // whether to write out as a one-liner
-	writer    io.Writer
-	writeByte func(byte) error // write a single byte to writer
+	ind      int
+	complete bool // if the current position is a complete line
+	compact  bool // whether to write out as a one-liner
+	w        writer
 }
 
 func (w *textWriter) WriteString(s string) (n int, err error) {
@@ -71,7 +76,7 @@
 			w.writeIndent()
 		}
 		w.complete = false
-		return io.WriteString(w.writer, s)
+		return io.WriteString(w.w, s)
 	}
 	// WriteString is typically called without newlines, so this
 	// codepath and its copy are rare.  We copy to avoid
@@ -80,40 +85,52 @@
 }
 
 func (w *textWriter) Write(p []byte) (n int, err error) {
-	n, err = len(p), nil
-
 	newlines := bytes.Count(p, newline)
 	if newlines == 0 {
 		if !w.compact && w.complete {
 			w.writeIndent()
 		}
-		w.writer.Write(p)
+		n, err = w.w.Write(p)
 		w.complete = false
-		return
+		return n, err
 	}
 
 	frags := bytes.SplitN(p, newline, newlines+1)
 	if w.compact {
 		for i, frag := range frags {
 			if i > 0 {
-				w.writeByte(' ')
+				if err := w.w.WriteByte(' '); err != nil {
+					return n, err
+				}
+				n++
 			}
-			w.writer.Write(frag)
+			nn, err := w.w.Write(frag)
+			n += nn
+			if err != nil {
+				return n, err
+			}
 		}
-		return
+		return n, nil
 	}
 
 	for i, frag := range frags {
 		if w.complete {
 			w.writeIndent()
 		}
-		w.writer.Write(frag)
+		nn, err := w.w.Write(frag)
+		n += nn
+		if err != nil {
+			return n, err
+		}
 		if i+1 < len(frags) {
-			w.writeByte('\n')
+			if err := w.w.WriteByte('\n'); err != nil {
+				return n, err
+			}
+			n++
 		}
 	}
 	w.complete = len(frags[len(frags)-1]) == 0
-	return
+	return n, nil
 }
 
 func (w *textWriter) WriteByte(c byte) error {
@@ -123,7 +140,7 @@
 	if !w.compact && w.complete {
 		w.writeIndent()
 	}
-	err := w.writeByte(c)
+	err := w.w.WriteByte(c)
 	w.complete = c == '\n'
 	return err
 }
@@ -138,11 +155,14 @@
 	w.ind--
 }
 
-func writeName(w *textWriter, props *Properties) {
-	io.WriteString(w, props.OrigName)
-	if props.Wire != "group" {
-		w.WriteByte(':')
+func writeName(w *textWriter, props *Properties) error {
+	if _, err := w.WriteString(props.OrigName); err != nil {
+		return err
 	}
+	if props.Wire != "group" {
+		return w.WriteByte(':')
+	}
+	return nil
 }
 
 var (
@@ -154,10 +174,9 @@
 	Bytes() []byte
 }
 
-func writeStruct(w *textWriter, sv reflect.Value) {
+func writeStruct(w *textWriter, sv reflect.Value) error {
 	if sv.Type() == messageSetType {
-		writeMessageSet(w, sv.Addr().Interface().(*MessageSet))
-		return
+		return writeMessageSet(w, sv.Addr().Interface().(*MessageSet))
 	}
 
 	st := sv.Type()
@@ -171,7 +190,9 @@
 			// The first is handled here;
 			// the second is handled at the bottom of this function.
 			if name == "XXX_unrecognized" && !fv.IsNil() {
-				writeUnknownStruct(w, fv.Interface().([]byte))
+				if err := writeUnknownStruct(w, fv.Interface().([]byte)); err != nil {
+					return err
+				}
 			}
 			continue
 		}
@@ -190,72 +211,111 @@
 		if props.Repeated && fv.Kind() == reflect.Slice {
 			// Repeated field.
 			for j := 0; j < fv.Len(); j++ {
-				writeName(w, props)
-				if !w.compact {
-					w.WriteByte(' ')
+				if err := writeName(w, props); err != nil {
+					return err
 				}
-				writeAny(w, fv.Index(j), props)
-				w.WriteByte('\n')
+				if !w.compact {
+					if err := w.WriteByte(' '); err != nil {
+						return err
+					}
+				}
+				if err := writeAny(w, fv.Index(j), props); err != nil {
+					return err
+				}
+				if err := w.WriteByte('\n'); err != nil {
+					return err
+				}
 			}
 			continue
 		}
 
-		writeName(w, props)
+		if err := writeName(w, props); err != nil {
+			return err
+		}
 		if !w.compact {
-			w.WriteByte(' ')
+			if err := w.WriteByte(' '); err != nil {
+				return err
+			}
 		}
 		if b, ok := fv.Interface().(raw); ok {
-			writeRaw(w, b.Bytes())
+			if err := writeRaw(w, b.Bytes()); err != nil {
+				return err
+			}
 			continue
 		}
-		if props.Enum != "" && tryWriteEnum(w, props.Enum, fv) {
-			// Enum written.
-		} else {
-			writeAny(w, fv, props)
+
+		var written bool
+		var err error
+		if props.Enum != "" {
+			written, err = tryWriteEnum(w, props.Enum, fv)
+			if err != nil {
+				return err
+			}
 		}
-		w.WriteByte('\n')
+		if !written {
+			if err := writeAny(w, fv, props); err != nil {
+				return err
+			}
+		}
+
+		if err := w.WriteByte('\n'); err != nil {
+			return err
+		}
 	}
 
 	// Extensions (the XXX_extensions field).
 	pv := sv.Addr()
 	if pv.Type().Implements(extendableProtoType) {
-		writeExtensions(w, pv)
+		if err := writeExtensions(w, pv); err != nil {
+			return err
+		}
 	}
+
+	return nil
 }
 
 // writeRaw writes an uninterpreted raw message.
-func writeRaw(w *textWriter, b []byte) {
-	w.WriteByte('<')
+func writeRaw(w *textWriter, b []byte) error {
+	if err := w.WriteByte('<'); err != nil {
+		return err
+	}
 	if !w.compact {
-		w.WriteByte('\n')
+		if err := w.WriteByte('\n'); err != nil {
+			return err
+		}
 	}
 	w.indent()
-	writeUnknownStruct(w, b)
+	if err := writeUnknownStruct(w, b); err != nil {
+		return err
+	}
 	w.unindent()
-	w.WriteByte('>')
+	if err := w.WriteByte('>'); err != nil {
+		return err
+	}
+	return nil
 }
 
 // tryWriteEnum attempts to write an enum value as a symbolic constant.
 // If the enum is unregistered, nothing is written and false is returned.
-func tryWriteEnum(w *textWriter, enum string, v reflect.Value) bool {
+func tryWriteEnum(w *textWriter, enum string, v reflect.Value) (bool, error) {
 	v = reflect.Indirect(v)
 	if v.Type().Kind() != reflect.Int32 {
-		return false
+		return false, nil
 	}
 	m, ok := enumNameMaps[enum]
 	if !ok {
-		return false
+		return false, nil
 	}
 	str, ok := m[int32(v.Int())]
 	if !ok {
-		return false
+		return false, nil
 	}
-	fmt.Fprintf(w, str)
-	return true
+	_, err := fmt.Fprintf(w, str)
+	return true, err
 }
 
 // writeAny writes an arbitrary field.
-func writeAny(w *textWriter, v reflect.Value, props *Properties) {
+func writeAny(w *textWriter, v reflect.Value, props *Properties) error {
 	v = reflect.Indirect(v)
 
 	// We don't attempt to serialise every possible value type; only those
@@ -263,26 +323,40 @@
 	switch v.Kind() {
 	case reflect.Slice:
 		// Should only be a []byte; repeated fields are handled in writeStruct.
-		writeString(w, string(v.Interface().([]byte)))
+		if err := writeString(w, string(v.Interface().([]byte))); err != nil {
+			return err
+		}
 	case reflect.String:
-		writeString(w, v.String())
+		if err := writeString(w, v.String()); err != nil {
+			return err
+		}
 	case reflect.Struct:
 		// Required/optional group/message.
 		var bra, ket byte = '<', '>'
 		if props != nil && props.Wire == "group" {
 			bra, ket = '{', '}'
 		}
-		w.WriteByte(bra)
+		if err := w.WriteByte(bra); err != nil {
+			return err
+		}
 		if !w.compact {
-			w.WriteByte('\n')
+			if err := w.WriteByte('\n'); err != nil {
+				return err
+			}
 		}
 		w.indent()
-		writeStruct(w, v)
+		if err := writeStruct(w, v); err != nil {
+			return err
+		}
 		w.unindent()
-		w.WriteByte(ket)
+		if err := w.WriteByte(ket); err != nil {
+			return err
+		}
 	default:
-		fmt.Fprint(w, v.Interface())
+		_, err := fmt.Fprint(w, v.Interface())
+		return err
 	}
+	return nil
 }
 
 // equivalent to C's isprint.
@@ -295,117 +369,154 @@
 // we treat the string as a byte sequence, and we use octal escapes.
 // These differences are to maintain interoperability with the other
 // languages' implementations of the text format.
-func writeString(w *textWriter, s string) {
-	w.WriteByte('"') // use WriteByte here to get any needed indent
+func writeString(w *textWriter, s string) error {
+	// use WriteByte here to get any needed indent
+	if err := w.WriteByte('"'); err != nil {
+		return err
+	}
 	// Loop over the bytes, not the runes.
 	for i := 0; i < len(s); i++ {
+		var err error
 		// Divergence from C++: we don't escape apostrophes.
 		// There's no need to escape them, and the C++ parser
 		// copes with a naked apostrophe.
 		switch c := s[i]; c {
 		case '\n':
-			w.writer.Write(backslashN)
+			_, err = w.w.Write(backslashN)
 		case '\r':
-			w.writer.Write(backslashR)
+			_, err = w.w.Write(backslashR)
 		case '\t':
-			w.writer.Write(backslashT)
+			_, err = w.w.Write(backslashT)
 		case '"':
-			w.writer.Write(backslashDQ)
+			_, err = w.w.Write(backslashDQ)
 		case '\\':
-			w.writer.Write(backslashBS)
+			_, err = w.w.Write(backslashBS)
 		default:
 			if isprint(c) {
-				w.writeByte(c)
+				err = w.w.WriteByte(c)
 			} else {
-				fmt.Fprintf(w.writer, "\\%03o", c)
+				_, err = fmt.Fprintf(w.w, "\\%03o", c)
 			}
 		}
+		if err != nil {
+			return err
+		}
 	}
-	w.WriteByte('"')
+	return w.WriteByte('"')
 }
 
-func writeMessageSet(w *textWriter, ms *MessageSet) {
+func writeMessageSet(w *textWriter, ms *MessageSet) error {
 	for _, item := range ms.Item {
 		id := *item.TypeId
 		if msd, ok := messageSetMap[id]; ok {
 			// Known message set type.
-			fmt.Fprintf(w, "[%s]: <\n", msd.name)
+			if _, err := fmt.Fprintf(w, "[%s]: <\n", msd.name); err != nil {
+				return err
+			}
 			w.indent()
 
 			pb := reflect.New(msd.t.Elem())
 			if err := Unmarshal(item.Message, pb.Interface().(Message)); err != nil {
-				fmt.Fprintf(w, "/* bad message: %v */\n", err)
+				if _, err := fmt.Fprintf(w, "/* bad message: %v */\n", err); err != nil {
+					return err
+				}
 			} else {
-				writeStruct(w, pb.Elem())
+				if err := writeStruct(w, pb.Elem()); err != nil {
+					return err
+				}
 			}
 		} else {
 			// Unknown type.
-			fmt.Fprintf(w, "[%d]: <\n", id)
+			if _, err := fmt.Fprintf(w, "[%d]: <\n", id); err != nil {
+				return err
+			}
 			w.indent()
-			writeUnknownStruct(w, item.Message)
+			if err := writeUnknownStruct(w, item.Message); err != nil {
+				return err
+			}
 		}
 		w.unindent()
-		w.Write(gtNewline)
+		if _, err := w.Write(gtNewline); err != nil {
+			return err
+		}
 	}
+	return nil
 }
 
-func writeUnknownStruct(w *textWriter, data []byte) {
+func writeUnknownStruct(w *textWriter, data []byte) (err error) {
 	if !w.compact {
-		fmt.Fprintf(w, "/* %d unknown bytes */\n", len(data))
+		if _, err := fmt.Fprintf(w, "/* %d unknown bytes */\n", len(data)); err != nil {
+			return err
+		}
 	}
 	b := NewBuffer(data)
 	for b.index < len(b.buf) {
 		x, err := b.DecodeVarint()
 		if err != nil {
-			fmt.Fprintf(w, "/* %v */\n", err)
-			return
+			_, err := fmt.Fprintf(w, "/* %v */\n", err)
+			return err
 		}
 		wire, tag := x&7, x>>3
 		if wire == WireEndGroup {
 			w.unindent()
-			w.Write(endBraceNewline)
+			if _, err := w.Write(endBraceNewline); err != nil {
+				return err
+			}
 			continue
 		}
-		fmt.Fprintf(w, "tag%d", tag)
+		if _, err := fmt.Fprintf(w, "tag%d", tag); err != nil {
+			return err
+		}
 		if wire != WireStartGroup {
-			w.WriteByte(':')
+			if err := w.WriteByte(':'); err != nil {
+				return err
+			}
 		}
 		if !w.compact || wire == WireStartGroup {
-			w.WriteByte(' ')
+			if err := w.WriteByte(' '); err != nil {
+				return err
+			}
 		}
 		switch wire {
 		case WireBytes:
-			buf, err := b.DecodeRawBytes(false)
+			buf, e := b.DecodeRawBytes(false)
 			if err == nil {
-				fmt.Fprintf(w, "%q", buf)
+				_, err = fmt.Fprintf(w, "%q", buf)
 			} else {
-				fmt.Fprintf(w, "/* %v */", err)
+				_, err = fmt.Fprintf(w, "/* %v */", e)
 			}
 		case WireFixed32:
-			x, err := b.DecodeFixed32()
-			writeUnknownInt(w, x, err)
+			x, err = b.DecodeFixed32()
+			err = writeUnknownInt(w, x, err)
 		case WireFixed64:
-			x, err := b.DecodeFixed64()
-			writeUnknownInt(w, x, err)
+			x, err = b.DecodeFixed64()
+			err = writeUnknownInt(w, x, err)
 		case WireStartGroup:
-			w.WriteByte('{')
+			err = w.WriteByte('{')
 			w.indent()
 		case WireVarint:
-			x, err := b.DecodeVarint()
-			writeUnknownInt(w, x, err)
+			x, err = b.DecodeVarint()
+			err = writeUnknownInt(w, x, err)
 		default:
-			fmt.Fprintf(w, "/* unknown wire type %d */", wire)
+			_, err = fmt.Fprintf(w, "/* unknown wire type %d */", wire)
 		}
-		w.WriteByte('\n')
+		if err != nil {
+			return err
+		}
+		if err = w.WriteByte('\n'); err != nil {
+			return err
+		}
 	}
+	return nil
 }
 
-func writeUnknownInt(w *textWriter, x uint64, err error) {
+func writeUnknownInt(w *textWriter, x uint64, err error) error {
 	if err == nil {
-		fmt.Fprint(w, x)
+		_, err = fmt.Fprint(w, x)
 	} else {
-		fmt.Fprintf(w, "/* %v */", err)
+		_, err = fmt.Fprintf(w, "/* %v */", err)
 	}
+	return err
 }
 
 type int32Slice []int32
@@ -416,7 +527,7 @@
 
 // writeExtensions writes all the extensions in pv.
 // pv is assumed to be a pointer to a protocol message struct that is extendable.
-func writeExtensions(w *textWriter, pv reflect.Value) {
+func writeExtensions(w *textWriter, pv reflect.Value) error {
 	emap := extensionMaps[pv.Type().Elem()]
 	ep := pv.Interface().(extendableProto)
 
@@ -438,35 +549,53 @@
 		}
 		if desc == nil {
 			// Unknown extension.
-			writeUnknownStruct(w, ext.enc)
+			if err := writeUnknownStruct(w, ext.enc); err != nil {
+				return err
+			}
 			continue
 		}
 
 		pb, err := GetExtension(ep, desc)
 		if err != nil {
-			fmt.Fprintln(os.Stderr, "proto: failed getting extension: ", err)
+			if _, err := fmt.Fprintln(os.Stderr, "proto: failed getting extension: ", err); err != nil {
+				return err
+			}
 			continue
 		}
 
 		// Repeated extensions will appear as a slice.
 		if !desc.repeated() {
-			writeExtension(w, desc.Name, pb)
+			if err := writeExtension(w, desc.Name, pb); err != nil {
+				return err
+			}
 		} else {
 			v := reflect.ValueOf(pb)
 			for i := 0; i < v.Len(); i++ {
-				writeExtension(w, desc.Name, v.Index(i).Interface())
+				if err := writeExtension(w, desc.Name, v.Index(i).Interface()); err != nil {
+					return err
+				}
 			}
 		}
 	}
+	return nil
 }
 
-func writeExtension(w *textWriter, name string, pb interface{}) {
-	fmt.Fprintf(w, "[%s]:", name)
-	if !w.compact {
-		w.WriteByte(' ')
+func writeExtension(w *textWriter, name string, pb interface{}) error {
+	if _, err := fmt.Fprintf(w, "[%s]:", name); err != nil {
+		return err
 	}
-	writeAny(w, reflect.ValueOf(pb), nil)
-	w.WriteByte('\n')
+	if !w.compact {
+		if err := w.WriteByte(' '); err != nil {
+			return err
+		}
+	}
+	if err := writeAny(w, reflect.ValueOf(pb), nil); err != nil {
+		return err
+	}
+	if err := w.WriteByte('\n'); err != nil {
+		return err
+	}
+	return nil
 }
 
 func (w *textWriter) writeIndent() {
@@ -479,46 +608,43 @@
 		if n > len(spaces) {
 			n = len(spaces)
 		}
-		w.writer.Write(spaces[:n])
+		w.w.Write(spaces[:n])
 		remain -= n
 	}
 	w.complete = false
 }
 
-type byteWriter interface {
-	WriteByte(byte) error
-}
-
-func marshalText(w io.Writer, pb Message, compact bool) {
+func marshalText(w io.Writer, pb Message, compact bool) error {
 	if pb == nil {
 		w.Write([]byte("<nil>"))
-		return
+		return nil
 	}
-	aw := new(textWriter)
-	aw.writer = w
-	aw.complete = true
-	aw.compact = compact
-
-	if bw, ok := w.(byteWriter); ok {
-		aw.writeByte = func(c byte) error {
-			return bw.WriteByte(c)
-		}
-	} else {
-		var scratch [1]byte
-		aw.writeByte = func(c byte) error {
-			scratch[0] = c
-			_, err := w.Write(scratch[:])
-			return err
-		}
+	var bw *bufio.Writer
+	ww, ok := w.(writer)
+	if !ok {
+		bw = bufio.NewWriter(w)
+		ww = bw
+	}
+	aw := &textWriter{
+		w:        ww,
+		complete: true,
+		compact:  compact,
 	}
 
 	// Dereference the received pointer so we don't have outer < and >.
 	v := reflect.Indirect(reflect.ValueOf(pb))
-	writeStruct(aw, v)
+	if err := writeStruct(aw, v); err != nil {
+		return err
+	}
+	if bw != nil {
+		return bw.Flush()
+	}
+	return nil
 }
 
 // MarshalText writes a given protocol buffer in text format.
-func MarshalText(w io.Writer, pb Message) { marshalText(w, pb, false) }
+// The only errors returned are from w.
+func MarshalText(w io.Writer, pb Message) error { return marshalText(w, pb, false) }
 
 // MarshalTextString is the same as MarshalText, but returns the string directly.
 func MarshalTextString(pb Message) string {
@@ -528,7 +654,7 @@
 }
 
 // CompactText writes a given protocol buffer in compact text format (one line).
-func CompactText(w io.Writer, pb Message) { marshalText(w, pb, true) }
+func CompactText(w io.Writer, pb Message) error { return marshalText(w, pb, true) }
 
 // CompactTextString is the same as CompactText, but returns the string directly.
 func CompactTextString(pb Message) string {