internal/impl: add fast-path for IsInitialized

This currently returns uninformative errors from the fast path and then
consults the slow, reflection-based path only when an error is detected.
Perhaps it's worth going through the effort of producing better errors
directly on the fast path.

Change-Id: I68536e9438010dbd97dbaff4f47b78430221d94b
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/171462
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/codec_field.go b/internal/impl/codec_field.go
index 38bd392..05bc202 100644
--- a/internal/impl/codec_field.go
+++ b/internal/impl/codec_field.go
@@ -15,12 +15,14 @@
 type pointerCoderFuncs struct {
 	size    func(p pointer, tagsize int, opts marshalOptions) int
 	marshal func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error)
+	isInit  func(p pointer) error
 }
 
 // ifaceCoderFuncs is a set of interface{} encoding functions.
 type ifaceCoderFuncs struct {
 	size    func(ival interface{}, tagsize int, opts marshalOptions) int
 	marshal func(b []byte, ival interface{}, wiretag uint64, opts marshalOptions) ([]byte, error)
+	isInit  func(ival interface{}) error
 }
 
 // fieldCoder returns pointer functions for a field, used for operating on
diff --git a/internal/impl/encode_field.go b/internal/impl/encode_field.go
index b8afb99..15124c0 100644
--- a/internal/impl/encode_field.go
+++ b/internal/impl/encode_field.go
@@ -38,32 +38,40 @@
 		}
 	}
 	ft := fs.Type
+	getInfo := func(p pointer) (pointer, oneofFieldInfo) {
+		v := p.AsValueOf(ft).Elem()
+		if v.IsNil() {
+			return pointer{}, oneofFieldInfo{}
+		}
+		v = v.Elem() // interface -> *struct
+		telem := v.Elem().Type()
+		info, ok := oneofFieldInfos[telem]
+		if !ok {
+			panic(fmt.Errorf("invalid oneof type %v", telem))
+		}
+		return pointerOfValue(v).Apply(zeroOffset), info
+	}
 	return pointerCoderFuncs{
 		size: func(p pointer, _ int, opts marshalOptions) int {
-			v := p.AsValueOf(ft).Elem()
-			if v.IsNil() {
+			v, info := getInfo(p)
+			if info.funcs.size == nil {
 				return 0
 			}
-			v = v.Elem() // interface -> *struct
-			telem := v.Elem().Type()
-			info, ok := oneofFieldInfos[telem]
-			if !ok {
-				panic(fmt.Errorf("invalid oneof type %v", telem))
-			}
-			return info.funcs.size(pointerOfValue(v).Apply(zeroOffset), info.tagsize, opts)
+			return info.funcs.size(v, info.tagsize, opts)
 		},
 		marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
-			v := p.AsValueOf(ft).Elem()
-			if v.IsNil() {
+			v, info := getInfo(p)
+			if info.funcs.marshal == nil {
 				return b, nil
 			}
-			v = v.Elem() // interface -> *struct
-			telem := v.Elem().Type()
-			info, ok := oneofFieldInfos[telem]
-			if !ok {
-				panic(fmt.Errorf("invalid oneof type %v", telem))
+			return info.funcs.marshal(b, v, info.wiretag, opts)
+		},
+		isInit: func(p pointer) error {
+			v, info := getInfo(p)
+			if info.funcs.isInit == nil {
+				return nil
 			}
-			return info.funcs.marshal(b, pointerOfValue(v).Apply(zeroOffset), info.wiretag, opts)
+			return info.funcs.isInit(v)
 		},
 	}
 }
@@ -77,6 +85,9 @@
 			marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
 				return appendMessageInfo(b, p, wiretag, fi, opts)
 			},
+			isInit: func(p pointer) error {
+				return fi.isInitializedPointer(p.Elem())
+			},
 		}
 	} else {
 		return pointerCoderFuncs{
@@ -88,6 +99,10 @@
 				m := asMessage(p.AsValueOf(ft).Elem())
 				return appendMessage(b, m, wiretag, opts)
 			},
+			isInit: func(p pointer) error {
+				m := asMessage(p.AsValueOf(ft).Elem())
+				return proto.IsInitialized(m)
+			},
 		}
 	}
 }
