internal/encoding/json: rewrite to a token-based encoder and decoder

Previous decoder decodes a JSON number into a float64, which lacks
64-bit integer precision.

I attempted to retrofit it with storing the raw bytes and parsed out
number parts, see golang.org/cl/164377.  While that is possible, the
encoding logic for Value is not symmetrical with the decoding logic and
can be confusing since both utilizes the same Value struct.

Joe and I decided that it would be better to rewrite the JSON encoder
and decoder to be token-based instead, removing the need for sharing a
model type plus making it more efficient.

Change-Id: Ic0601428a824be4e20141623409ab4d92b6167c7
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/165677
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/encoding/json/decode.go b/internal/encoding/json/decode.go
index 769619b..543abbc 100644
--- a/internal/encoding/json/decode.go
+++ b/internal/encoding/json/decode.go
@@ -6,189 +6,389 @@
 
 import (
 	"bytes"
+	"fmt"
 	"io"
 	"regexp"
+	"strconv"
 	"unicode/utf8"
 
 	"github.com/golang/protobuf/v2/internal/errors"
 )
 
-type syntaxError struct{ error }
+// Decoder is a token-based JSON decoder.
+type Decoder struct {
+	lastType Type
 
-func newSyntaxError(f string, x ...interface{}) error {
-	return syntaxError{errors.New(f, x...)}
+	// startStack is a stack containing StartObject and StartArray types. The
+	// top of stack represents the object or the array the current value is
+	// directly located in.
+	startStack []Type
+
+	// orig is used in reporting line and column.
+	orig []byte
+	// in contains the unconsumed input.
+	in []byte
 }
 
-// Unmarshal parses b as the JSON format.
-// It returns a Value, which represents the input as an AST.
-func Unmarshal(b []byte) (Value, error) {
-	p := decoder{in: b}
-	p.consume(0) // trim leading spaces
-	v, err := p.unmarshalValue()
-	if !p.nerr.Merge(err) {
-		if e, ok := err.(syntaxError); ok {
-			b = b[:len(b)-len(p.in)] // consumed input
-			line := bytes.Count(b, []byte("\n")) + 1
-			if i := bytes.LastIndexByte(b, '\n'); i >= 0 {
-				b = b[i+1:]
-			}
-			column := utf8.RuneCount(b) + 1 // ignore multi-rune characters
-			err = errors.New("syntax error (line %d:%d): %v", line, column, e.error)
-		}
+// NewDecoder returns a Decoder to read the given []byte.
+func NewDecoder(b []byte) *Decoder {
+	return &Decoder{orig: b, in: b}
+}
+
+// ReadNext returns the next JSON value. It will return an error if there is no
+// valid JSON value.  For String types containing invalid UTF8 characters, a
+// non-fatal error is returned and caller can call ReadNext for the next value.
+func (d *Decoder) ReadNext() (Value, error) {
+	var nerr errors.NonFatal
+	value, n, err := d.parseNext()
+	if !nerr.Merge(err) {
 		return Value{}, err
 	}
-	if len(p.in) > 0 {
-		return Value{}, errors.New("%d bytes of unconsumed input", len(p.in))
+
+	switch value.typ {
+	case EOF:
+		if len(d.startStack) != 0 ||
+			d.lastType&Null|Bool|Number|String|EndObject|EndArray == 0 {
+			return Value{}, io.ErrUnexpectedEOF
+		}
+
+	case Null:
+		if !d.isValueNext() {
+			return Value{}, d.newSyntaxError("unexpected value null")
+		}
+
+	case Bool, Number:
+		if !d.isValueNext() {
+			return Value{}, d.newSyntaxError("unexpected value %v", value)
+		}
+
+	case String:
+		if d.isValueNext() {
+			break
+		}
+		// Check if this is for an object name.
+		if d.lastType&(StartObject|comma) == 0 {
+			return Value{}, d.newSyntaxError("unexpected value %q", value)
+		}
+		d.in = d.in[n:]
+		d.consume(0)
+		if c := d.in[0]; c != ':' {
+			return Value{}, d.newSyntaxError(`unexpected character %v, missing ":" after object name`, string(c))
+		}
+		n = 1
+		value.typ = Name
+
+	case StartObject, StartArray:
+		if !d.isValueNext() {
+			return Value{}, d.newSyntaxError("unexpected character %v", value)
+		}
+		d.startStack = append(d.startStack, value.typ)
+
+	case EndObject:
+		if len(d.startStack) == 0 ||
+			d.lastType == comma ||
+			d.startStack[len(d.startStack)-1] != StartObject {
+			return Value{}, d.newSyntaxError("unexpected character }")
+		}
+		d.startStack = d.startStack[:len(d.startStack)-1]
+
+	case EndArray:
+		if len(d.startStack) == 0 ||
+			d.lastType == comma ||
+			d.startStack[len(d.startStack)-1] != StartArray {
+			return Value{}, d.newSyntaxError("unexpected character ]")
+		}
+		d.startStack = d.startStack[:len(d.startStack)-1]
+
+	case comma:
+		if len(d.startStack) == 0 ||
+			d.lastType&(Null|Bool|Number|String|EndObject|EndArray) == 0 {
+			return Value{}, d.newSyntaxError("unexpected character ,")
+		}
 	}
-	return v, p.nerr.E
+
+	// Update lastType only after validating value to be in the right
+	// sequence.
+	d.lastType = value.typ
+	d.in = d.in[n:]
+
+	if d.lastType == comma {
+		return d.ReadNext()
+	}
+	return value, nerr.E
 }
 
-type decoder struct {
-	nerr errors.NonFatal
-	in   []byte
-}
+var (
+	literalRegexp = regexp.MustCompile(`^(null|true|false)`)
+	// Any sequence that looks like a non-delimiter (for error reporting).
+	errRegexp = regexp.MustCompile(`^([-+._a-zA-Z0-9]{1,32}|.)`)
+)
 
-var literalRegexp = regexp.MustCompile("^(null|true|false)")
+// parseNext parses for the next JSON value. It returns a Value object for
+// different types, except for Name. It also returns the size that was parsed.
+// It does not handle whether the next value is in a valid sequence or not, it
+// only ensures that the value is a valid one.
+func (d *Decoder) parseNext() (value Value, n int, err error) {
+	// Trim leading spaces.
+	d.consume(0)
 
-func (p *decoder) unmarshalValue() (Value, error) {
-	if len(p.in) == 0 {
-		return Value{}, io.ErrUnexpectedEOF
+	in := d.in
+	if len(in) == 0 {
+		return d.newValue(EOF, nil, nil), 0, nil
 	}
-	switch p.in[0] {
+
+	switch in[0] {
 	case 'n', 't', 'f':
-		if n := matchWithDelim(literalRegexp, p.in); n > 0 {
-			var v Value
-			switch p.in[0] {
-			case 'n':
-				v = rawValueOf(nil, p.in[:n:n])
-			case 't':
-				v = rawValueOf(true, p.in[:n:n])
-			case 'f':
-				v = rawValueOf(false, p.in[:n:n])
-			}
-			p.consume(n)
-			return v, nil
+		n := matchWithDelim(literalRegexp, in)
+		if n == 0 {
+			return Value{}, 0, d.newSyntaxError("invalid value %s", errRegexp.Find(in))
 		}
-		return Value{}, newSyntaxError("invalid %q as literal", errRegexp.Find(p.in))
+		switch in[0] {
+		case 'n':
+			return d.newValue(Null, in[:n], nil), n, nil
+		case 't':
+			return d.newValue(Bool, in[:n], true), n, nil
+		case 'f':
+			return d.newValue(Bool, in[:n], false), n, nil
+		}
+
 	case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
-		return p.unmarshalNumber()
+		num, n := parseNumber(in)
+		if num == nil {
+			return Value{}, 0, d.newSyntaxError("invalid number %s", errRegexp.Find(in))
+		}
+		return d.newValue(Number, in[:n], num), n, nil
+
 	case '"':
-		return p.unmarshalString()
-	case '[':
-		return p.unmarshalArray()
+		var nerr errors.NonFatal
+		s, n, err := d.parseString(in)
+		if !nerr.Merge(err) {
+			return Value{}, 0, err
+		}
+		return d.newValue(String, in[:n], s), n, nerr.E
+
 	case '{':
-		return p.unmarshalObject()
-	default:
-		return Value{}, newSyntaxError("invalid %q as value", errRegexp.Find(p.in))
+		return d.newValue(StartObject, in[:1], nil), 1, nil
+
+	case '}':
+		return d.newValue(EndObject, in[:1], nil), 1, nil
+
+	case '[':
+		return d.newValue(StartArray, in[:1], nil), 1, nil
+
+	case ']':
+		return d.newValue(EndArray, in[:1], nil), 1, nil
+
+	case ',':
+		return d.newValue(comma, in[:1], nil), 1, nil
 	}
+	return Value{}, 0, d.newSyntaxError("invalid value %s", errRegexp.Find(in))
 }
 
-func (p *decoder) unmarshalArray() (Value, error) {
-	b := p.in
-	var elems []Value
-	if err := p.consumeChar('[', "at start of array"); err != nil {
-		return Value{}, err
+// position returns line and column number of parsed bytes.
+func (d *Decoder) position() (int, int) {
+	// Calculate line and column of consumed input.
+	b := d.orig[:len(d.orig)-len(d.in)]
+	line := bytes.Count(b, []byte("\n")) + 1
+	if i := bytes.LastIndexByte(b, '\n'); i >= 0 {
+		b = b[i+1:]
 	}
-	if len(p.in) > 0 && p.in[0] != ']' {
-		for len(p.in) > 0 {
-			v, err := p.unmarshalValue()
-			if !p.nerr.Merge(err) {
-				return Value{}, err
-			}
-			elems = append(elems, v)
-			if !p.tryConsumeChar(',') {
-				break
-			}
-		}
-	}
-	if err := p.consumeChar(']', "at end of array"); err != nil {
-		return Value{}, err
-	}
-	b = b[:len(b)-len(p.in)]
-	return rawValueOf(elems, b[:len(b):len(b)]), nil
+	column := utf8.RuneCount(b) + 1 // ignore multi-rune characters
+	return line, column
 }
 
-func (p *decoder) unmarshalObject() (Value, error) {
-	b := p.in
-	var items [][2]Value
-	if err := p.consumeChar('{', "at start of object"); err != nil {
-		return Value{}, err
-	}
-	if len(p.in) > 0 && p.in[0] != '}' {
-		for len(p.in) > 0 {
-			k, err := p.unmarshalString()
-			if !p.nerr.Merge(err) {
-				return Value{}, err
-			}
-			if err := p.consumeChar(':', "in object"); err != nil {
-				return Value{}, err
-			}
-			v, err := p.unmarshalValue()
-			if !p.nerr.Merge(err) {
-				return Value{}, err
-			}
-			items = append(items, [2]Value{k, v})
-			if !p.tryConsumeChar(',') {
-				break
-			}
-		}
-	}
-	if err := p.consumeChar('}', "at end of object"); err != nil {
-		return Value{}, err
-	}
-	b = b[:len(b)-len(p.in)]
-	return rawValueOf(items, b[:len(b):len(b)]), nil
+// newSyntaxError returns an error with line and column information useful for
+// syntax errors.
+func (d *Decoder) newSyntaxError(f string, x ...interface{}) error {
+	e := errors.New(f, x...)
+	line, column := d.position()
+	return errors.New("syntax error (line %d:%d): %v", line, column, e)
 }
 
-func (p *decoder) consumeChar(c byte, msg string) error {
-	if p.tryConsumeChar(c) {
-		return nil
-	}
-	if len(p.in) == 0 {
-		return io.ErrUnexpectedEOF
-	}
-	return newSyntaxError("invalid character %q, expected %q %s", p.in[0], c, msg)
-}
-
-func (p *decoder) tryConsumeChar(c byte) bool {
-	if len(p.in) > 0 && p.in[0] == c {
-		p.consume(1)
-		return true
-	}
-	return false
-}
-
-// consume consumes n bytes of input and any subsequent whitespace.
-func (p *decoder) consume(n int) {
-	p.in = p.in[n:]
-	for len(p.in) > 0 {
-		switch p.in[0] {
-		case ' ', '\n', '\r', '\t':
-			p.in = p.in[1:]
-		default:
-			return
-		}
-	}
-}
-
-// Any sequence that looks like a non-delimiter (for error reporting).
-var errRegexp = regexp.MustCompile("^([-+._a-zA-Z0-9]{1,32}|.)")
-
 // matchWithDelim matches r with the input b and verifies that the match
 // terminates with a delimiter of some form (e.g., r"[^-+_.a-zA-Z0-9]").
 // As a special case, EOF is considered a delimiter.
 func matchWithDelim(r *regexp.Regexp, b []byte) int {
 	n := len(r.Find(b))
 	if n < len(b) {
-		// Check that that the next character is a delimiter.
-		c := b[n]
-		notDelim := (c == '-' || c == '+' || c == '.' || c == '_' ||
-			('a' <= c && c <= 'z') ||
-			('A' <= c && c <= 'Z') ||
-			('0' <= c && c <= '9'))
-		if notDelim {
+		// Check that the next character is a delimiter.
+		if isNotDelim(b[n]) {
 			return 0
 		}
 	}
 	return n
 }
+
+// isNotDelim returns true if given byte is a not delimiter character.
+func isNotDelim(c byte) bool {
+	return (c == '-' || c == '+' || c == '.' || c == '_' ||
+		('a' <= c && c <= 'z') ||
+		('A' <= c && c <= 'Z') ||
+		('0' <= c && c <= '9'))
+}
+
+// consume consumes n bytes of input and any subsequent whitespace.
+func (d *Decoder) consume(n int) {
+	d.in = d.in[n:]
+	for len(d.in) > 0 {
+		switch d.in[0] {
+		case ' ', '\n', '\r', '\t':
+			d.in = d.in[1:]
+		default:
+			return
+		}
+	}
+}
+
+// isValueNext returns true if next type should be a JSON value: Null,
+// Number, String or Bool.
+func (d *Decoder) isValueNext() bool {
+	if len(d.startStack) == 0 {
+		return d.lastType == 0
+	}
+
+	start := d.startStack[len(d.startStack)-1]
+	switch start {
+	case StartObject:
+		return d.lastType&Name != 0
+	case StartArray:
+		return d.lastType&(StartArray|comma) != 0
+	}
+	panic(fmt.Sprintf(
+		"unreachable logic in Decoder.isValueNext, lastType: %v, startStack: %v",
+		d.lastType, start))
+}
+
+// newValue constructs a Value.
+func (d *Decoder) newValue(typ Type, input []byte, value interface{}) Value {
+	line, column := d.position()
+	return Value{
+		input:  input,
+		line:   line,
+		column: column,
+		typ:    typ,
+		value:  value,
+	}
+}
+
+// Value contains a JSON type and value parsed from calling Decoder.ReadNext.
+type Value struct {
+	input  []byte
+	line   int
+	column int
+	typ    Type
+	// value will be set to the following Go type based on the type field:
+	//    Bool   => bool
+	//    Number => *numberParts
+	//    String => string
+	//    Name   => string
+	// It will be nil if none of the above.
+	value interface{}
+}
+
+func (v Value) newError(f string, x ...interface{}) error {
+	e := errors.New(f, x...)
+	return errors.New("error (line %d:%d): %v", v.line, v.column, e)
+}
+
+// Type returns the JSON type.
+func (v Value) Type() Type {
+	return v.typ
+}
+
+// Position returns the line and column of the value.
+func (v Value) Position() (int, int) {
+	return v.line, v.column
+}
+
+// Bool returns the bool value if token is Bool, else it will return an error.
+func (v Value) Bool() (bool, error) {
+	if v.typ != Bool {
+		return false, v.newError("%s is not a bool", v.input)
+	}
+	return v.value.(bool), nil
+}
+
+// String returns the string value for a JSON string token or the read value in
+// string if token is not a string.
+func (v Value) String() string {
+	if v.typ != String {
+		return string(v.input)
+	}
+	return v.value.(string)
+}
+
+// Name returns the object name if token is Name, else it will return an error.
+func (v Value) Name() (string, error) {
+	if v.typ != Name {
+		return "", v.newError("%s is not an object name", v.input)
+	}
+	return v.value.(string), nil
+}
+
+// Float returns the floating-point number if token is Number, else it will
+// return an error.
+//
+// The floating-point precision is specified by the bitSize parameter: 32 for
+// float32 or 64 for float64. If bitSize=32, the result still has type float64,
+// but it will be convertible to float32 without changing its value. It will
+// return an error if the number exceeds the floating point limits for given
+// bitSize.
+func (v Value) Float(bitSize int) (float64, error) {
+	if v.typ != Number {
+		return 0, v.newError("%s is not a number", v.input)
+	}
+	f, err := strconv.ParseFloat(string(v.input), bitSize)
+	if err != nil {
+		return 0, v.newError("%v", err)
+	}
+	return f, nil
+}
+
+// Int returns the signed integer number if token is Number, else it will
+// return an error.
+//
+// The given bitSize specifies the integer type that the result must fit into.
+// It returns an error if the number is not an integer value or if the result
+// exceeds the limits for given bitSize.
+func (v Value) Int(bitSize int) (int64, error) {
+	s, err := v.getIntStr()
+	if err != nil {
+		return 0, err
+	}
+	n, err := strconv.ParseInt(s, 10, bitSize)
+	if err != nil {
+		return 0, v.newError("%v", err)
+	}
+	return n, nil
+}
+
+// Uint returns the signed integer number if token is Number, else it will
+// return an error.
+//
+// The given bitSize specifies the unsigned integer type that the result must
+// fit into.  It returns an error if the number is not an unsigned integer value
+// or if the result exceeds the limits for given bitSize.
+func (v Value) Uint(bitSize int) (uint64, error) {
+	s, err := v.getIntStr()
+	if err != nil {
+		return 0, err
+	}
+	n, err := strconv.ParseUint(s, 10, bitSize)
+	if err != nil {
+		return 0, v.newError("%v", err)
+	}
+	return n, nil
+}
+
+func (v Value) getIntStr() (string, error) {
+	if v.typ != Number {
+		return "", v.newError("%s is not a number", v.input)
+	}
+	pnum := v.value.(*numberParts)
+	num, ok := normalizeToIntString(pnum)
+	if !ok {
+		return "", v.newError("cannot convert %s to integer", v.input)
+	}
+	return num, nil
+}