all: improve extension validation
Changes made:
* Ensure protoreflect.ExtensionType.IsValidInterface never panics,
especially if given a nil interface value.
* Have protoreflect.ExtensionType.IsValid{Interface,Value} only
perform type-checks. It does not do value checks (i.e., whether the
value itself is valid). Value validity is left to when an actual
protoreflect.Message.Set operation is performed.
* Add special-casing on proto.SetExtension to treat an invalid
message or list as functionally equivalent to Clear. This is to
be more consistent with the legacy SetExtension implementation
which never panicked when given such values.
* Add special-casing on proto.HasExtension to treat a mismatched
extension descriptor as simply not being present in the message.
This is also to be more consistent with the legacy HasExtension
implementation which did the same thing.
Change-Id: Idf0419abf27b9f85d9b92bd2ff8088e25b7990cc
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/229558
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/impl/convert.go b/internal/impl/convert.go
index 9fc384a..36a90df 100644
--- a/internal/impl/convert.go
+++ b/internal/impl/convert.go
@@ -162,7 +162,7 @@
return ok
}
func (c *boolConverter) IsValidGo(v reflect.Value) bool {
- return v.Type() == c.goType
+ return v.IsValid() && v.Type() == c.goType
}
func (c *boolConverter) New() pref.Value { return c.def }
func (c *boolConverter) Zero() pref.Value { return c.def }
@@ -186,7 +186,7 @@
return ok
}
func (c *int32Converter) IsValidGo(v reflect.Value) bool {
- return v.Type() == c.goType
+ return v.IsValid() && v.Type() == c.goType
}
func (c *int32Converter) New() pref.Value { return c.def }
func (c *int32Converter) Zero() pref.Value { return c.def }
@@ -210,7 +210,7 @@
return ok
}
func (c *int64Converter) IsValidGo(v reflect.Value) bool {
- return v.Type() == c.goType
+ return v.IsValid() && v.Type() == c.goType
}
func (c *int64Converter) New() pref.Value { return c.def }
func (c *int64Converter) Zero() pref.Value { return c.def }
@@ -234,7 +234,7 @@
return ok
}
func (c *uint32Converter) IsValidGo(v reflect.Value) bool {
- return v.Type() == c.goType
+ return v.IsValid() && v.Type() == c.goType
}
func (c *uint32Converter) New() pref.Value { return c.def }
func (c *uint32Converter) Zero() pref.Value { return c.def }
@@ -258,7 +258,7 @@
return ok
}
func (c *uint64Converter) IsValidGo(v reflect.Value) bool {
- return v.Type() == c.goType
+ return v.IsValid() && v.Type() == c.goType
}
func (c *uint64Converter) New() pref.Value { return c.def }
func (c *uint64Converter) Zero() pref.Value { return c.def }
@@ -282,7 +282,7 @@
return ok
}
func (c *float32Converter) IsValidGo(v reflect.Value) bool {
- return v.Type() == c.goType
+ return v.IsValid() && v.Type() == c.goType
}
func (c *float32Converter) New() pref.Value { return c.def }
func (c *float32Converter) Zero() pref.Value { return c.def }
@@ -306,7 +306,7 @@
return ok
}
func (c *float64Converter) IsValidGo(v reflect.Value) bool {
- return v.Type() == c.goType
+ return v.IsValid() && v.Type() == c.goType
}
func (c *float64Converter) New() pref.Value { return c.def }
func (c *float64Converter) Zero() pref.Value { return c.def }
@@ -336,7 +336,7 @@
return ok
}
func (c *stringConverter) IsValidGo(v reflect.Value) bool {
- return v.Type() == c.goType
+ return v.IsValid() && v.Type() == c.goType
}
func (c *stringConverter) New() pref.Value { return c.def }
func (c *stringConverter) Zero() pref.Value { return c.def }
@@ -363,7 +363,7 @@
return ok
}
func (c *bytesConverter) IsValidGo(v reflect.Value) bool {
- return v.Type() == c.goType
+ return v.IsValid() && v.Type() == c.goType
}
func (c *bytesConverter) New() pref.Value { return c.def }
func (c *bytesConverter) Zero() pref.Value { return c.def }
@@ -400,7 +400,7 @@
}
func (c *enumConverter) IsValidGo(v reflect.Value) bool {
- return v.Type() == c.goType
+ return v.IsValid() && v.Type() == c.goType
}
func (c *enumConverter) New() pref.Value {
@@ -455,7 +455,7 @@
}
func (c *messageConverter) IsValidGo(v reflect.Value) bool {
- return v.Type() == c.goType
+ return v.IsValid() && v.Type() == c.goType
}
func (c *messageConverter) New() pref.Value {
diff --git a/internal/impl/convert_list.go b/internal/impl/convert_list.go
index fe9384a..6fccab5 100644
--- a/internal/impl/convert_list.go
+++ b/internal/impl/convert_list.go
@@ -22,7 +22,7 @@
}
type listConverter struct {
- goType reflect.Type
+ goType reflect.Type // []T
c Converter
}
@@ -48,11 +48,11 @@
if !ok {
return false
}
- return list.v.Type().Elem() == c.goType && list.IsValid()
+ return list.v.Type().Elem() == c.goType
}
func (c *listConverter) IsValidGo(v reflect.Value) bool {
- return v.Type() == c.goType
+ return v.IsValid() && v.Type() == c.goType
}
func (c *listConverter) New() pref.Value {
@@ -64,7 +64,7 @@
}
type listPtrConverter struct {
- goType reflect.Type
+ goType reflect.Type // *[]T
c Converter
}
@@ -88,7 +88,7 @@
}
func (c *listPtrConverter) IsValidGo(v reflect.Value) bool {
- return v.Type() == c.goType
+ return v.IsValid() && v.Type() == c.goType
}
func (c *listPtrConverter) New() pref.Value {
diff --git a/internal/impl/convert_map.go b/internal/impl/convert_map.go
index 3ef36d3..de06b25 100644
--- a/internal/impl/convert_map.go
+++ b/internal/impl/convert_map.go
@@ -12,7 +12,7 @@
)
type mapConverter struct {
- goType reflect.Type
+ goType reflect.Type // map[K]V
keyConv, valConv Converter
}
@@ -43,11 +43,11 @@
if !ok {
return false
}
- return mapv.v.Type() == c.goType && mapv.IsValid()
+ return mapv.v.Type() == c.goType
}
func (c *mapConverter) IsValidGo(v reflect.Value) bool {
- return v.Type() == c.goType
+ return v.IsValid() && v.Type() == c.goType
}
func (c *mapConverter) New() pref.Value {
diff --git a/internal/impl/message_reflect.go b/internal/impl/message_reflect.go
index aac55ee..28114ff 100644
--- a/internal/impl/message_reflect.go
+++ b/internal/impl/message_reflect.go
@@ -170,6 +170,8 @@
return x.Value().List().Len() > 0
case xd.IsMap():
return x.Value().Map().Len() > 0
+ case xd.Message() != nil:
+ return x.Value().Message().IsValid()
}
return true
}
@@ -186,15 +188,28 @@
return xt.Zero()
}
func (m *extensionMap) Set(xt pref.ExtensionType, v pref.Value) {
- if !xt.IsValidValue(v) {
+ xd := xt.TypeDescriptor()
+ isValid := true
+ switch {
+ case !xt.IsValidValue(v):
+ isValid = false
+ case xd.IsList():
+ isValid = v.List().IsValid()
+ case xd.IsMap():
+ isValid = v.Map().IsValid()
+ case xd.Message() != nil:
+ isValid = v.Message().IsValid()
+ }
+ if !isValid {
panic(fmt.Sprintf("%v: assigning invalid value", xt.TypeDescriptor().FullName()))
}
+
if *m == nil {
*m = make(map[int32]ExtensionField)
}
var x ExtensionField
x.Set(xt, v)
- (*m)[int32(xt.TypeDescriptor().Number())] = x
+ (*m)[int32(xd.Number())] = x
}
func (m *extensionMap) Mutable(xt pref.ExtensionType) pref.Value {
xd := xt.TypeDescriptor()