internal/impl: validate messagesets

Change-Id: Id90bb386e7481bb9dee5a07889f308f1e1810825
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/218438
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/impl/validate.go b/internal/impl/validate.go
index 06acc78..bb00cd0 100644
--- a/internal/impl/validate.go
+++ b/internal/impl/validate.go
@@ -11,6 +11,7 @@
 	"reflect"
 	"unicode/utf8"
 
+	"google.golang.org/protobuf/internal/encoding/messageset"
 	"google.golang.org/protobuf/internal/encoding/wire"
 	"google.golang.org/protobuf/internal/flags"
 	"google.golang.org/protobuf/internal/strs"
@@ -93,6 +94,7 @@
 	validationTypeFixed64
 	validationTypeBytes
 	validationTypeUTF8String
+	validationTypeMessageSetItem
 )
 
 func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
@@ -237,11 +239,6 @@
 State:
 	for len(states) > 0 {
 		st := &states[len(states)-1]
-		if st.mi != nil {
-			if flags.ProtoLegacy && st.mi.isMessageSet {
-				return out, ValidationUnknown
-			}
-		}
 		for len(b) > 0 {
 			// Parse the tag (field number and wire type).
 			var tag uint64
@@ -274,8 +271,8 @@
 				return out, ValidationInvalid
 			}
 			var vi validationInfo
-			switch st.typ {
-			case validationTypeMap:
+			switch {
+			case st.typ == validationTypeMap:
 				switch num {
 				case 1:
 					vi.typ = st.keyType
@@ -284,6 +281,11 @@
 					vi.mi = st.mi
 					vi.requiredBit = 1
 				}
+			case flags.ProtoLegacy && st.mi.isMessageSet:
+				switch num {
+				case messageset.FieldItem:
+					vi.typ = validationTypeMessageSetItem
+				}
 			default:
 				var f *coderFieldInfo
 				if int(num) < len(st.mi.denseCoderFields) {
@@ -483,8 +485,8 @@
 				}
 				b = b[8:]
 			case wire.StartGroupType:
-				switch vi.typ {
-				case validationTypeGroup:
+				switch {
+				case vi.typ == validationTypeGroup:
 					if vi.mi == nil {
 						return out, ValidationUnknown
 					}
@@ -495,6 +497,27 @@
 						endGroup: num,
 					})
 					continue State
+				case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
+					typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
+					if err != nil {
+						return out, ValidationInvalid
+					}
+					xt, err := opts.Resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
+					switch {
+					case err == preg.NotFound:
+						b = b[n:]
+					case err != nil:
+						return out, ValidationUnknown
+					default:
+						xvi := getExtensionFieldInfo(xt).validation
+						states = append(states, validationState{
+							typ:  xvi.typ,
+							mi:   xvi.mi,
+							tail: b[n:],
+						})
+						b = v
+						continue State
+					}
 				default:
 					n := wire.ConsumeFieldValue(num, wtyp, b)
 					if n < 0 {