proto: add MarshalState, UnmarshalState
Add functions to the proto package which plumb through the fast-path state.
As a sample use case: A followup CL adds an Initialized field to
protoiface.UnmarshalOutput, permitting the unmarshaller to report back
when it can confirm that a message is fully initialized. We want to
preserve that information when an unmarshal operation threads through
the proto package (such as when unmarshaling extensions).
To allow these functions to be added as methods of MarshalOptions and
UnmarshalOptions rather than top-level functions, separate the options
from the input structs.
Also update options passed to fast-path methods to set AllowPartial and
Merge to reflect the expected behavior of those methods. (Always allow
partial, never merge.)
Change-Id: I482477b0c9340793be533e75a86d0bb88708716a
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/215877
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/proto/decode.go b/proto/decode.go
index f5b5808..83942ea 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -47,38 +47,56 @@
// Unmarshal parses the wire-format message in b and places the result in m.
func Unmarshal(b []byte, m Message) error {
- return UnmarshalOptions{}.Unmarshal(b, m)
+ _, err := UnmarshalOptions{}.unmarshal(b, m)
+ return err
}
// Unmarshal parses the wire-format message in b and places the result in m.
func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
+ _, err := o.unmarshal(b, m)
+ return err
+}
+
+// UnmarshalState parses a wire-format message and places the result in m.
+//
+// This method permits fine-grained control over the unmarshaler.
+// Most users should use Unmarshal instead.
+func (o UnmarshalOptions) UnmarshalState(m Message, in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
+ return o.unmarshal(in.Buf, m)
+}
+
+func (o UnmarshalOptions) unmarshal(b []byte, message Message) (out protoiface.UnmarshalOutput, err error) {
if o.Resolver == nil {
o.Resolver = protoregistry.GlobalTypes
}
-
if !o.Merge {
- Reset(m)
+ Reset(message)
}
- err := o.unmarshalMessage(b, m.ProtoReflect())
+ allowPartial := o.AllowPartial
+ o.Merge = true
+ o.AllowPartial = true
+ m := message.ProtoReflect()
+ methods := protoMethods(m)
+ if methods != nil && methods.Unmarshal != nil &&
+ !(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
+ out, err = methods.Unmarshal(m, protoiface.UnmarshalInput{
+ Buf: b,
+ }, protoiface.UnmarshalOptions(o))
+ } else {
+ err = o.unmarshalMessageSlow(b, m)
+ }
if err != nil {
- return err
+ return out, err
}
- if o.AllowPartial {
- return nil
+ if allowPartial {
+ return out, nil
}
- return IsInitialized(m)
+ return out, isInitialized(m)
}
func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
- if methods := protoMethods(m); methods != nil && methods.Unmarshal != nil &&
- !(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
- _, err := methods.Unmarshal(m, protoiface.UnmarshalInput{
- Buf: b,
- Options: protoiface.UnmarshalOptions(o),
- })
- return err
- }
- return o.unmarshalMessageSlow(b, m)
+ _, err := o.unmarshal(b, m.Interface())
+ return err
}
func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
diff --git a/proto/encode.go b/proto/encode.go
index 18ffe6e..3afa331 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -76,28 +76,35 @@
// Marshal returns the wire-format encoding of m.
func Marshal(m Message) ([]byte, error) {
- return MarshalOptions{}.MarshalAppend(nil, m)
+ out, err := MarshalOptions{}.marshal(nil, m)
+ return out.Buf, err
}
// Marshal returns the wire-format encoding of m.
func (o MarshalOptions) Marshal(m Message) ([]byte, error) {
- return o.MarshalAppend(nil, m)
+ out, err := o.marshal(nil, m)
+ return out.Buf, err
}
// MarshalAppend appends the wire-format encoding of m to b,
// returning the result.
func (o MarshalOptions) MarshalAppend(b []byte, m Message) ([]byte, error) {
- out, err := o.marshalMessage(b, m.ProtoReflect())
- if err != nil {
- return out, err
- }
- if o.AllowPartial {
- return out, nil
- }
- return out, IsInitialized(m)
+ out, err := o.marshal(b, m)
+ return out.Buf, err
}
-func (o MarshalOptions) marshalMessage(b []byte, m protoreflect.Message) ([]byte, error) {
+// MarshalState returns the wire-format encoding of m.
+//
+// This method permits fine-grained control over the marshaler.
+// Most users should use Marshal instead.
+func (o MarshalOptions) MarshalState(m Message, in protoiface.MarshalInput) (protoiface.MarshalOutput, error) {
+ return o.marshal(in.Buf, m)
+}
+
+func (o MarshalOptions) marshal(b []byte, message Message) (out protoiface.MarshalOutput, err error) {
+ allowPartial := o.AllowPartial
+ o.AllowPartial = true
+ m := message.ProtoReflect()
if methods := protoMethods(m); methods != nil && methods.Marshal != nil &&
!(o.Deterministic && methods.Flags&protoiface.SupportMarshalDeterministic == 0) {
if methods.Size != nil {
@@ -109,13 +116,24 @@
}
o.UseCachedSize = true
}
- out, err := methods.Marshal(m, protoiface.MarshalInput{
- Buf: b,
- Options: protoiface.MarshalOptions(o),
- })
- return out.Buf, err
+ out, err = methods.Marshal(m, protoiface.MarshalInput{
+ Buf: b,
+ }, protoiface.MarshalOptions(o))
+ } else {
+ out.Buf, err = o.marshalMessageSlow(b, m)
}
- return o.marshalMessageSlow(b, m)
+ if err != nil {
+ return out, err
+ }
+ if allowPartial {
+ return out, nil
+ }
+ return out, isInitialized(m)
+}
+
+func (o MarshalOptions) marshalMessage(b []byte, m protoreflect.Message) ([]byte, error) {
+ out, err := o.marshal(b, m.Interface())
+ return out.Buf, err
}
// growcap scales up the capacity of a slice.
diff --git a/proto/size.go b/proto/size.go
index 619104c..5f26693 100644
--- a/proto/size.go
+++ b/proto/size.go
@@ -29,7 +29,7 @@
if methods != nil && methods.Marshal != nil {
// This is not efficient, but we don't have any choice.
// This case is mainly used for legacy types with a Marshal method.
- out, _ := methods.Marshal(m, protoiface.MarshalInput{})
+ out, _ := methods.Marshal(m, protoiface.MarshalInput{}, protoiface.MarshalOptions{})
return len(out.Buf)
}
return sizeMessageSlow(m)