blob: 079afe07529d0a4eb835fd825ffb9d634ba062de [file] [log] [blame]
Damien Neil5322bdb2019-04-09 15:57:05 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "sync"
9
Damien Neil3d0706a2019-07-09 11:40:49 -070010 "google.golang.org/protobuf/internal/errors"
Damien Neil5322bdb2019-04-09 15:57:05 -070011 pref "google.golang.org/protobuf/reflect/protoreflect"
12)
13
Joe Tsai0f81b382019-07-10 23:14:31 -070014func (mi *MessageInfo) isInitialized(m pref.Message) error {
15 var p pointer
16 if ms, ok := m.(*messageState); ok {
17 p = ms.pointer()
18 } else {
19 p = m.(*messageReflectWrapper).pointer()
20 }
21 return mi.isInitializedPointer(p)
Damien Neil5322bdb2019-04-09 15:57:05 -070022}
23
24func (mi *MessageInfo) isInitializedPointer(p pointer) error {
25 mi.init()
26 if !mi.needsInitCheck {
27 return nil
28 }
29 if p.IsNil() {
Damien Neil3d0706a2019-07-09 11:40:49 -070030 for _, f := range mi.orderedCoderFields {
31 if f.isRequired {
Damien Neil92f76182019-08-02 16:58:08 -070032 return errors.RequiredNotSet(string(mi.PBType.Descriptor().Fields().ByNumber(f.num).FullName()))
Damien Neil3d0706a2019-07-09 11:40:49 -070033 }
34 }
35 return nil
Damien Neil5322bdb2019-04-09 15:57:05 -070036 }
37 if mi.extensionOffset.IsValid() {
38 e := p.Apply(mi.extensionOffset).Extensions()
39 if err := mi.isInitExtensions(e); err != nil {
40 return err
41 }
42 }
Damien Neil4ae30bb2019-06-20 10:12:23 -070043 for _, f := range mi.orderedCoderFields {
Damien Neil5322bdb2019-04-09 15:57:05 -070044 if !f.isRequired && f.funcs.isInit == nil {
45 continue
46 }
47 fptr := p.Apply(f.offset)
48 if f.isPointer && fptr.Elem().IsNil() {
49 if f.isRequired {
Damien Neil92f76182019-08-02 16:58:08 -070050 return errors.RequiredNotSet(string(mi.PBType.Descriptor().Fields().ByNumber(f.num).FullName()))
Damien Neil5322bdb2019-04-09 15:57:05 -070051 }
52 continue
53 }
54 if f.funcs.isInit == nil {
55 continue
56 }
57 if err := f.funcs.isInit(fptr); err != nil {
58 return err
59 }
60 }
61 return nil
62}
63
64func (mi *MessageInfo) isInitExtensions(ext *map[int32]ExtensionField) error {
65 if ext == nil {
66 return nil
67 }
68 for _, x := range *ext {
69 ei := mi.extensionFieldInfo(x.GetType())
70 if ei.funcs.isInit == nil {
71 continue
72 }
73 v := x.GetValue()
74 if v == nil {
75 continue
76 }
77 if err := ei.funcs.isInit(v); err != nil {
78 return err
79 }
80 }
81 return nil
82}
83
84var (
85 needsInitCheckMu sync.Mutex
86 needsInitCheckMap sync.Map
87)
88
89// needsInitCheck reports whether a message needs to be checked for partial initialization.
90//
91// It returns true if the message transitively includes any required or extension fields.
92func needsInitCheck(md pref.MessageDescriptor) bool {
93 if v, ok := needsInitCheckMap.Load(md); ok {
94 if has, ok := v.(bool); ok {
95 return has
96 }
97 }
98 needsInitCheckMu.Lock()
99 defer needsInitCheckMu.Unlock()
100 return needsInitCheckLocked(md)
101}
102
103func needsInitCheckLocked(md pref.MessageDescriptor) (has bool) {
104 if v, ok := needsInitCheckMap.Load(md); ok {
105 // If has is true, we've previously determined that this message
106 // needs init checks.
107 //
108 // If has is false, we've previously determined that it can never
109 // be uninitialized.
110 //
111 // If has is not a bool, we've just encountered a cycle in the
112 // message graph. In this case, it is safe to return false: If
113 // the message does have required fields, we'll detect them later
114 // in the graph traversal.
115 has, ok := v.(bool)
116 return ok && has
117 }
118 needsInitCheckMap.Store(md, struct{}{}) // avoid cycles while descending into this message
119 defer func() {
120 needsInitCheckMap.Store(md, has)
121 }()
122 if md.RequiredNumbers().Len() > 0 {
123 return true
124 }
125 if md.ExtensionRanges().Len() > 0 {
126 return true
127 }
128 for i := 0; i < md.Fields().Len(); i++ {
129 fd := md.Fields().Get(i)
130 // Map keys are never messages, so just consider the map value.
131 if fd.IsMap() {
132 fd = fd.MapValue()
133 }
134 fmd := fd.Message()
135 if fmd != nil && needsInitCheckLocked(fmd) {
136 return true
137 }
138 }
139 return false
140}