internal/impl: inline most field decoding in the validator
name old time/op new time/op delta
EmptyMessage/Wire/Validate-12 4.51ns ± 1% 4.57ns ± 0% +1.19% (p=0.045 n=8+8)
RepeatedInt32/Wire/Validate-12 910ns ± 0% 726ns ± 3% -20.13% (p=0.000 n=8+8)
Required/Wire/Validate-12 34.5ns ± 0% 29.6ns ± 5% -13.99% (p=0.000 n=7+8)
Change-Id: I8ac90ed3fc79dfef7f2500f13b33fd2593fc0fc1
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/216625
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/impl/validate.go b/internal/impl/validate.go
index 40c6a7a..bb6d47d 100644
--- a/internal/impl/validate.go
+++ b/internal/impl/validate.go
@@ -249,7 +249,6 @@
return ValidationUnknown
}
}
- Field:
for len(b) > 0 {
// Parse the tag (field number and wire type).
var tag uint64
@@ -358,100 +357,152 @@
st.requiredMask |= vi.requiredBit
}
}
- switch vi.typ {
- case validationTypeMessage, validationTypeMap:
- if wtyp != wire.BytesType {
- break
+
+ switch wtyp {
+ case wire.VarintType:
+ if len(b) >= 10 {
+ switch {
+ case b[0] < 0x80:
+ b = b[1:]
+ case b[1] < 0x80:
+ b = b[2:]
+ case b[2] < 0x80:
+ b = b[3:]
+ case b[3] < 0x80:
+ b = b[4:]
+ case b[4] < 0x80:
+ b = b[5:]
+ case b[5] < 0x80:
+ b = b[6:]
+ case b[6] < 0x80:
+ b = b[7:]
+ case b[7] < 0x80:
+ b = b[8:]
+ case b[8] < 0x80:
+ b = b[9:]
+ case b[9] < 0x80:
+ b = b[10:]
+ default:
+ return ValidationInvalid
+ }
+ } else {
+ switch {
+ case len(b) > 0 && b[0] < 0x80:
+ b = b[1:]
+ case len(b) > 1 && b[1] < 0x80:
+ b = b[2:]
+ case len(b) > 2 && b[2] < 0x80:
+ b = b[3:]
+ case len(b) > 3 && b[3] < 0x80:
+ b = b[4:]
+ case len(b) > 4 && b[4] < 0x80:
+ b = b[5:]
+ case len(b) > 5 && b[5] < 0x80:
+ b = b[6:]
+ case len(b) > 6 && b[6] < 0x80:
+ b = b[7:]
+ case len(b) > 7 && b[7] < 0x80:
+ b = b[8:]
+ case len(b) > 8 && b[8] < 0x80:
+ b = b[9:]
+ case len(b) > 9 && b[9] < 0x80:
+ b = b[10:]
+ default:
+ return ValidationInvalid
+ }
}
- if vi.mi == nil && vi.typ == validationTypeMessage {
- return ValidationUnknown
- }
- size, n := wire.ConsumeVarint(b)
- if n < 0 {
- return ValidationInvalid
- }
- b = b[n:]
- if uint64(len(b)) < size {
- return ValidationInvalid
- }
- states = append(states, validationState{
- typ: vi.typ,
- keyType: vi.keyType,
- valType: vi.valType,
- mi: vi.mi,
- tail: b[size:],
- })
- b = b[:size]
continue State
- case validationTypeGroup:
- if wtyp != wire.StartGroupType {
- break
- }
- if vi.mi == nil {
- return ValidationUnknown
- }
- states = append(states, validationState{
- typ: validationTypeGroup,
- mi: vi.mi,
- endGroup: num,
- })
- continue State
- case validationTypeRepeatedVarint:
- if wtyp != wire.BytesType {
- break
- }
- // Packed field.
- v, n := wire.ConsumeBytes(b)
- if n < 0 {
- return ValidationInvalid
- }
- b = b[n:]
- for len(v) > 0 {
- _, n := wire.ConsumeVarint(v)
+ case wire.BytesType:
+ var size uint64
+ if b[0] < 0x80 {
+ size = uint64(b[0])
+ b = b[1:]
+ } else if len(b) >= 2 && b[1] < 128 {
+ size = uint64(b[0]&0x7f) + uint64(b[1])<<7
+ b = b[2:]
+ } else {
+ var n int
+ size, n = wire.ConsumeVarint(b)
if n < 0 {
return ValidationInvalid
}
- v = v[n:]
+ b = b[n:]
}
- continue Field
- case validationTypeRepeatedFixed32:
- if wtyp != wire.BytesType {
- break
- }
- // Packed field.
- v, n := wire.ConsumeBytes(b)
- if n < 0 || len(v)%4 != 0 {
+ if size > uint64(len(b)) {
return ValidationInvalid
}
- b = b[n:]
- continue Field
- case validationTypeRepeatedFixed64:
- if wtyp != wire.BytesType {
- break
+ v := b[:size]
+ b = b[size:]
+ switch vi.typ {
+ case validationTypeMessage, validationTypeMap:
+ if vi.mi == nil && vi.typ == validationTypeMessage {
+ return ValidationUnknown
+ }
+ states = append(states, validationState{
+ typ: vi.typ,
+ keyType: vi.keyType,
+ valType: vi.valType,
+ mi: vi.mi,
+ tail: b,
+ })
+ b = v
+ continue State
+ case validationTypeRepeatedVarint:
+ // Packed field.
+ for len(v) > 0 {
+ _, n := wire.ConsumeVarint(v)
+ if n < 0 {
+ return ValidationInvalid
+ }
+ v = v[n:]
+ }
+ case validationTypeRepeatedFixed32:
+ // Packed field.
+ if len(v)%4 != 0 {
+ return ValidationInvalid
+ }
+ case validationTypeRepeatedFixed64:
+ // Packed field.
+ if len(v)%8 != 0 {
+ return ValidationInvalid
+ }
+ case validationTypeUTF8String:
+ if !utf8.Valid(v) {
+ return ValidationInvalid
+ }
}
- // Packed field.
- v, n := wire.ConsumeBytes(b)
- if n < 0 || len(v)%8 != 0 {
+ case wire.Fixed32Type:
+ if len(b) < 4 {
return ValidationInvalid
}
- b = b[n:]
- continue Field
- case validationTypeUTF8String:
- if wtyp != wire.BytesType {
- break
- }
- v, n := wire.ConsumeBytes(b)
- if n < 0 || !utf8.Valid(v) {
+ b = b[4:]
+ case wire.Fixed64Type:
+ if len(b) < 8 {
return ValidationInvalid
}
- b = b[n:]
- continue Field
- }
- n := wire.ConsumeFieldValue(num, wtyp, b)
- if n < 0 {
+ b = b[8:]
+ case wire.StartGroupType:
+ switch vi.typ {
+ case validationTypeGroup:
+ if vi.mi == nil {
+ return ValidationUnknown
+ }
+ states = append(states, validationState{
+ typ: validationTypeGroup,
+ mi: vi.mi,
+ endGroup: num,
+ })
+ continue State
+ default:
+ n := wire.ConsumeFieldValue(num, wtyp, b)
+ if n < 0 {
+ return ValidationInvalid
+ }
+ b = b[n:]
+ }
+ default:
return ValidationInvalid
}
- b = b[n:]
}
if st.endGroup != 0 {
return ValidationInvalid