internal/impl: change unmarshal func return to unmarshalOptions

The fast-path unmarshal funcs return the number of bytes consumed.

Change these functions to return an unmarshalOutput struct instead, to
make it easier to add to the results. This is groundwork for allowing
the fast-path unmarshaler to indicate when the unmarshaled message is
known to be initialized.

Change-Id: Ia8c44731a88f5be969a55cd98ea26282f412c7ae
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/215720
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/impl/codec_map.go b/internal/impl/codec_map.go
index 05d1ecd..b69ee1a 100644
--- a/internal/impl/codec_map.go
+++ b/internal/impl/codec_map.go
@@ -56,7 +56,7 @@
 		marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
 			return appendMap(b, p.AsValueOf(ft).Elem(), wiretag, mapi, opts)
 		},
-		unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+		unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
 			mp := p.AsValueOf(ft)
 			if mp.Elem().IsNil() {
 				mp.Elem().Set(reflect.MakeMap(mapi.goType))
@@ -104,13 +104,13 @@
 	return n
 }
 
-func consumeMap(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
+func consumeMap(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
 	if wtyp != wire.BytesType {
-		return 0, errUnknown
+		return out, errUnknown
 	}
 	b, n := wire.ConsumeBytes(b)
 	if n < 0 {
-		return 0, wire.ParseError(n)
+		return out, wire.ParseError(n)
 	}
 	var (
 		key = mapi.keyZero
@@ -119,50 +119,55 @@
 	for len(b) > 0 {
 		num, wtyp, n := wire.ConsumeTag(b)
 		if n < 0 {
-			return 0, wire.ParseError(n)
+			return out, wire.ParseError(n)
 		}
 		if num > wire.MaxValidNumber {
-			return 0, errors.New("invalid field number")
+			return out, errors.New("invalid field number")
 		}
 		b = b[n:]
 		err := errUnknown
 		switch num {
 		case 1:
 			var v pref.Value
-			v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
+			var o unmarshalOutput
+			v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
 			if err != nil {
 				break
 			}
 			key = v
+			n = o.n
 		case 2:
 			var v pref.Value
-			v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
+			var o unmarshalOutput
+			v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
 			if err != nil {
 				break
 			}
 			val = v
+			n = o.n
 		}
 		if err == errUnknown {
 			n = wire.ConsumeFieldValue(num, wtyp, b)
 			if n < 0 {
-				return 0, wire.ParseError(n)
+				return out, wire.ParseError(n)
 			}
 		} else if err != nil {
-			return 0, err
+			return out, err
 		}
 		b = b[n:]
 	}
 	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
-	return n, nil
+	out.n = n
+	return out, nil
 }
 
-func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
+func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
 	if wtyp != wire.BytesType {
-		return 0, errUnknown
+		return out, errUnknown
 	}
 	b, n := wire.ConsumeBytes(b)
 	if n < 0 {
-		return 0, wire.ParseError(n)
+		return out, wire.ParseError(n)
 	}
 	var (
 		key = mapi.keyZero
@@ -171,21 +176,23 @@
 	for len(b) > 0 {
 		num, wtyp, n := wire.ConsumeTag(b)
 		if n < 0 {
-			return 0, wire.ParseError(n)
+			return out, wire.ParseError(n)
 		}
 		if num > wire.MaxValidNumber {
-			return 0, errors.New("invalid field number")
+			return out, errors.New("invalid field number")
 		}
 		b = b[n:]
 		err := errUnknown
 		switch num {
 		case 1:
 			var v pref.Value
-			v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
+			var o unmarshalOutput
+			v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
 			if err != nil {
 				break
 			}
 			key = v
+			n = o.n
 		case 2:
 			if wtyp != wire.BytesType {
 				break
@@ -193,22 +200,23 @@
 			var v []byte
 			v, n = wire.ConsumeBytes(b)
 			if n < 0 {
-				return 0, wire.ParseError(n)
+				return out, wire.ParseError(n)
 			}
 			_, err = mapi.valMessageInfo.unmarshalPointer(v, pointerOfValue(val), 0, opts)
 		}
 		if err == errUnknown {
 			n = wire.ConsumeFieldValue(num, wtyp, b)
 			if n < 0 {
-				return 0, wire.ParseError(n)
+				return out, wire.ParseError(n)
 			}
 		} else if err != nil {
-			return 0, err
+			return out, err
 		}
 		b = b[n:]
 	}
 	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
-	return n, nil
+	out.n = n
+	return out, nil
 }
 
 func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, opts marshalOptions) ([]byte, error) {