internal/impl: add fast-path for IsInitialized
This currently returns uninformative errors from the fast path and then
consults the slow, reflection-based path only when an error is detected.
Perhaps it's worth going through the effort of producing better errors
directly on the fast path.
Change-Id: I68536e9438010dbd97dbaff4f47b78430221d94b
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/171462
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/codec_field.go b/internal/impl/codec_field.go
index 38bd392..05bc202 100644
--- a/internal/impl/codec_field.go
+++ b/internal/impl/codec_field.go
@@ -15,12 +15,14 @@
type pointerCoderFuncs struct {
size func(p pointer, tagsize int, opts marshalOptions) int
marshal func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error)
+ isInit func(p pointer) error
}
// ifaceCoderFuncs is a set of interface{} encoding functions.
type ifaceCoderFuncs struct {
size func(ival interface{}, tagsize int, opts marshalOptions) int
marshal func(b []byte, ival interface{}, wiretag uint64, opts marshalOptions) ([]byte, error)
+ isInit func(ival interface{}) error
}
// fieldCoder returns pointer functions for a field, used for operating on
diff --git a/internal/impl/encode_field.go b/internal/impl/encode_field.go
index b8afb99..15124c0 100644
--- a/internal/impl/encode_field.go
+++ b/internal/impl/encode_field.go
@@ -38,32 +38,40 @@
}
}
ft := fs.Type
+ getInfo := func(p pointer) (pointer, oneofFieldInfo) {
+ v := p.AsValueOf(ft).Elem()
+ if v.IsNil() {
+ return pointer{}, oneofFieldInfo{}
+ }
+ v = v.Elem() // interface -> *struct
+ telem := v.Elem().Type()
+ info, ok := oneofFieldInfos[telem]
+ if !ok {
+ panic(fmt.Errorf("invalid oneof type %v", telem))
+ }
+ return pointerOfValue(v).Apply(zeroOffset), info
+ }
return pointerCoderFuncs{
size: func(p pointer, _ int, opts marshalOptions) int {
- v := p.AsValueOf(ft).Elem()
- if v.IsNil() {
+ v, info := getInfo(p)
+ if info.funcs.size == nil {
return 0
}
- v = v.Elem() // interface -> *struct
- telem := v.Elem().Type()
- info, ok := oneofFieldInfos[telem]
- if !ok {
- panic(fmt.Errorf("invalid oneof type %v", telem))
- }
- return info.funcs.size(pointerOfValue(v).Apply(zeroOffset), info.tagsize, opts)
+ return info.funcs.size(v, info.tagsize, opts)
},
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
- v := p.AsValueOf(ft).Elem()
- if v.IsNil() {
+ v, info := getInfo(p)
+ if info.funcs.marshal == nil {
return b, nil
}
- v = v.Elem() // interface -> *struct
- telem := v.Elem().Type()
- info, ok := oneofFieldInfos[telem]
- if !ok {
- panic(fmt.Errorf("invalid oneof type %v", telem))
+ return info.funcs.marshal(b, v, info.wiretag, opts)
+ },
+ isInit: func(p pointer) error {
+ v, info := getInfo(p)
+ if info.funcs.isInit == nil {
+ return nil
}
- return info.funcs.marshal(b, pointerOfValue(v).Apply(zeroOffset), info.wiretag, opts)
+ return info.funcs.isInit(v)
},
}
}
@@ -77,6 +85,9 @@
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendMessageInfo(b, p, wiretag, fi, opts)
},
+ isInit: func(p pointer) error {
+ return fi.isInitializedPointer(p.Elem())
+ },
}
} else {
return pointerCoderFuncs{
@@ -88,6 +99,10 @@
m := asMessage(p.AsValueOf(ft).Elem())
return appendMessage(b, m, wiretag, opts)
},
+ isInit: func(p pointer) error {
+ m := asMessage(p.AsValueOf(ft).Elem())
+ return proto.IsInitialized(m)
+ },
}
}
}
@@ -122,9 +137,15 @@
return appendMessage(b, m, wiretag, opts)
}
+func isInitMessageIface(ival interface{}) error {
+ m := Export{}.MessageOf(ival).Interface()
+ return proto.IsInitialized(m)
+}
+
var coderMessageIface = ifaceCoderFuncs{
size: sizeMessageIface,
marshal: appendMessageIface,
+ isInit: isInitMessageIface,
}
func makeGroupFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
@@ -136,6 +157,9 @@
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendGroupType(b, p, wiretag, fi, opts)
},
+ isInit: func(p pointer) error {
+ return fi.isInitializedPointer(p.Elem())
+ },
}
} else {
return pointerCoderFuncs{
@@ -147,6 +171,10 @@
m := asMessage(p.AsValueOf(ft).Elem())
return appendGroup(b, m, wiretag, opts)
},
+ isInit: func(p pointer) error {
+ m := asMessage(p.AsValueOf(ft).Elem())
+ return proto.IsInitialized(m)
+ },
}
}
}
@@ -186,6 +214,7 @@
var coderGroupIface = ifaceCoderFuncs{
size: sizeGroupIface,
marshal: appendGroupIface,
+ isInit: isInitMessageIface,
}
func makeMessageSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
@@ -197,6 +226,9 @@
size: func(p pointer, tagsize int, opts marshalOptions) int {
return sizeMessageSliceInfo(p, fi, tagsize, opts)
},
+ isInit: func(p pointer) error {
+ return isInitMessageSliceInfo(p, fi)
+ },
}
}
return pointerCoderFuncs{
@@ -206,6 +238,9 @@
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendMessageSlice(b, p, wiretag, ft, opts)
},
+ isInit: func(p pointer) error {
+ return isInitMessageSlice(p, ft)
+ },
}
}
@@ -233,6 +268,16 @@
return b, nil
}
+func isInitMessageSliceInfo(p pointer, mi *MessageInfo) error {
+ s := p.PointerSlice()
+ for _, v := range s {
+ if err := mi.isInitializedPointer(v); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
func sizeMessageSlice(p pointer, goType reflect.Type, tagsize int, _ marshalOptions) int {
s := p.PointerSlice()
n := 0
@@ -259,6 +304,17 @@
return b, nil
}
+func isInitMessageSlice(p pointer, goType reflect.Type) error {
+ s := p.PointerSlice()
+ for _, v := range s {
+ m := Export{}.MessageOf(v.AsValueOf(goType.Elem()).Interface()).Interface()
+ if err := proto.IsInitialized(m); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
// Slices of messages
func sizeMessageSliceIface(ival interface{}, tagsize int, opts marshalOptions) int {
@@ -271,9 +327,15 @@
return appendMessageSlice(b, p, wiretag, reflect.TypeOf(ival).Elem().Elem(), opts)
}
+func isInitMessageSliceIface(ival interface{}) error {
+ p := pointerOfIface(ival)
+ return isInitMessageSlice(p, reflect.TypeOf(ival).Elem().Elem())
+}
+
var coderMessageSliceIface = ifaceCoderFuncs{
size: sizeMessageSliceIface,
marshal: appendMessageSliceIface,
+ isInit: isInitMessageSliceIface,
}
func makeGroupSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
@@ -285,6 +347,9 @@
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendGroupSliceInfo(b, p, wiretag, fi, opts)
},
+ isInit: func(p pointer) error {
+ return isInitMessageSliceInfo(p, fi)
+ },
}
}
return pointerCoderFuncs{
@@ -294,6 +359,9 @@
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendGroupSlice(b, p, wiretag, ft, opts)
},
+ isInit: func(p pointer) error {
+ return isInitMessageSlice(p, ft)
+ },
}
}
@@ -358,6 +426,7 @@
var coderGroupSliceIface = ifaceCoderFuncs{
size: sizeGroupSliceIface,
marshal: appendGroupSliceIface,
+ isInit: isInitMessageSliceIface,
}
// Enums
diff --git a/internal/impl/encode_map.go b/internal/impl/encode_map.go
index d140b27..5a02c34 100644
--- a/internal/impl/encode_map.go
+++ b/internal/impl/encode_map.go
@@ -25,7 +25,7 @@
keyFuncs := encoderFuncsForValue(keyField, ft.Key())
valFuncs := encoderFuncsForValue(valField, ft.Elem())
- return pointerCoderFuncs{
+ funcs = pointerCoderFuncs{
size: func(p pointer, tagsize int, opts marshalOptions) int {
return sizeMap(p, tagsize, ft, keyFuncs, valFuncs, opts)
},
@@ -33,6 +33,12 @@
return appendMap(b, p, wiretag, keyWiretag, valWiretag, ft, keyFuncs, valFuncs, opts)
},
}
+ if valFuncs.isInit != nil {
+ funcs.isInit = func(p pointer) error {
+ return isInitMap(p, ft, valFuncs.isInit)
+ }
+ }
+ return funcs
}
const (
@@ -103,6 +109,20 @@
return b, nil
}
+func isInitMap(p pointer, goType reflect.Type, isInit func(interface{}) error) error {
+ m := p.AsValueOf(goType).Elem()
+ if m.Len() == 0 {
+ return nil
+ }
+ iter := mapRange(m)
+ for iter.Next() {
+ if err := isInit(iter.Value().Interface()); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
// mapKeys returns a sort.Interface to be used for sorting the map keys.
// Map fields may have key types of non-float scalars, strings and enums.
func mapKeys(vs []reflect.Value) sort.Interface {
diff --git a/internal/impl/encode_reflect.go b/internal/impl/encode_reflect.go
index 6e172d9..05ceb33 100644
--- a/internal/impl/encode_reflect.go
+++ b/internal/impl/encode_reflect.go
@@ -22,7 +22,10 @@
return b, nil
}
-var coderEnum = pointerCoderFuncs{sizeEnum, appendEnum}
+var coderEnum = pointerCoderFuncs{
+ size: sizeEnum,
+ marshal: appendEnum,
+}
func sizeEnumNoZero(p pointer, tagsize int, opts marshalOptions) (size int) {
if p.v.Elem().Int() == 0 {
@@ -38,7 +41,10 @@
return appendEnum(b, p, wiretag, opts)
}
-var coderEnumNoZero = pointerCoderFuncs{sizeEnumNoZero, appendEnumNoZero}
+var coderEnumNoZero = pointerCoderFuncs{
+ size: sizeEnumNoZero,
+ marshal: appendEnumNoZero,
+}
func sizeEnumPtr(p pointer, tagsize int, opts marshalOptions) (size int) {
return sizeEnum(pointer{p.v.Elem()}, tagsize, opts)
@@ -48,7 +54,10 @@
return appendEnum(b, pointer{p.v.Elem()}, wiretag, opts)
}
-var coderEnumPtr = pointerCoderFuncs{sizeEnumPtr, appendEnumPtr}
+var coderEnumPtr = pointerCoderFuncs{
+ size: sizeEnumPtr,
+ marshal: appendEnumPtr,
+}
func sizeEnumSlice(p pointer, tagsize int, opts marshalOptions) (size int) {
return sizeEnumSliceReflect(p.v.Elem(), tagsize, opts)
@@ -58,7 +67,10 @@
return appendEnumSliceReflect(b, p.v.Elem(), wiretag, opts)
}
-var coderEnumSlice = pointerCoderFuncs{sizeEnumSlice, appendEnumSlice}
+var coderEnumSlice = pointerCoderFuncs{
+ size: sizeEnumSlice,
+ marshal: appendEnumSlice,
+}
func sizeEnumPackedSlice(p pointer, tagsize int, _ marshalOptions) (size int) {
s := p.v.Elem()
@@ -91,4 +103,7 @@
return b, nil
}
-var coderEnumPackedSlice = pointerCoderFuncs{sizeEnumPackedSlice, appendEnumPackedSlice}
+var coderEnumPackedSlice = pointerCoderFuncs{
+ size: sizeEnumPackedSlice,
+ marshal: appendEnumPackedSlice,
+}
diff --git a/internal/impl/isinit.go b/internal/impl/isinit.go
new file mode 100644
index 0000000..972eb85
--- /dev/null
+++ b/internal/impl/isinit.go
@@ -0,0 +1,134 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package impl
+
+import (
+ "sync"
+
+ "google.golang.org/protobuf/proto"
+ pref "google.golang.org/protobuf/reflect/protoreflect"
+)
+
+type errRequiredNotSet struct{}
+
+func (errRequiredNotSet) Error() string { return "proto: required field not set" }
+func (errRequiredNotSet) RequiredNotSet() bool { return true }
+
+func (mi *MessageInfo) isInitialized(msg proto.Message) error {
+ return mi.isInitializedPointer(pointerOfIface(msg))
+}
+
+func (mi *MessageInfo) isInitializedPointer(p pointer) error {
+ mi.init()
+ if !mi.needsInitCheck {
+ return nil
+ }
+ if p.IsNil() {
+ return errRequiredNotSet{}
+ }
+ if mi.extensionOffset.IsValid() {
+ e := p.Apply(mi.extensionOffset).Extensions()
+ if err := mi.isInitExtensions(e); err != nil {
+ return err
+ }
+ }
+ for _, f := range mi.fieldsOrdered {
+ if !f.isRequired && f.funcs.isInit == nil {
+ continue
+ }
+ fptr := p.Apply(f.offset)
+ if f.isPointer && fptr.Elem().IsNil() {
+ if f.isRequired {
+ return errRequiredNotSet{}
+ }
+ continue
+ }
+ if f.funcs.isInit == nil {
+ continue
+ }
+ if err := f.funcs.isInit(fptr); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (mi *MessageInfo) isInitExtensions(ext *map[int32]ExtensionField) error {
+ if ext == nil {
+ return nil
+ }
+ for _, x := range *ext {
+ ei := mi.extensionFieldInfo(x.GetType())
+ if ei.funcs.isInit == nil {
+ continue
+ }
+ v := x.GetValue()
+ if v == nil {
+ continue
+ }
+ if err := ei.funcs.isInit(v); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+var (
+ needsInitCheckMu sync.Mutex
+ needsInitCheckMap sync.Map
+)
+
+// needsInitCheck reports whether a message needs to be checked for partial initialization.
+//
+// It returns true if the message transitively includes any required or extension fields.
+func needsInitCheck(md pref.MessageDescriptor) bool {
+ if v, ok := needsInitCheckMap.Load(md); ok {
+ if has, ok := v.(bool); ok {
+ return has
+ }
+ }
+ needsInitCheckMu.Lock()
+ defer needsInitCheckMu.Unlock()
+ return needsInitCheckLocked(md)
+}
+
+func needsInitCheckLocked(md pref.MessageDescriptor) (has bool) {
+ if v, ok := needsInitCheckMap.Load(md); ok {
+ // If has is true, we've previously determined that this message
+ // needs init checks.
+ //
+ // If has is false, we've previously determined that it can never
+ // be uninitialized.
+ //
+ // If has is not a bool, we've just encountered a cycle in the
+ // message graph. In this case, it is safe to return false: If
+ // the message does have required fields, we'll detect them later
+ // in the graph traversal.
+ has, ok := v.(bool)
+ return ok && has
+ }
+ needsInitCheckMap.Store(md, struct{}{}) // avoid cycles while descending into this message
+ defer func() {
+ needsInitCheckMap.Store(md, has)
+ }()
+ if md.RequiredNumbers().Len() > 0 {
+ return true
+ }
+ if md.ExtensionRanges().Len() > 0 {
+ return true
+ }
+ for i := 0; i < md.Fields().Len(); i++ {
+ fd := md.Fields().Get(i)
+ // Map keys are never messages, so just consider the map value.
+ if fd.IsMap() {
+ fd = fd.MapValue()
+ }
+ fmd := fd.Message()
+ if fmd != nil && needsInitCheckLocked(fmd) {
+ return true
+ }
+ }
+ return false
+}
diff --git a/internal/impl/message.go b/internal/impl/message.go
index 161f194..7c4873a 100644
--- a/internal/impl/message.go
+++ b/internal/impl/message.go
@@ -47,6 +47,7 @@
methods piface.Methods
+ needsInitCheck bool
sizecacheOffset offset
extensionOffset offset
unknownOffset offset
@@ -101,6 +102,7 @@
}
si := mi.makeStructInfo(t.Elem())
+ mi.needsInitCheck = needsInitCheck(mi.PBType)
mi.makeKnownFieldsFunc(si)
mi.makeUnknownFieldsFunc(t.Elem())
mi.makeExtensionFieldsFunc(t.Elem())
@@ -139,6 +141,7 @@
mi.methods.Flags = piface.MethodFlagDeterministicMarshal
mi.methods.MarshalAppend = mi.marshalAppend
mi.methods.Size = mi.size
+ mi.methods.IsInitialized = mi.isInitialized
}
type structInfo struct {
diff --git a/internal/impl/message_field.go b/internal/impl/message_field.go
index 4f13256..8aacd25 100644
--- a/internal/impl/message_field.go
+++ b/internal/impl/message_field.go
@@ -27,12 +27,13 @@
newMessage func() pref.Message
// These fields are used for fast-path functions.
- funcs pointerCoderFuncs // fast-path per-field functions
- num pref.FieldNumber // field number
- offset offset // struct field offset
- wiretag uint64 // field tag (number + wire type)
- tagsize int // size of the varint-encoded tag
- isPointer bool // true if IsNil may be called on the struct field
+ funcs pointerCoderFuncs // fast-path per-field functions
+ num pref.FieldNumber // field number
+ offset offset // struct field offset
+ wiretag uint64 // field tag (number + wire type)
+ tagsize int // size of the varint-encoded tag
+ isPointer bool // true if IsNil may be called on the struct field
+ isRequired bool // true if field is required
}
func fieldInfoForOneof(fd pref.FieldDescriptor, fs reflect.StructField, ot reflect.Type) fieldInfo {
@@ -308,11 +309,12 @@
rv.Set(emptyBytes)
}
},
- funcs: funcs,
- offset: fieldOffset,
- isPointer: nullable,
- wiretag: wiretag,
- tagsize: wire.SizeVarint(wiretag),
+ funcs: funcs,
+ offset: fieldOffset,
+ isPointer: nullable,
+ isRequired: fd.Cardinality() == pref.Required,
+ wiretag: wiretag,
+ tagsize: wire.SizeVarint(wiretag),
}
}
@@ -365,6 +367,7 @@
funcs: fieldCoder(fd, ft),
offset: fieldOffset,
isPointer: true,
+ isRequired: fd.Cardinality() == pref.Required,
wiretag: wiretag,
tagsize: wire.SizeVarint(wiretag),
}