internal/impl: add fast-path for IsInitialized

This currently returns uninformative errors from the fast path and then
consults the slow, reflection-based path only when an error is detected.
Perhaps it's worth going through the effort of producing better errors
directly on the fast path.

Change-Id: I68536e9438010dbd97dbaff4f47b78430221d94b
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/171462
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/encode_field.go b/internal/impl/encode_field.go
index b8afb99..15124c0 100644
--- a/internal/impl/encode_field.go
+++ b/internal/impl/encode_field.go
@@ -38,32 +38,40 @@
 		}
 	}
 	ft := fs.Type
+	getInfo := func(p pointer) (pointer, oneofFieldInfo) {
+		v := p.AsValueOf(ft).Elem()
+		if v.IsNil() {
+			return pointer{}, oneofFieldInfo{}
+		}
+		v = v.Elem() // interface -> *struct
+		telem := v.Elem().Type()
+		info, ok := oneofFieldInfos[telem]
+		if !ok {
+			panic(fmt.Errorf("invalid oneof type %v", telem))
+		}
+		return pointerOfValue(v).Apply(zeroOffset), info
+	}
 	return pointerCoderFuncs{
 		size: func(p pointer, _ int, opts marshalOptions) int {
-			v := p.AsValueOf(ft).Elem()
-			if v.IsNil() {
+			v, info := getInfo(p)
+			if info.funcs.size == nil {
 				return 0
 			}
-			v = v.Elem() // interface -> *struct
-			telem := v.Elem().Type()
-			info, ok := oneofFieldInfos[telem]
-			if !ok {
-				panic(fmt.Errorf("invalid oneof type %v", telem))
-			}
-			return info.funcs.size(pointerOfValue(v).Apply(zeroOffset), info.tagsize, opts)
+			return info.funcs.size(v, info.tagsize, opts)
 		},
 		marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
-			v := p.AsValueOf(ft).Elem()
-			if v.IsNil() {
+			v, info := getInfo(p)
+			if info.funcs.marshal == nil {
 				return b, nil
 			}
-			v = v.Elem() // interface -> *struct
-			telem := v.Elem().Type()
-			info, ok := oneofFieldInfos[telem]
-			if !ok {
-				panic(fmt.Errorf("invalid oneof type %v", telem))
+			return info.funcs.marshal(b, v, info.wiretag, opts)
+		},
+		isInit: func(p pointer) error {
+			v, info := getInfo(p)
+			if info.funcs.isInit == nil {
+				return nil
 			}
-			return info.funcs.marshal(b, pointerOfValue(v).Apply(zeroOffset), info.wiretag, opts)
+			return info.funcs.isInit(v)
 		},
 	}
 }
@@ -77,6 +85,9 @@
 			marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
 				return appendMessageInfo(b, p, wiretag, fi, opts)
 			},
+			isInit: func(p pointer) error {
+				return fi.isInitializedPointer(p.Elem())
+			},
 		}
 	} else {
 		return pointerCoderFuncs{
@@ -88,6 +99,10 @@
 				m := asMessage(p.AsValueOf(ft).Elem())
 				return appendMessage(b, m, wiretag, opts)
 			},
+			isInit: func(p pointer) error {
+				m := asMessage(p.AsValueOf(ft).Elem())
+				return proto.IsInitialized(m)
+			},
 		}
 	}
 }
@@ -122,9 +137,15 @@
 	return appendMessage(b, m, wiretag, opts)
 }
 
+func isInitMessageIface(ival interface{}) error {
+	m := Export{}.MessageOf(ival).Interface()
+	return proto.IsInitialized(m)
+}
+
 var coderMessageIface = ifaceCoderFuncs{
 	size:    sizeMessageIface,
 	marshal: appendMessageIface,
+	isInit:  isInitMessageIface,
 }
 
 func makeGroupFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
@@ -136,6 +157,9 @@
 			marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
 				return appendGroupType(b, p, wiretag, fi, opts)
 			},
+			isInit: func(p pointer) error {
+				return fi.isInitializedPointer(p.Elem())
+			},
 		}
 	} else {
 		return pointerCoderFuncs{
@@ -147,6 +171,10 @@
 				m := asMessage(p.AsValueOf(ft).Elem())
 				return appendGroup(b, m, wiretag, opts)
 			},
+			isInit: func(p pointer) error {
+				m := asMessage(p.AsValueOf(ft).Elem())
+				return proto.IsInitialized(m)
+			},
 		}
 	}
 }
@@ -186,6 +214,7 @@
 var coderGroupIface = ifaceCoderFuncs{
 	size:    sizeGroupIface,
 	marshal: appendGroupIface,
+	isInit:  isInitMessageIface,
 }
 
 func makeMessageSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
@@ -197,6 +226,9 @@
 			size: func(p pointer, tagsize int, opts marshalOptions) int {
 				return sizeMessageSliceInfo(p, fi, tagsize, opts)
 			},
+			isInit: func(p pointer) error {
+				return isInitMessageSliceInfo(p, fi)
+			},
 		}
 	}
 	return pointerCoderFuncs{
@@ -206,6 +238,9 @@
 		marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
 			return appendMessageSlice(b, p, wiretag, ft, opts)
 		},
+		isInit: func(p pointer) error {
+			return isInitMessageSlice(p, ft)
+		},
 	}
 }
 
@@ -233,6 +268,16 @@
 	return b, nil
 }
 
+func isInitMessageSliceInfo(p pointer, mi *MessageInfo) error {
+	s := p.PointerSlice()
+	for _, v := range s {
+		if err := mi.isInitializedPointer(v); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 func sizeMessageSlice(p pointer, goType reflect.Type, tagsize int, _ marshalOptions) int {
 	s := p.PointerSlice()
 	n := 0
@@ -259,6 +304,17 @@
 	return b, nil
 }
 
+func isInitMessageSlice(p pointer, goType reflect.Type) error {
+	s := p.PointerSlice()
+	for _, v := range s {
+		m := Export{}.MessageOf(v.AsValueOf(goType.Elem()).Interface()).Interface()
+		if err := proto.IsInitialized(m); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 // Slices of messages
 
 func sizeMessageSliceIface(ival interface{}, tagsize int, opts marshalOptions) int {
@@ -271,9 +327,15 @@
 	return appendMessageSlice(b, p, wiretag, reflect.TypeOf(ival).Elem().Elem(), opts)
 }
 
+func isInitMessageSliceIface(ival interface{}) error {
+	p := pointerOfIface(ival)
+	return isInitMessageSlice(p, reflect.TypeOf(ival).Elem().Elem())
+}
+
 var coderMessageSliceIface = ifaceCoderFuncs{
 	size:    sizeMessageSliceIface,
 	marshal: appendMessageSliceIface,
+	isInit:  isInitMessageSliceIface,
 }
 
 func makeGroupSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
@@ -285,6 +347,9 @@
 			marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
 				return appendGroupSliceInfo(b, p, wiretag, fi, opts)
 			},
+			isInit: func(p pointer) error {
+				return isInitMessageSliceInfo(p, fi)
+			},
 		}
 	}
 	return pointerCoderFuncs{
@@ -294,6 +359,9 @@
 		marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
 			return appendGroupSlice(b, p, wiretag, ft, opts)
 		},
+		isInit: func(p pointer) error {
+			return isInitMessageSlice(p, ft)
+		},
 	}
 }
 
@@ -358,6 +426,7 @@
 var coderGroupSliceIface = ifaceCoderFuncs{
 	size:    sizeGroupSliceIface,
 	marshal: appendGroupSliceIface,
+	isInit:  isInitMessageSliceIface,
 }
 
 // Enums