@@ -122,9 +137,15 @@
 	return appendMessage(b, m, wiretag, opts)
 }
 
+func isInitMessageIface(ival interface{}) error {
+	m := Export{}.MessageOf(ival).Interface()
+	return proto.IsInitialized(m)
+}
+
 var coderMessageIface = ifaceCoderFuncs{
 	size:    sizeMessageIface,
 	marshal: appendMessageIface,
+	isInit:  isInitMessageIface,
 }
 
 func makeGroupFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
@@ -136,6 +157,9 @@
 			marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
 				return appendGroupType(b, p, wiretag, fi, opts)
 			},
+			isInit: func(p pointer) error {
+				return fi.isInitializedPointer(p.Elem())
+			},
 		}
 	} else {
 		return pointerCoderFuncs{
@@ -147,6 +171,10 @@
 				m := asMessage(p.AsValueOf(ft).Elem())
 				return appendGroup(b, m, wiretag, opts)
 			},
+			isInit: func(p pointer) error {
+				m := asMessage(p.AsValueOf(ft).Elem())
+				return proto.IsInitialized(m)
+			},
 		}
 	}
 }
@@ -186,6 +214,7 @@
 var coderGroupIface = ifaceCoderFuncs{
 	size:    sizeGroupIface,
 	marshal: appendGroupIface,
+	isInit:  isInitMessageIface,
 }
 
 func makeMessageSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
@@ -197,6 +226,9 @@
 			size: func(p pointer, tagsize int, opts marshalOptions) int {
 				return sizeMessageSliceInfo(p, fi, tagsize, opts)
 			},
+			isInit: func(p pointer) error {
+				return isInitMessageSliceInfo(p, fi)
+			},
 		}
 	}
 	return pointerCoderFuncs{
@@ -206,6 +238,9 @@
 		marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
 			return appendMessageSlice(b, p, wiretag, ft, opts)
 		},
+		isInit: func(p pointer) error {
+			return isInitMessageSlice(p, ft)
+		},
 	}
 }
 
@@ -233,6 +268,16 @@
 	return b, nil
 }
 
+func isInitMessageSliceInfo(p pointer, mi *MessageInfo) error {
+	s := p.PointerSlice()
+	for _, v := range s {
+		if err := mi.isInitializedPointer(v); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 func sizeMessageSlice(p pointer, goType reflect.Type, tagsize int, _ marshalOptions) int {
 	s := p.PointerSlice()
 	n := 0
@@ -259,6 +304,17 @@
 	return b, nil
 }
 
+func isInitMessageSlice(p pointer, goType reflect.Type) error {
+	s := p.PointerSlice()
+	for _, v := range s {
+		m := Export{}.MessageOf(v.AsValueOf(goType.Elem()).Interface()).Interface()
+		if err := proto.IsInitialized(m); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 // Slices of messages
 
 func sizeMessageSliceIface(ival interface{}, tagsize int, opts marshalOptions) int {
@@ -271,9 +327,15 @@
 	return appendMessageSlice(b, p, wiretag, reflect.TypeOf(ival).Elem().Elem(), opts)
 }
 
+func isInitMessageSliceIface(ival interface{}) error {
+	p := pointerOfIface(ival)
+	return isInitMessageSlice(p, reflect.TypeOf(ival).Elem().Elem())
+}
+
 var coderMessageSliceIface = ifaceCoderFuncs{
 	size:    sizeMessageSliceIface,
 	marshal: appendMessageSliceIface,
+	isInit:  isInitMessageSliceIface,
 }
 
 func makeGroupSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
@@ -285,6 +347,9 @@
 			marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
 				return appendGroupSliceInfo(b, p, wiretag, fi, opts)
 			},
+			isInit: func(p pointer) error {
+				return isInitMessageSliceInfo(p, fi)
+			},
 		}
 	}
 	return pointerCoderFuncs{
@@ -294,6 +359,9 @@
 		marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
 			return appendGroupSlice(b, p, wiretag, ft, opts)
 		},
+		isInit: func(p pointer) error {
+			return isInitMessageSlice(p, ft)
+		},
 	}
 }
 
