proto: wire decoding support

Add proto.Unmarshal.

Test cases all produce identical results to the v1 unmarshaller.

Change-Id: I42259266018a14e88a650c5d83a043cb17a3a15d
Reviewed-on: https://go-review.googlesource.com/c/153918
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/proto/decode.go b/proto/decode.go
new file mode 100644
index 0000000..4af5f16
--- /dev/null
+++ b/proto/decode.go
@@ -0,0 +1,191 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+package proto
+
+import (
+	"errors"
+
+	"github.com/golang/protobuf/v2/internal/encoding/wire"
+	"github.com/golang/protobuf/v2/internal/pragma"
+	"github.com/golang/protobuf/v2/reflect/protoreflect"
+)
+
+// UnmarshalOptions configures the unmarshaler.
+//
+// Example usage:
+//   err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
+type UnmarshalOptions struct {
+	// If DiscardUnknown is set, unknown fields are ignored.
+	DiscardUnknown bool
+
+	pragma.NoUnkeyedLiterals
+}
+
+// Unmarshal parses the wire-format message in b and places the result in m.
+func Unmarshal(b []byte, m Message) error {
+	return UnmarshalOptions{}.Unmarshal(b, m)
+}
+
+// Unmarshal parses the wire-format message in b and places the result in m.
+func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
+	// TODO: Reset m?
+	return o.unmarshalMessage(b, m.ProtoReflect())
+}
+
+func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
+	messageType := m.Type()
+	fieldTypes := messageType.Fields()
+	knownFields := m.KnownFields()
+	unknownFields := m.UnknownFields()
+	for len(b) > 0 {
+		// Parse the tag (field number and wire type).
+		num, wtyp, tagLen := wire.ConsumeTag(b)
+		if tagLen < 0 {
+			return wire.ParseError(tagLen)
+		}
+
+		// Parse the field value.
+		fieldType := fieldTypes.ByNumber(num)
+		var err error
+		var valLen int
+		switch {
+		case fieldType == nil:
+			err = errUnknown
+		case fieldType.Cardinality() != protoreflect.Repeated:
+			valLen, err = o.unmarshalScalarField(b[tagLen:], wtyp, num, knownFields, fieldType)
+		case !fieldType.IsMap():
+			valLen, err = o.unmarshalList(b[tagLen:], wtyp, num, knownFields.Get(num).List(), fieldType.Kind())
+		default:
+			valLen, err = o.unmarshalMap(b[tagLen:], wtyp, num, knownFields.Get(num).Map(), fieldType)
+		}
+		if err == errUnknown {
+			valLen = wire.ConsumeFieldValue(num, wtyp, b[tagLen:])
+			if valLen < 0 {
+				return wire.ParseError(valLen)
+			}
+			unknownFields.Set(num, append(unknownFields.Get(num), b[:tagLen+valLen]...))
+		} else if err != nil {
+			return err
+		}
+		b = b[tagLen+valLen:]
+	}
+	// TODO: required field checks
+	return nil
+}
+
+func (o UnmarshalOptions) unmarshalScalarField(b []byte, wtyp wire.Type, num wire.Number, knownFields protoreflect.KnownFields, field protoreflect.FieldDescriptor) (n int, err error) {
+	v, n, err := o.unmarshalScalar(b, wtyp, num, field.Kind())
+	if err != nil {
+		return 0, err
+	}
+	switch field.Kind() {
+	case protoreflect.GroupKind, protoreflect.MessageKind:
+		// Messages are merged with any existing message value,
+		// unless the message is part of a oneof.
+		//
+		// TODO: C++ merges into oneofs, while v1 does not.
+		// Evaluate which behavior to pick.
+		var m protoreflect.Message
+		if knownFields.Has(num) && field.OneofType() == nil {
+			m = knownFields.Get(num).Message()
+		} else {
+			m = knownFields.NewMessage(num).ProtoReflect()
+			knownFields.Set(num, protoreflect.ValueOf(m))
+		}
+		if err := o.unmarshalMessage(v.Bytes(), m); err != nil {
+			return 0, err
+		}
+	default:
+		// Non-message scalars replace the previous value.
+		knownFields.Set(num, v)
+	}
+	return n, nil
+}
+
+func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number, mapv protoreflect.Map, field protoreflect.FieldDescriptor) (n int, err error) {
+	if wtyp != wire.BytesType {
+		return 0, errUnknown
+	}
+	b, n = wire.ConsumeBytes(b)
+	if n < 0 {
+		return 0, wire.ParseError(n)
+	}
+	var (
+		keyField = field.MessageType().Fields().ByNumber(1)
+		valField = field.MessageType().Fields().ByNumber(2)
+		key      protoreflect.Value
+		val      protoreflect.Value
+		haveKey  bool
+		haveVal  bool
+	)
+	switch valField.Kind() {
+	case protoreflect.GroupKind, protoreflect.MessageKind:
+		val = protoreflect.ValueOf(mapv.NewMessage().ProtoReflect())
+	}
+	// Map entries are represented as a two-element message with fields
+	// containing the key and value.
+	for len(b) > 0 {
+		num, wtyp, n := wire.ConsumeTag(b)
+		if n < 0 {
+			return 0, wire.ParseError(n)
+		}
+		b = b[n:]
+		err = errUnknown
+		switch num {
+		case 1:
+			key, n, err = o.unmarshalScalar(b, wtyp, num, keyField.Kind())
+			if err != nil {
+				break
+			}
+			haveKey = true
+		case 2:
+			var v protoreflect.Value
+			v, n, err = o.unmarshalScalar(b, wtyp, num, valField.Kind())
+			if err != nil {
+				break
+			}
+			switch valField.Kind() {
+			case protoreflect.GroupKind, protoreflect.MessageKind:
+				if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
+					return 0, err
+				}
+			default:
+				val = v
+			}
+			haveVal = true
+		}
+		if err == errUnknown {
+			n = wire.ConsumeFieldValue(num, wtyp, b)
+			if n < 0 {
+				return 0, wire.ParseError(n)
+			}
+		} else if err != nil {
+			return 0, err
+		}
+		b = b[n:]
+	}
+	// Every map entry should have entries for key and value, but this is not strictly required.
+	if !haveKey {
+		key = keyField.Default()
+	}
+	if !haveVal {
+		switch valField.Kind() {
+		case protoreflect.GroupKind, protoreflect.MessageKind:
+			// Trigger required field checks by unmarshaling an empty message.
+			if err := o.unmarshalMessage(nil, val.Message()); err != nil {
+				return 0, err
+			}
+		default:
+			val = valField.Default()
+		}
+	}
+	mapv.Set(key.MapKey(), val)
+	return n, nil
+}
+
+// errUnknown is used internally to indicate fields which should be added
+// to the unknown field set of a message. It is never returned from an exported
+// function.
+var errUnknown = errors.New("unknown")