goprotobuf: More efficient text marshaling, plus returning errors from underlying writer.
R=golang-dev, iant
CC=golang-dev
https://codereview.appspot.com/6894047
diff --git a/proto/text.go b/proto/text.go
index 507c214..fc3763e 100644
--- a/proto/text.go
+++ b/proto/text.go
@@ -34,6 +34,7 @@
// Functions for writing the text protocol buffer format.
import (
+ "bufio"
"bytes"
"fmt"
"io"
@@ -56,13 +57,17 @@
backslashBS = []byte{'\\', '\\'}
)
+type writer interface {
+ io.Writer
+ WriteByte(byte) error
+}
+
// textWriter is an io.Writer that tracks its indentation level.
type textWriter struct {
- ind int
- complete bool // if the current position is a complete line
- compact bool // whether to write out as a one-liner
- writer io.Writer
- writeByte func(byte) error // write a single byte to writer
+ ind int
+ complete bool // if the current position is a complete line
+ compact bool // whether to write out as a one-liner
+ w writer
}
func (w *textWriter) WriteString(s string) (n int, err error) {
@@ -71,7 +76,7 @@
w.writeIndent()
}
w.complete = false
- return io.WriteString(w.writer, s)
+ return io.WriteString(w.w, s)
}
// WriteString is typically called without newlines, so this
// codepath and its copy are rare. We copy to avoid
@@ -80,40 +85,52 @@
}
func (w *textWriter) Write(p []byte) (n int, err error) {
- n, err = len(p), nil
-
newlines := bytes.Count(p, newline)
if newlines == 0 {
if !w.compact && w.complete {
w.writeIndent()
}
- w.writer.Write(p)
+ n, err = w.w.Write(p)
w.complete = false
- return
+ return n, err
}
frags := bytes.SplitN(p, newline, newlines+1)
if w.compact {
for i, frag := range frags {
if i > 0 {
- w.writeByte(' ')
+ if err := w.w.WriteByte(' '); err != nil {
+ return n, err
+ }
+ n++
}
- w.writer.Write(frag)
+ nn, err := w.w.Write(frag)
+ n += nn
+ if err != nil {
+ return n, err
+ }
}
- return
+ return n, nil
}
for i, frag := range frags {
if w.complete {
w.writeIndent()
}
- w.writer.Write(frag)
+ nn, err := w.w.Write(frag)
+ n += nn
+ if err != nil {
+ return n, err
+ }
if i+1 < len(frags) {
- w.writeByte('\n')
+ if err := w.w.WriteByte('\n'); err != nil {
+ return n, err
+ }
+ n++
}
}
w.complete = len(frags[len(frags)-1]) == 0
- return
+ return n, nil
}
func (w *textWriter) WriteByte(c byte) error {
@@ -123,7 +140,7 @@
if !w.compact && w.complete {
w.writeIndent()
}
- err := w.writeByte(c)
+ err := w.w.WriteByte(c)
w.complete = c == '\n'
return err
}
@@ -138,11 +155,14 @@
w.ind--
}
-func writeName(w *textWriter, props *Properties) {
- io.WriteString(w, props.OrigName)
- if props.Wire != "group" {
- w.WriteByte(':')
+func writeName(w *textWriter, props *Properties) error {
+ if _, err := w.WriteString(props.OrigName); err != nil {
+ return err
}
+ if props.Wire != "group" {
+ return w.WriteByte(':')
+ }
+ return nil
}
var (
@@ -154,10 +174,9 @@
Bytes() []byte
}
-func writeStruct(w *textWriter, sv reflect.Value) {
+func writeStruct(w *textWriter, sv reflect.Value) error {
if sv.Type() == messageSetType {
- writeMessageSet(w, sv.Addr().Interface().(*MessageSet))
- return
+ return writeMessageSet(w, sv.Addr().Interface().(*MessageSet))
}
st := sv.Type()
@@ -171,7 +190,9 @@
// The first is handled here;
// the second is handled at the bottom of this function.
if name == "XXX_unrecognized" && !fv.IsNil() {
- writeUnknownStruct(w, fv.Interface().([]byte))
+ if err := writeUnknownStruct(w, fv.Interface().([]byte)); err != nil {
+ return err
+ }
}
continue
}
@@ -190,72 +211,111 @@
if props.Repeated && fv.Kind() == reflect.Slice {
// Repeated field.
for j := 0; j < fv.Len(); j++ {
- writeName(w, props)
- if !w.compact {
- w.WriteByte(' ')
+ if err := writeName(w, props); err != nil {
+ return err
}
- writeAny(w, fv.Index(j), props)
- w.WriteByte('\n')
+ if !w.compact {
+ if err := w.WriteByte(' '); err != nil {
+ return err
+ }
+ }
+ if err := writeAny(w, fv.Index(j), props); err != nil {
+ return err
+ }
+ if err := w.WriteByte('\n'); err != nil {
+ return err
+ }
}
continue
}
- writeName(w, props)
+ if err := writeName(w, props); err != nil {
+ return err
+ }
if !w.compact {
- w.WriteByte(' ')
+ if err := w.WriteByte(' '); err != nil {
+ return err
+ }
}
if b, ok := fv.Interface().(raw); ok {
- writeRaw(w, b.Bytes())
+ if err := writeRaw(w, b.Bytes()); err != nil {
+ return err
+ }
continue
}
- if props.Enum != "" && tryWriteEnum(w, props.Enum, fv) {
- // Enum written.
- } else {
- writeAny(w, fv, props)
+
+ var written bool
+ var err error
+ if props.Enum != "" {
+ written, err = tryWriteEnum(w, props.Enum, fv)
+ if err != nil {
+ return err
+ }
}
- w.WriteByte('\n')
+ if !written {
+ if err := writeAny(w, fv, props); err != nil {
+ return err
+ }
+ }
+
+ if err := w.WriteByte('\n'); err != nil {
+ return err
+ }
}
// Extensions (the XXX_extensions field).
pv := sv.Addr()
if pv.Type().Implements(extendableProtoType) {
- writeExtensions(w, pv)
+ if err := writeExtensions(w, pv); err != nil {
+ return err
+ }
}
+
+ return nil
}
// writeRaw writes an uninterpreted raw message.
-func writeRaw(w *textWriter, b []byte) {
- w.WriteByte('<')
+func writeRaw(w *textWriter, b []byte) error {
+ if err := w.WriteByte('<'); err != nil {
+ return err
+ }
if !w.compact {
- w.WriteByte('\n')
+ if err := w.WriteByte('\n'); err != nil {
+ return err
+ }
}
w.indent()
- writeUnknownStruct(w, b)
+ if err := writeUnknownStruct(w, b); err != nil {
+ return err
+ }
w.unindent()
- w.WriteByte('>')
+ if err := w.WriteByte('>'); err != nil {
+ return err
+ }
+ return nil
}
// tryWriteEnum attempts to write an enum value as a symbolic constant.
// If the enum is unregistered, nothing is written and false is returned.
-func tryWriteEnum(w *textWriter, enum string, v reflect.Value) bool {
+func tryWriteEnum(w *textWriter, enum string, v reflect.Value) (bool, error) {
v = reflect.Indirect(v)
if v.Type().Kind() != reflect.Int32 {
- return false
+ return false, nil
}
m, ok := enumNameMaps[enum]
if !ok {
- return false
+ return false, nil
}
str, ok := m[int32(v.Int())]
if !ok {
- return false
+ return false, nil
}
- fmt.Fprintf(w, str)
- return true
+ _, err := fmt.Fprintf(w, str)
+ return true, err
}
// writeAny writes an arbitrary field.
-func writeAny(w *textWriter, v reflect.Value, props *Properties) {
+func writeAny(w *textWriter, v reflect.Value, props *Properties) error {
v = reflect.Indirect(v)
// We don't attempt to serialise every possible value type; only those
@@ -263,26 +323,40 @@
switch v.Kind() {
case reflect.Slice:
// Should only be a []byte; repeated fields are handled in writeStruct.
- writeString(w, string(v.Interface().([]byte)))
+ if err := writeString(w, string(v.Interface().([]byte))); err != nil {
+ return err
+ }
case reflect.String:
- writeString(w, v.String())
+ if err := writeString(w, v.String()); err != nil {
+ return err
+ }
case reflect.Struct:
// Required/optional group/message.
var bra, ket byte = '<', '>'
if props != nil && props.Wire == "group" {
bra, ket = '{', '}'
}
- w.WriteByte(bra)
+ if err := w.WriteByte(bra); err != nil {
+ return err
+ }
if !w.compact {
- w.WriteByte('\n')
+ if err := w.WriteByte('\n'); err != nil {
+ return err
+ }
}
w.indent()
- writeStruct(w, v)
+ if err := writeStruct(w, v); err != nil {
+ return err
+ }
w.unindent()
- w.WriteByte(ket)
+ if err := w.WriteByte(ket); err != nil {
+ return err
+ }
default:
- fmt.Fprint(w, v.Interface())
+ _, err := fmt.Fprint(w, v.Interface())
+ return err
}
+ return nil
}
// equivalent to C's isprint.
@@ -295,117 +369,154 @@
// we treat the string as a byte sequence, and we use octal escapes.
// These differences are to maintain interoperability with the other
// languages' implementations of the text format.
-func writeString(w *textWriter, s string) {
- w.WriteByte('"') // use WriteByte here to get any needed indent
+func writeString(w *textWriter, s string) error {
+ // use WriteByte here to get any needed indent
+ if err := w.WriteByte('"'); err != nil {
+ return err
+ }
// Loop over the bytes, not the runes.
for i := 0; i < len(s); i++ {
+ var err error
// Divergence from C++: we don't escape apostrophes.
// There's no need to escape them, and the C++ parser
// copes with a naked apostrophe.
switch c := s[i]; c {
case '\n':
- w.writer.Write(backslashN)
+ _, err = w.w.Write(backslashN)
case '\r':
- w.writer.Write(backslashR)
+ _, err = w.w.Write(backslashR)
case '\t':
- w.writer.Write(backslashT)
+ _, err = w.w.Write(backslashT)
case '"':
- w.writer.Write(backslashDQ)
+ _, err = w.w.Write(backslashDQ)
case '\\':
- w.writer.Write(backslashBS)
+ _, err = w.w.Write(backslashBS)
default:
if isprint(c) {
- w.writeByte(c)
+ err = w.w.WriteByte(c)
} else {
- fmt.Fprintf(w.writer, "\\%03o", c)
+ _, err = fmt.Fprintf(w.w, "\\%03o", c)
}
}
+ if err != nil {
+ return err
+ }
}
- w.WriteByte('"')
+ return w.WriteByte('"')
}
-func writeMessageSet(w *textWriter, ms *MessageSet) {
+func writeMessageSet(w *textWriter, ms *MessageSet) error {
for _, item := range ms.Item {
id := *item.TypeId
if msd, ok := messageSetMap[id]; ok {
// Known message set type.
- fmt.Fprintf(w, "[%s]: <\n", msd.name)
+ if _, err := fmt.Fprintf(w, "[%s]: <\n", msd.name); err != nil {
+ return err
+ }
w.indent()
pb := reflect.New(msd.t.Elem())
if err := Unmarshal(item.Message, pb.Interface().(Message)); err != nil {
- fmt.Fprintf(w, "/* bad message: %v */\n", err)
+ if _, err := fmt.Fprintf(w, "/* bad message: %v */\n", err); err != nil {
+ return err
+ }
} else {
- writeStruct(w, pb.Elem())
+ if err := writeStruct(w, pb.Elem()); err != nil {
+ return err
+ }
}
} else {
// Unknown type.
- fmt.Fprintf(w, "[%d]: <\n", id)
+ if _, err := fmt.Fprintf(w, "[%d]: <\n", id); err != nil {
+ return err
+ }
w.indent()
- writeUnknownStruct(w, item.Message)
+ if err := writeUnknownStruct(w, item.Message); err != nil {
+ return err
+ }
}
w.unindent()
- w.Write(gtNewline)
+ if _, err := w.Write(gtNewline); err != nil {
+ return err
+ }
}
+ return nil
}
-func writeUnknownStruct(w *textWriter, data []byte) {
+func writeUnknownStruct(w *textWriter, data []byte) (err error) {
if !w.compact {
- fmt.Fprintf(w, "/* %d unknown bytes */\n", len(data))
+ if _, err := fmt.Fprintf(w, "/* %d unknown bytes */\n", len(data)); err != nil {
+ return err
+ }
}
b := NewBuffer(data)
for b.index < len(b.buf) {
x, err := b.DecodeVarint()
if err != nil {
- fmt.Fprintf(w, "/* %v */\n", err)
- return
+ _, err := fmt.Fprintf(w, "/* %v */\n", err)
+ return err
}
wire, tag := x&7, x>>3
if wire == WireEndGroup {
w.unindent()
- w.Write(endBraceNewline)
+ if _, err := w.Write(endBraceNewline); err != nil {
+ return err
+ }
continue
}
- fmt.Fprintf(w, "tag%d", tag)
+ if _, err := fmt.Fprintf(w, "tag%d", tag); err != nil {
+ return err
+ }
if wire != WireStartGroup {
- w.WriteByte(':')
+ if err := w.WriteByte(':'); err != nil {
+ return err
+ }
}
if !w.compact || wire == WireStartGroup {
- w.WriteByte(' ')
+ if err := w.WriteByte(' '); err != nil {
+ return err
+ }
}
switch wire {
case WireBytes:
- buf, err := b.DecodeRawBytes(false)
+ buf, e := b.DecodeRawBytes(false)
if err == nil {
- fmt.Fprintf(w, "%q", buf)
+ _, err = fmt.Fprintf(w, "%q", buf)
} else {
- fmt.Fprintf(w, "/* %v */", err)
+ _, err = fmt.Fprintf(w, "/* %v */", e)
}
case WireFixed32:
- x, err := b.DecodeFixed32()
- writeUnknownInt(w, x, err)
+ x, err = b.DecodeFixed32()
+ err = writeUnknownInt(w, x, err)
case WireFixed64:
- x, err := b.DecodeFixed64()
- writeUnknownInt(w, x, err)
+ x, err = b.DecodeFixed64()
+ err = writeUnknownInt(w, x, err)
case WireStartGroup:
- w.WriteByte('{')
+ err = w.WriteByte('{')
w.indent()
case WireVarint:
- x, err := b.DecodeVarint()
- writeUnknownInt(w, x, err)
+ x, err = b.DecodeVarint()
+ err = writeUnknownInt(w, x, err)
default:
- fmt.Fprintf(w, "/* unknown wire type %d */", wire)
+ _, err = fmt.Fprintf(w, "/* unknown wire type %d */", wire)
}
- w.WriteByte('\n')
+ if err != nil {
+ return err
+ }
+ if err = w.WriteByte('\n'); err != nil {
+ return err
+ }
}
+ return nil
}
-func writeUnknownInt(w *textWriter, x uint64, err error) {
+func writeUnknownInt(w *textWriter, x uint64, err error) error {
if err == nil {
- fmt.Fprint(w, x)
+ _, err = fmt.Fprint(w, x)
} else {
- fmt.Fprintf(w, "/* %v */", err)
+ _, err = fmt.Fprintf(w, "/* %v */", err)
}
+ return err
}
type int32Slice []int32
@@ -416,7 +527,7 @@
// writeExtensions writes all the extensions in pv.
// pv is assumed to be a pointer to a protocol message struct that is extendable.
-func writeExtensions(w *textWriter, pv reflect.Value) {
+func writeExtensions(w *textWriter, pv reflect.Value) error {
emap := extensionMaps[pv.Type().Elem()]
ep := pv.Interface().(extendableProto)
@@ -438,35 +549,53 @@
}
if desc == nil {
// Unknown extension.
- writeUnknownStruct(w, ext.enc)
+ if err := writeUnknownStruct(w, ext.enc); err != nil {
+ return err
+ }
continue
}
pb, err := GetExtension(ep, desc)
if err != nil {
- fmt.Fprintln(os.Stderr, "proto: failed getting extension: ", err)
+ if _, err := fmt.Fprintln(os.Stderr, "proto: failed getting extension: ", err); err != nil {
+ return err
+ }
continue
}
// Repeated extensions will appear as a slice.
if !desc.repeated() {
- writeExtension(w, desc.Name, pb)
+ if err := writeExtension(w, desc.Name, pb); err != nil {
+ return err
+ }
} else {
v := reflect.ValueOf(pb)
for i := 0; i < v.Len(); i++ {
- writeExtension(w, desc.Name, v.Index(i).Interface())
+ if err := writeExtension(w, desc.Name, v.Index(i).Interface()); err != nil {
+ return err
+ }
}
}
}
+ return nil
}
-func writeExtension(w *textWriter, name string, pb interface{}) {
- fmt.Fprintf(w, "[%s]:", name)
- if !w.compact {
- w.WriteByte(' ')
+func writeExtension(w *textWriter, name string, pb interface{}) error {
+ if _, err := fmt.Fprintf(w, "[%s]:", name); err != nil {
+ return err
}
- writeAny(w, reflect.ValueOf(pb), nil)
- w.WriteByte('\n')
+ if !w.compact {
+ if err := w.WriteByte(' '); err != nil {
+ return err
+ }
+ }
+ if err := writeAny(w, reflect.ValueOf(pb), nil); err != nil {
+ return err
+ }
+ if err := w.WriteByte('\n'); err != nil {
+ return err
+ }
+ return nil
}
func (w *textWriter) writeIndent() {
@@ -479,46 +608,43 @@
if n > len(spaces) {
n = len(spaces)
}
- w.writer.Write(spaces[:n])
+ w.w.Write(spaces[:n])
remain -= n
}
w.complete = false
}
-type byteWriter interface {
- WriteByte(byte) error
-}
-
-func marshalText(w io.Writer, pb Message, compact bool) {
+func marshalText(w io.Writer, pb Message, compact bool) error {
if pb == nil {
w.Write([]byte("<nil>"))
- return
+ return nil
}
- aw := new(textWriter)
- aw.writer = w
- aw.complete = true
- aw.compact = compact
-
- if bw, ok := w.(byteWriter); ok {
- aw.writeByte = func(c byte) error {
- return bw.WriteByte(c)
- }
- } else {
- var scratch [1]byte
- aw.writeByte = func(c byte) error {
- scratch[0] = c
- _, err := w.Write(scratch[:])
- return err
- }
+ var bw *bufio.Writer
+ ww, ok := w.(writer)
+ if !ok {
+ bw = bufio.NewWriter(w)
+ ww = bw
+ }
+ aw := &textWriter{
+ w: ww,
+ complete: true,
+ compact: compact,
}
// Dereference the received pointer so we don't have outer < and >.
v := reflect.Indirect(reflect.ValueOf(pb))
- writeStruct(aw, v)
+ if err := writeStruct(aw, v); err != nil {
+ return err
+ }
+ if bw != nil {
+ return bw.Flush()
+ }
+ return nil
}
// MarshalText writes a given protocol buffer in text format.
-func MarshalText(w io.Writer, pb Message) { marshalText(w, pb, false) }
+// The only errors returned are from w.
+func MarshalText(w io.Writer, pb Message) error { return marshalText(w, pb, false) }
// MarshalTextString is the same as MarshalText, but returns the string directly.
func MarshalTextString(pb Message) string {
@@ -528,7 +654,7 @@
}
// CompactText writes a given protocol buffer in compact text format (one line).
-func CompactText(w io.Writer, pb Message) { marshalText(w, pb, true) }
+func CompactText(w io.Writer, pb Message) error { return marshalText(w, pb, true) }
// CompactTextString is the same as CompactText, but returns the string directly.
func CompactTextString(pb Message) string {