@@ -358,6 +426,7 @@
 var coderGroupSliceIface = ifaceCoderFuncs{
 	size:    sizeGroupSliceIface,
 	marshal: appendGroupSliceIface,
+	isInit:  isInitMessageSliceIface,
 }
 
 // Enums
diff --git a/internal/impl/encode_map.go b/internal/impl/encode_map.go
index d140b27..5a02c34 100644
--- a/internal/impl/encode_map.go
+++ b/internal/impl/encode_map.go
@@ -25,7 +25,7 @@
 	keyFuncs := encoderFuncsForValue(keyField, ft.Key())
 	valFuncs := encoderFuncsForValue(valField, ft.Elem())
 
-	return pointerCoderFuncs{
+	funcs = pointerCoderFuncs{
 		size: func(p pointer, tagsize int, opts marshalOptions) int {
 			return sizeMap(p, tagsize, ft, keyFuncs, valFuncs, opts)
 		},
@@ -33,6 +33,12 @@
 			return appendMap(b, p, wiretag, keyWiretag, valWiretag, ft, keyFuncs, valFuncs, opts)
 		},
 	}
+	if valFuncs.isInit != nil {
+		funcs.isInit = func(p pointer) error {
+			return isInitMap(p, ft, valFuncs.isInit)
+		}
+	}
+	return funcs
 }
 
 const (
@@ -103,6 +109,20 @@
 	return b, nil
 }
 
