internal/impl: validate messagesets
Change-Id: Id90bb386e7481bb9dee5a07889f308f1e1810825
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/218438
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/impl/validate.go b/internal/impl/validate.go
index 06acc78..bb00cd0 100644
--- a/internal/impl/validate.go
+++ b/internal/impl/validate.go
@@ -11,6 +11,7 @@
"reflect"
"unicode/utf8"
+ "google.golang.org/protobuf/internal/encoding/messageset"
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/internal/flags"
"google.golang.org/protobuf/internal/strs"
@@ -93,6 +94,7 @@
validationTypeFixed64
validationTypeBytes
validationTypeUTF8String
+ validationTypeMessageSetItem
)
func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
@@ -237,11 +239,6 @@
State:
for len(states) > 0 {
st := &states[len(states)-1]
- if st.mi != nil {
- if flags.ProtoLegacy && st.mi.isMessageSet {
- return out, ValidationUnknown
- }
- }
for len(b) > 0 {
// Parse the tag (field number and wire type).
var tag uint64
@@ -274,8 +271,8 @@
return out, ValidationInvalid
}
var vi validationInfo
- switch st.typ {
- case validationTypeMap:
+ switch {
+ case st.typ == validationTypeMap:
switch num {
case 1:
vi.typ = st.keyType
@@ -284,6 +281,11 @@
vi.mi = st.mi
vi.requiredBit = 1
}
+ case flags.ProtoLegacy && st.mi.isMessageSet:
+ switch num {
+ case messageset.FieldItem:
+ vi.typ = validationTypeMessageSetItem
+ }
default:
var f *coderFieldInfo
if int(num) < len(st.mi.denseCoderFields) {
@@ -483,8 +485,8 @@
}
b = b[8:]
case wire.StartGroupType:
- switch vi.typ {
- case validationTypeGroup:
+ switch {
+ case vi.typ == validationTypeGroup:
if vi.mi == nil {
return out, ValidationUnknown
}
@@ -495,6 +497,27 @@
endGroup: num,
})
continue State
+ case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
+ typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
+ if err != nil {
+ return out, ValidationInvalid
+ }
+ xt, err := opts.Resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
+ switch {
+ case err == preg.NotFound:
+ b = b[n:]
+ case err != nil:
+ return out, ValidationUnknown
+ default:
+ xvi := getExtensionFieldInfo(xt).validation
+ states = append(states, validationState{
+ typ: xvi.typ,
+ mi: xvi.mi,
+ tail: b[n:],
+ })
+ b = v
+ continue State
+ }
default:
n := wire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {