internal/detectknown: add helper package to identify well-known types

Change-Id: Id54621b4b44522a350e6994074962852690b5d66
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/225257
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/internal/detectknown/detect.go b/internal/detectknown/detect.go
new file mode 100644
index 0000000..091c423
--- /dev/null
+++ b/internal/detectknown/detect.go
@@ -0,0 +1,47 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package detectknown provides functionality for detecting well-known types
+// and identifying them by name.
+package detectknown
+
+import "google.golang.org/protobuf/reflect/protoreflect"
+
+type ProtoFile int
+
+const (
+	Unknown ProtoFile = iota
+	AnyProto
+	TimestampProto
+	DurationProto
+	WrappersProto
+	StructProto
+	FieldMaskProto
+	EmptyProto
+)
+
+var wellKnownTypes = map[protoreflect.FullName]ProtoFile{
+	"google.protobuf.Any":         AnyProto,
+	"google.protobuf.Timestamp":   TimestampProto,
+	"google.protobuf.Duration":    DurationProto,
+	"google.protobuf.BoolValue":   WrappersProto,
+	"google.protobuf.Int32Value":  WrappersProto,
+	"google.protobuf.Int64Value":  WrappersProto,
+	"google.protobuf.UInt32Value": WrappersProto,
+	"google.protobuf.UInt64Value": WrappersProto,
+	"google.protobuf.FloatValue":  WrappersProto,
+	"google.protobuf.DoubleValue": WrappersProto,
+	"google.protobuf.BytesValue":  WrappersProto,
+	"google.protobuf.StringValue": WrappersProto,
+	"google.protobuf.Struct":      StructProto,
+	"google.protobuf.ListValue":   StructProto,
+	"google.protobuf.Value":       StructProto,
+	"google.protobuf.FieldMask":   FieldMaskProto,
+	"google.protobuf.Empty":       EmptyProto,
+}
+
+// Which identifies the proto file that a well-known type belongs to.
+func Which(s protoreflect.FullName) ProtoFile {
+	return wellKnownTypes[s]
+}
diff --git a/internal/detectknown/detect_test.go b/internal/detectknown/detect_test.go
new file mode 100644
index 0000000..c9a31c9
--- /dev/null
+++ b/internal/detectknown/detect_test.go
@@ -0,0 +1,58 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package detectknown_test
+
+import (
+	"testing"
+
+	"google.golang.org/protobuf/internal/detectknown"
+	"google.golang.org/protobuf/reflect/protoreflect"
+
+	fieldmaskpb "google.golang.org/protobuf/internal/testprotos/fieldmaskpb"
+	"google.golang.org/protobuf/types/descriptorpb"
+	"google.golang.org/protobuf/types/known/anypb"
+	"google.golang.org/protobuf/types/known/durationpb"
+	"google.golang.org/protobuf/types/known/emptypb"
+	"google.golang.org/protobuf/types/known/structpb"
+	"google.golang.org/protobuf/types/known/timestamppb"
+	"google.golang.org/protobuf/types/known/wrapperspb"
+	"google.golang.org/protobuf/types/pluginpb"
+)
+
+func TestWhich(t *testing.T) {
+	tests := []struct {
+		in   protoreflect.FileDescriptor
+		want detectknown.ProtoFile
+	}{
+		{descriptorpb.File_google_protobuf_descriptor_proto, detectknown.Unknown},
+		{pluginpb.File_google_protobuf_compiler_plugin_proto, detectknown.Unknown},
+		{anypb.File_google_protobuf_any_proto, detectknown.AnyProto},
+		{timestamppb.File_google_protobuf_timestamp_proto, detectknown.TimestampProto},
+		{durationpb.File_google_protobuf_duration_proto, detectknown.DurationProto},
+		{wrapperspb.File_google_protobuf_wrappers_proto, detectknown.WrappersProto},
+		{structpb.File_google_protobuf_struct_proto, detectknown.StructProto},
+		{fieldmaskpb.File_google_protobuf_field_mask_proto, detectknown.FieldMaskProto},
+		{emptypb.File_google_protobuf_empty_proto, detectknown.EmptyProto},
+	}
+
+	for _, tt := range tests {
+		rangeMessages(tt.in.Messages(), func(md protoreflect.MessageDescriptor) {
+			got := detectknown.Which(md.FullName())
+			if got != tt.want {
+				t.Errorf("Which(%s) = %v, want %v", md.FullName(), got, tt.want)
+			}
+		})
+	}
+}
+
+func rangeMessages(mds protoreflect.MessageDescriptors, f func(protoreflect.MessageDescriptor)) {
+	for i := 0; i < mds.Len(); i++ {
+		md := mds.Get(i)
+		if !md.IsMapEntry() {
+			f(md)
+		}
+		rangeMessages(md.Messages(), f)
+	}
+}
diff --git a/internal/msgfmt/format.go b/internal/msgfmt/format.go
index 21023e5..c2c856f 100644
--- a/internal/msgfmt/format.go
+++ b/internal/msgfmt/format.go
@@ -18,6 +18,7 @@
 	"time"
 
 	"google.golang.org/protobuf/encoding/protowire"
+	"google.golang.org/protobuf/internal/detectknown"
 	"google.golang.org/protobuf/internal/detrand"
 	"google.golang.org/protobuf/internal/mapsort"
 	"google.golang.org/protobuf/proto"
@@ -102,13 +103,9 @@
 
 func appendKnownMessage(b []byte, m protoreflect.Message) []byte {
 	md := m.Descriptor()
-	if md.FullName().Parent() != "google.protobuf" {
-		return nil
-	}
-
 	fds := md.Fields()
-	switch md.Name() {
-	case "Any":
+	switch detectknown.Which(md.FullName()) {
+	case detectknown.AnyProto:
 		var msgVal protoreflect.Message
 		url := m.Get(fds.ByName("type_url")).String()
 		if v := reflect.ValueOf(m); v.Type().ConvertibleTo(protocmpMessageType) {
@@ -140,7 +137,7 @@
 		b = append(b, '}')
 		return b
 
-	case "Timestamp":
+	case detectknown.TimestampProto:
 		secs := m.Get(fds.ByName("seconds")).Int()
 		nanos := m.Get(fds.ByName("nanos")).Int()
 		if nanos < 0 || nanos >= 1e9 {
@@ -153,7 +150,7 @@
 		x = strings.TrimSuffix(x, ".000")
 		return append(b, x+"Z"...)
 
-	case "Duration":
+	case detectknown.DurationProto:
 		secs := m.Get(fds.ByName("seconds")).Int()
 		nanos := m.Get(fds.ByName("nanos")).Int()
 		if nanos <= -1e9 || nanos >= 1e9 || (secs > 0 && nanos < 0) || (secs < 0 && nanos > 0) {
@@ -165,7 +162,7 @@
 		x = strings.TrimSuffix(x, ".000")
 		return append(b, x+"s"...)
 
-	case "BoolValue", "Int32Value", "Int64Value", "UInt32Value", "UInt64Value", "FloatValue", "DoubleValue", "StringValue", "BytesValue":
+	case detectknown.WrappersProto:
 		fd := fds.ByName("value")
 		return appendValue(b, m.Get(fd), fd)
 	}