+func isInitMap(p pointer, goType reflect.Type, isInit func(interface{}) error) error {
+	m := p.AsValueOf(goType).Elem()
+	if m.Len() == 0 {
+		return nil
+	}
+	iter := mapRange(m)
+	for iter.Next() {
+		if err := isInit(iter.Value().Interface()); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 // mapKeys returns a sort.Interface to be used for sorting the map keys.
 // Map fields may have key types of non-float scalars, strings and enums.
 func mapKeys(vs []reflect.Value) sort.Interface {
diff --git a/internal/impl/encode_reflect.go b/internal/impl/encode_reflect.go
index 6e172d9..05ceb33 100644
--- a/internal/impl/encode_reflect.go
+++ b/internal/impl/encode_reflect.go
@@ -22,7 +22,10 @@
 	return b, nil
 }
 
-var coderEnum = pointerCoderFuncs{sizeEnum, appendEnum}
+var coderEnum = pointerCoderFuncs{
+	size:    sizeEnum,
+	marshal: appendEnum,
+}
 
 func sizeEnumNoZero(p pointer, tagsize int, opts marshalOptions) (size int) {
 	if p.v.Elem().Int() == 0 {
@@ -38,7 +41,10 @@
 	return appendEnum(b, p, wiretag, opts)
 }
 
-var coderEnumNoZero = pointerCoderFuncs{sizeEnumNoZero, appendEnumNoZero}
+var coderEnumNoZero = pointerCoderFuncs{
+	size:    sizeEnumNoZero,
+	marshal: appendEnumNoZero,
+}
 
 func sizeEnumPtr(p pointer, tagsize int, opts marshalOptions) (size int) {
 	return sizeEnum(pointer{p.v.Elem()}, tagsize, opts)
@@ -48,7 +54,10 @@
 	return appendEnum(b, pointer{p.v.Elem()}, wiretag, opts)
 }
 
-var coderEnumPtr = pointerCoderFuncs{sizeEnumPtr, appendEnumPtr}
+var coderEnumPtr = pointerCoderFuncs{
+	size:    sizeEnumPtr,
+	marshal: appendEnumPtr,
+}
 
 func sizeEnumSlice(p pointer, tagsize int, opts marshalOptions) (size int) {
 	return sizeEnumSliceReflect(p.v.Elem(), tagsize, opts)
@@ -58,7 +67,10 @@
 	return appendEnumSliceReflect(b, p.v.Elem(), wiretag, opts)
 }
 
-var coderEnumSlice = pointerCoderFuncs{sizeEnumSlice, appendEnumSlice}
+var coderEnumSlice = pointerCoderFuncs{
+	size:    sizeEnumSlice,
+	marshal: appendEnumSlice,
+}
 
 func sizeEnumPackedSlice(p pointer, tagsize int, _ marshalOptions) (size int) {
 	s := p.v.Elem()
@@ -91,4 +103,7 @@
 	return b, nil
 }
 
-var coderEnumPackedSlice = pointerCoderFuncs{sizeEnumPackedSlice, appendEnumPackedSlice}
+var coderEnumPackedSlice = pointerCoderFuncs{
+	size:    sizeEnumPackedSlice,
+	marshal: appendEnumPackedSlice,
+}
diff --git a/internal/impl/isinit.go b/internal/impl/isinit.go
new file mode 100644
index 0000000..972eb85
--- /dev/null
+++ b/internal/impl/isinit.go
@@ -0,0 +1,134 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package impl
+
+import (
+	"sync"
+
+	"google.golang.org/protobuf/proto"
+	pref "google.golang.org/protobuf/reflect/protoreflect"
+)
+
+type errRequiredNotSet struct{}
+
+func (errRequiredNotSet) Error() string        { return "proto: required field not set" }
+func (errRequiredNotSet) RequiredNotSet() bool { return true }
+
+func (mi *MessageInfo) isInitialized(msg proto.Message) error {
+	return mi.isInitializedPointer(pointerOfIface(msg))
+}
+
+func (mi *MessageInfo) isInitializedPointer(p pointer) error {
+	mi.init()
+	if !mi.needsInitCheck {
+		return nil
+	}
+	if p.IsNil() {
+		return errRequiredNotSet{}
+	}
+	if mi.extensionOffset.IsValid() {
+		e := p.Apply(mi.extensionOffset).Extensions()
+		if err := mi.isInitExtensions(e); err != nil {
+			return err
+		}
+	}
+	for _, f := range mi.fieldsOrdered {
+		if !f.isRequired && f.funcs.isInit == nil {
+			continue
+		}
+		fptr := p.Apply(f.offset)
+		if f.isPointer && fptr.Elem().IsNil() {
+			if f.isRequired {
+				return errRequiredNotSet{}
+			}
+			continue
+		}
+		if f.funcs.isInit == nil {
+			continue
+		}
+		if err := f.funcs.isInit(fptr); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (mi *MessageInfo) isInitExtensions(ext *map[int32]ExtensionField) error {
+	if ext == nil {
+		return nil
+	}
+	for _, x := range *ext {
+		ei := mi.extensionFieldInfo(x.GetType())
+		if ei.funcs.isInit == nil {
+			continue
+		}
+		v := x.GetValue()
+		if v == nil {
+			continue
+		}
+		if err := ei.funcs.isInit(v); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+var (
+	needsInitCheckMu  sync.Mutex
+	needsInitCheckMap sync.Map
+)
+
+// needsInitCheck reports whether a message needs to be checked for partial initialization.
+//
+// It returns true if the message transitively includes any required or extension fields.
+func needsInitCheck(md pref.MessageDescriptor) bool {
+	if v, ok := needsInitCheckMap.Load(md); ok {
+		if has, ok := v.(bool); ok {
+			return has
+		}
+	}
+	needsInitCheckMu.Lock()
+	defer needsInitCheckMu.Unlock()
+	return needsInitCheckLocked(md)
+}
+
+func needsInitCheckLocked(md pref.MessageDescriptor) (has bool) {
+	if v, ok := needsInitCheckMap.Load(md); ok {
+		// If has is true, we've previously determined that this message
+		// needs init checks.
+		//
+		// If has is false, we've previously determined that it can never
+		// be uninitialized.
+		//
+		// If has is not a bool, we've just encountered a cycle in the
+		// message graph. In this case, it is safe to return false: If
+		// the message does have required fields, we'll detect them later
+		// in the graph traversal.
+		has, ok := v.(bool)
+		return ok && has
+	}
+	needsInitCheckMap.Store(md, struct{}{}) // avoid cycles while descending into this message
+	defer func() {
+		needsInitCheckMap.Store(md, has)
+	}()
+	if md.RequiredNumbers().Len() > 0 {
+		return true
+	}
+	if md.ExtensionRanges().Len() > 0 {
+		return true
+	}
+	for i := 0; i < md.Fields().Len(); i++ {
+		fd := md.Fields().Get(i)
+		// Map keys are never messages, so just consider the map value.
+		if fd.IsMap() {
+			fd = fd.MapValue()
+		}
+		fmd := fd.Message()
+		if fmd != nil && needsInitCheckLocked(fmd) {
+			return true
+		}
+	}
+	return false
+}
diff --git a/internal/impl/message.go b/internal/impl/message.go
index 161f194..7c4873a 100644
--- a/internal/impl/message.go
+++ b/internal/impl/message.go
@@ -47,6 +47,7 @@
 
 	methods piface.Methods
 
+	needsInitCheck        bool
 	sizecacheOffset       offset
 	extensionOffset       offset
 	unknownOffset         offset
@@ -101,6 +102,7 @@
 	}
 
 	si := mi.makeStructInfo(t.Elem())
+	mi.needsInitCheck = needsInitCheck(mi.PBType)
 	mi.makeKnownFieldsFunc(si)
 	mi.makeUnknownFieldsFunc(t.Elem())
 	mi.makeExtensionFieldsFunc(t.Elem())
@@ -139,6 +141,7 @@
 	mi.methods.Flags = piface.MethodFlagDeterministicMarshal
 	mi.methods.MarshalAppend = mi.marshalAppend
 	mi.methods.Size = mi.size
+	mi.methods.IsInitialized = mi.isInitialized
 }
 
 type structInfo struct {
diff --git a/internal/impl/message_field.go b/internal/impl/message_field.go
index 4f13256..8aacd25 100644
--- a/internal/impl/message_field.go
+++ b/internal/impl/message_field.go
@@ -27,12 +27,13 @@
 	newMessage func() pref.Message
 
 	// These fields are used for fast-path functions.
-	funcs     pointerCoderFuncs // fast-path per-field functions
-	num       pref.FieldNumber  // field number
-	offset    offset            // struct field offset
-	wiretag   uint64            // field tag (number + wire type)
-	tagsize   int               // size of the varint-encoded tag
-	isPointer bool              // true if IsNil may be called on the struct field
+	funcs      pointerCoderFuncs // fast-path per-field functions
+	num        pref.FieldNumber  // field number
+	offset     offset            // struct field offset
+	wiretag    uint64            // field tag (number + wire type)
+	tagsize    int               // size of the varint-encoded tag
+	isPointer  bool              // true if IsNil may be called on the struct field
+	isRequired bool              // true if field is required
 }
 
 func fieldInfoForOneof(fd pref.FieldDescriptor, fs reflect.StructField, ot reflect.Type) fieldInfo {
@@ -308,11 +309,12 @@
 				rv.Set(emptyBytes)
 			}
 		},
-		funcs:     funcs,
-		offset:    fieldOffset,
-		isPointer: nullable,
-		wiretag:   wiretag,
-		tagsize:   wire.SizeVarint(wiretag),
+		funcs:      funcs,
+		offset:     fieldOffset,
+		isPointer:  nullable,
+		isRequired: fd.Cardinality() == pref.Required,
+		wiretag:    wiretag,
+		tagsize:    wire.SizeVarint(wiretag),
 	}
 }
 
@@ -365,6 +367,7 @@
 		funcs:      fieldCoder(fd, ft),
 		offset:     fieldOffset,
 		isPointer:  true,
+		isRequired: fd.Cardinality() == pref.Required,
 		wiretag:    wiretag,
 		tagsize:    wire.SizeVarint(wiretag),
 	}