internal/impl: add fast-path unmarshal

Benchmarks run with:
  go test ./benchmarks/ -bench=Wire  -benchtime=500ms -benchmem -count=8

Fast-path vs. parent commit:

  name                                      old time/op    new time/op    delta
  Wire/Unmarshal/google_message1_proto2-12    1.35µs ± 2%    0.45µs ± 4%  -67.01%  (p=0.000 n=8+8)
  Wire/Unmarshal/google_message1_proto3-12    1.07µs ± 1%    0.31µs ± 1%  -71.04%  (p=0.000 n=8+8)
  Wire/Unmarshal/google_message2-12            691µs ± 2%     188µs ± 2%  -72.78%  (p=0.000 n=7+8)

  name                                      old allocs/op  new allocs/op  delta
  Wire/Unmarshal/google_message1_proto2-12      60.0 ± 0%      25.0 ± 0%  -58.33%  (p=0.000 n=8+8)
  Wire/Unmarshal/google_message1_proto3-12      42.0 ± 0%       7.0 ± 0%  -83.33%  (p=0.000 n=8+8)
  Wire/Unmarshal/google_message2-12            28.6k ± 0%      8.5k ± 0%  -70.34%  (p=0.000 n=8+8)

Fast-path vs. -v1:

  name                                      old time/op    new time/op    delta
  Wire/Unmarshal/google_message1_proto2-12     702ns ± 1%     445ns ± 4%   -36.58%  (p=0.000 n=8+8)
  Wire/Unmarshal/google_message1_proto3-12     604ns ± 1%     311ns ± 1%   -48.54%  (p=0.000 n=8+8)
  Wire/Unmarshal/google_message2-12            179µs ± 3%     188µs ± 2%    +5.30%  (p=0.000 n=7+8)

  name                                      old allocs/op  new allocs/op  delta
  Wire/Unmarshal/google_message1_proto2-12      26.0 ± 0%      25.0 ± 0%    -3.85%  (p=0.000 n=8+8)
  Wire/Unmarshal/google_message1_proto3-12      8.00 ± 0%      7.00 ± 0%   -12.50%  (p=0.000 n=8+8)
  Wire/Unmarshal/google_message2-12            8.49k ± 0%     8.49k ± 0%    -0.01%  (p=0.000 n=8+8)

Change-Id: I6247ac3fd66a63d9acb902cbd192094ee3d151c3
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/185147
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/codec_map.go b/internal/impl/codec_map.go
index 5a02c34..53a860b 100644
--- a/internal/impl/codec_map.go
+++ b/internal/impl/codec_map.go
@@ -16,6 +16,17 @@
 
 var protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
 
+type mapInfo struct {
+	goType     reflect.Type
+	keyWiretag uint64
+	valWiretag uint64
+	keyFuncs   ifaceCoderFuncs
+	valFuncs   ifaceCoderFuncs
+	keyZero    interface{}
+	valZero    interface{}
+	newVal     func() interface{}
+}
+
 func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
 	// TODO: Consider generating specialized map coders.
 	keyField := fd.MapKey()
@@ -25,6 +36,22 @@
 	keyFuncs := encoderFuncsForValue(keyField, ft.Key())
 	valFuncs := encoderFuncsForValue(valField, ft.Elem())
 
+	mapi := &mapInfo{
+		goType:     ft,
+		keyWiretag: keyWiretag,
+		valWiretag: valWiretag,
+		keyFuncs:   keyFuncs,
+		valFuncs:   valFuncs,
+		keyZero:    reflect.Zero(ft.Key()).Interface(),
+		valZero:    reflect.Zero(ft.Elem()).Interface(),
+	}
+	switch valField.Kind() {
+	case pref.GroupKind, pref.MessageKind:
+		mapi.newVal = func() interface{} {
+			return reflect.New(ft.Elem().Elem()).Interface()
+		}
+	}
+
 	funcs = pointerCoderFuncs{
 		size: func(p pointer, tagsize int, opts marshalOptions) int {
 			return sizeMap(p, tagsize, ft, keyFuncs, valFuncs, opts)
@@ -32,6 +59,9 @@
 		marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
 			return appendMap(b, p, wiretag, keyWiretag, valWiretag, ft, keyFuncs, valFuncs, opts)
 		},
+		unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+			return consumeMap(b, p, wtyp, mapi, opts)
+		},
 	}
 	if valFuncs.isInit != nil {
 		funcs.isInit = func(p pointer) error {
@@ -46,6 +76,64 @@
 	mapValTagSize = 1 // field 2, tag size 2.
 )
 
+func consumeMap(b []byte, p pointer, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
+	mp := p.AsValueOf(mapi.goType)
+	if mp.Elem().IsNil() {
+		mp.Elem().Set(reflect.MakeMap(mapi.goType))
+	}
+	m := mp.Elem()
+
+	if wtyp != wire.BytesType {
+		return 0, errUnknown
+	}
+	b, n := wire.ConsumeBytes(b)
+	if n < 0 {
+		return 0, wire.ParseError(n)
+	}
+	var (
+		key = mapi.keyZero
+		val = mapi.valZero
+	)
+	if mapi.newVal != nil {
+		val = mapi.newVal()
+	}
+	for len(b) > 0 {
+		num, wtyp, n := wire.ConsumeTag(b)
+		if n < 0 {
+			return 0, wire.ParseError(n)
+		}
+		b = b[n:]
+		err := errUnknown
+		switch num {
+		case 1:
+			var v interface{}
+			v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
+			if err != nil {
+				break
+			}
+			key = v
+		case 2:
+			var v interface{}
+			v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
+			if err != nil {
+				break
+			}
+			val = v
+		}
+		if err == errUnknown {
+			n = wire.ConsumeFieldValue(num, wtyp, b)
+			if n < 0 {
+				return 0, wire.ParseError(n)
+			}
+		} else if err != nil {
+			return 0, err
+		}
+		b = b[n:]
+	}
+	m.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(val))
+	return n, nil
+}
+
 func sizeMap(p pointer, tagsize int, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) int {
 	m := p.AsValueOf(goType).Elem()
 	n := 0