proto: wire decoding support
Add proto.Unmarshal.
Test cases all produce identical results to the v1 unmarshaller.
Change-Id: I42259266018a14e88a650c5d83a043cb17a3a15d
Reviewed-on: https://go-review.googlesource.com/c/153918
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/proto/decode.go b/proto/decode.go
new file mode 100644
index 0000000..4af5f16
--- /dev/null
+++ b/proto/decode.go
@@ -0,0 +1,191 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+package proto
+
+import (
+ "errors"
+
+ "github.com/golang/protobuf/v2/internal/encoding/wire"
+ "github.com/golang/protobuf/v2/internal/pragma"
+ "github.com/golang/protobuf/v2/reflect/protoreflect"
+)
+
+// UnmarshalOptions configures the unmarshaler.
+//
+// Example usage:
+// err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
+type UnmarshalOptions struct {
+ // If DiscardUnknown is set, unknown fields are ignored.
+ DiscardUnknown bool
+
+ pragma.NoUnkeyedLiterals
+}
+
+// Unmarshal parses the wire-format message in b and places the result in m.
+func Unmarshal(b []byte, m Message) error {
+ return UnmarshalOptions{}.Unmarshal(b, m)
+}
+
+// Unmarshal parses the wire-format message in b and places the result in m.
+func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
+ // TODO: Reset m?
+ return o.unmarshalMessage(b, m.ProtoReflect())
+}
+
+func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
+ messageType := m.Type()
+ fieldTypes := messageType.Fields()
+ knownFields := m.KnownFields()
+ unknownFields := m.UnknownFields()
+ for len(b) > 0 {
+ // Parse the tag (field number and wire type).
+ num, wtyp, tagLen := wire.ConsumeTag(b)
+ if tagLen < 0 {
+ return wire.ParseError(tagLen)
+ }
+
+ // Parse the field value.
+ fieldType := fieldTypes.ByNumber(num)
+ var err error
+ var valLen int
+ switch {
+ case fieldType == nil:
+ err = errUnknown
+ case fieldType.Cardinality() != protoreflect.Repeated:
+ valLen, err = o.unmarshalScalarField(b[tagLen:], wtyp, num, knownFields, fieldType)
+ case !fieldType.IsMap():
+ valLen, err = o.unmarshalList(b[tagLen:], wtyp, num, knownFields.Get(num).List(), fieldType.Kind())
+ default:
+ valLen, err = o.unmarshalMap(b[tagLen:], wtyp, num, knownFields.Get(num).Map(), fieldType)
+ }
+ if err == errUnknown {
+ valLen = wire.ConsumeFieldValue(num, wtyp, b[tagLen:])
+ if valLen < 0 {
+ return wire.ParseError(valLen)
+ }
+ unknownFields.Set(num, append(unknownFields.Get(num), b[:tagLen+valLen]...))
+ } else if err != nil {
+ return err
+ }
+ b = b[tagLen+valLen:]
+ }
+ // TODO: required field checks
+ return nil
+}
+
+func (o UnmarshalOptions) unmarshalScalarField(b []byte, wtyp wire.Type, num wire.Number, knownFields protoreflect.KnownFields, field protoreflect.FieldDescriptor) (n int, err error) {
+ v, n, err := o.unmarshalScalar(b, wtyp, num, field.Kind())
+ if err != nil {
+ return 0, err
+ }
+ switch field.Kind() {
+ case protoreflect.GroupKind, protoreflect.MessageKind:
+ // Messages are merged with any existing message value,
+ // unless the message is part of a oneof.
+ //
+ // TODO: C++ merges into oneofs, while v1 does not.
+ // Evaluate which behavior to pick.
+ var m protoreflect.Message
+ if knownFields.Has(num) && field.OneofType() == nil {
+ m = knownFields.Get(num).Message()
+ } else {
+ m = knownFields.NewMessage(num).ProtoReflect()
+ knownFields.Set(num, protoreflect.ValueOf(m))
+ }
+ if err := o.unmarshalMessage(v.Bytes(), m); err != nil {
+ return 0, err
+ }
+ default:
+ // Non-message scalars replace the previous value.
+ knownFields.Set(num, v)
+ }
+ return n, nil
+}
+
+func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number, mapv protoreflect.Map, field protoreflect.FieldDescriptor) (n int, err error) {
+ if wtyp != wire.BytesType {
+ return 0, errUnknown
+ }
+ b, n = wire.ConsumeBytes(b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ var (
+ keyField = field.MessageType().Fields().ByNumber(1)
+ valField = field.MessageType().Fields().ByNumber(2)
+ key protoreflect.Value
+ val protoreflect.Value
+ haveKey bool
+ haveVal bool
+ )
+ switch valField.Kind() {
+ case protoreflect.GroupKind, protoreflect.MessageKind:
+ val = protoreflect.ValueOf(mapv.NewMessage().ProtoReflect())
+ }
+ // Map entries are represented as a two-element message with fields
+ // containing the key and value.
+ for len(b) > 0 {
+ num, wtyp, n := wire.ConsumeTag(b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ b = b[n:]
+ err = errUnknown
+ switch num {
+ case 1:
+ key, n, err = o.unmarshalScalar(b, wtyp, num, keyField.Kind())
+ if err != nil {
+ break
+ }
+ haveKey = true
+ case 2:
+ var v protoreflect.Value
+ v, n, err = o.unmarshalScalar(b, wtyp, num, valField.Kind())
+ if err != nil {
+ break
+ }
+ switch valField.Kind() {
+ case protoreflect.GroupKind, protoreflect.MessageKind:
+ if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
+ return 0, err
+ }
+ default:
+ val = v
+ }
+ haveVal = true
+ }
+ if err == errUnknown {
+ n = wire.ConsumeFieldValue(num, wtyp, b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ } else if err != nil {
+ return 0, err
+ }
+ b = b[n:]
+ }
+ // Every map entry should have entries for key and value, but this is not strictly required.
+ if !haveKey {
+ key = keyField.Default()
+ }
+ if !haveVal {
+ switch valField.Kind() {
+ case protoreflect.GroupKind, protoreflect.MessageKind:
+ // Trigger required field checks by unmarshaling an empty message.
+ if err := o.unmarshalMessage(nil, val.Message()); err != nil {
+ return 0, err
+ }
+ default:
+ val = valField.Default()
+ }
+ }
+ mapv.Set(key.MapKey(), val)
+ return n, nil
+}
+
+// errUnknown is used internally to indicate fields which should be added
+// to the unknown field set of a message. It is never returned from an exported
+// function.
+var errUnknown = errors.New("unknown")