scripts: Make safe_structs call pnext handlers

Safe structs will now make deep copies of the attendant pnext chain.

Change-Id: I34d6dfed9ce8222c197b83448ad7ee420b55df8c
diff --git a/scripts/helper_file_generator.py b/scripts/helper_file_generator.py
index e3225a6..8a02bb2 100644
--- a/scripts/helper_file_generator.py
+++ b/scripts/helper_file_generator.py
@@ -454,6 +454,7 @@
         safe_struct_helper_header += '#include <vulkan/vulkan.h>\n'
         safe_struct_helper_header += '\n'
         safe_struct_helper_header += 'void *SafePnextCopy(const void *pNext);\n'
+        safe_struct_helper_header += 'void FreePnextChain(const void *head);\n'
         safe_struct_helper_header += 'void FreePnextChain(void *head);\n'
         safe_struct_helper_header += '\n'
         safe_struct_helper_header += self.GenerateSafeStructHeader()
@@ -891,58 +892,67 @@
         build_pnext_proc = '\n\n'
         build_pnext_proc += 'void *SafePnextCopy(const void *pNext) {\n'
         build_pnext_proc += '    void *cur_pnext = const_cast<void *>(pNext);\n'
-        build_pnext_proc += '    void *cur_ext_struct = NULL;\n\n'
-        build_pnext_proc += '    if (cur_pnext == nullptr) {\n'
-        build_pnext_proc += '        return nullptr;\n'
-        build_pnext_proc += '    } else {\n'
-        build_pnext_proc += '        VkBaseOutStructure *header = reinterpret_cast<VkBaseOutStructure *>(cur_pnext);\n\n'
-        build_pnext_proc += '        switch (header->sType) {\n'
+        build_pnext_proc += '    void *cur_ext_struct = NULL;\n'
+        build_pnext_proc += '    bool unrecognized_stype = true;\n\n'
+        build_pnext_proc += '    while (unrecognized_stype) {\n'
+        build_pnext_proc += '        unrecognized_stype = false;\n'
+        build_pnext_proc += '        if (cur_pnext == nullptr) {\n'
+        build_pnext_proc += '            return nullptr;\n'
+        build_pnext_proc += '        } else {\n'
+        build_pnext_proc += '            VkBaseOutStructure *header = reinterpret_cast<VkBaseOutStructure *>(cur_pnext);\n\n'
+        build_pnext_proc += '            switch (header->sType) {\n'
 
         free_pnext_proc = '\n\n'
+        free_pnext_proc += '// Free a const pNext extension chain\n'
+        free_pnext_proc += 'void FreePnextChain(const void *head) {\n'
+        free_pnext_proc += '    FreePnextChain(const_cast<void *>(head));\n'
+        free_pnext_proc += '}\n\n'
+
         free_pnext_proc += '// Free a pNext extension chain\n'
         free_pnext_proc += 'void FreePnextChain(void *head) {\n'
-        free_pnext_proc += '    VkBaseOutStructure *curr_ptr = reinterpret_cast<VkBaseOutStructure *>(head);\n'
-        free_pnext_proc += '    while (curr_ptr) {\n'
-        free_pnext_proc += '        VkBaseOutStructure *header = curr_ptr;\n'
-        free_pnext_proc += '        curr_ptr = reinterpret_cast<VkBaseOutStructure *>(header->pNext);\n\n'
-        free_pnext_proc += '        switch (header->sType) {\n';
+        free_pnext_proc += '    if (nullptr == head) return;\n'
+        free_pnext_proc += '    VkBaseOutStructure *header = reinterpret_cast<VkBaseOutStructure *>(head);\n\n'
+        free_pnext_proc += '    switch (header->sType) {\n';
 
         for item in self.structextends_list:
-            member_index = next((i for i, v in enumerate(self.structMembers) if v[0] == item), None)
-            if member_index is None:
+
+            struct = next((v for v in self.structMembers if v.name == item), None)
+            if struct is None:
                 continue
-            struct_info = self.structMembers[member_index][1]
-            feature_protect = self.structMembers[member_index][2]
+            
+            if struct.ifdef_protect is not None:
+                build_pnext_proc += '#ifdef %s\n' % struct.ifdef_protect
+                free_pnext_proc += '#ifdef %s\n' % struct.ifdef_protect
+            build_pnext_proc += '                case %s: {\n' % self.structTypes[item].value
+            build_pnext_proc += '                        safe_%s *safe_struct = new safe_%s;\n' % (item, item)
+            build_pnext_proc += '                        safe_struct->initialize(reinterpret_cast<const %s *>(cur_pnext));\n' % item
+            build_pnext_proc += '                        cur_ext_struct = reinterpret_cast<void *>(safe_struct);\n'
+            build_pnext_proc += '                    } break;\n'
 
-            if feature_protect is not None:
-                build_pnext_proc += '#ifdef %s\n' % feature_protect
-                free_pnext_proc += '#ifdef %s\n' % feature_protect
-            build_pnext_proc += '            case %s: {\n' % self.structTypes[item].value
-            build_pnext_proc += '                    safe_%s *safe_struct = new safe_%s;\n' % (item, item)
-            build_pnext_proc += '                    safe_struct->initialize(reinterpret_cast<const %s *>(cur_pnext));\n' % item
-            build_pnext_proc += '                    cur_ext_struct = reinterpret_cast<void *>(safe_struct);\n'
-            build_pnext_proc += '                } break;\n'
+            free_pnext_proc += '        case %s:\n' % self.structTypes[item].value
+            free_pnext_proc += '            delete reinterpret_cast<safe_%s *>(header);\n' % item
+            free_pnext_proc += '            break;\n'
 
-            free_pnext_proc += '            case %s:\n' % self.structTypes[item].value
-            free_pnext_proc += '                delete reinterpret_cast<safe_%s *>(header);\n' % item
-            free_pnext_proc += '                break;\n'
-
-            if feature_protect is not None:
-                build_pnext_proc += '#endif // %s\n' % feature_protect
-                free_pnext_proc += '#endif // %s\n' % feature_protect
+            if struct.ifdef_protect is not None:
+                build_pnext_proc += '#endif // %s\n' % struct.ifdef_protect
+                free_pnext_proc += '#endif // %s\n' % struct.ifdef_protect
             build_pnext_proc += '\n'
             free_pnext_proc += '\n'
 
-        build_pnext_proc += '            default:\n'
-        build_pnext_proc += '                break;\n'
+        build_pnext_proc += '                default:\n'
+        build_pnext_proc += '                    // Encountered an unknown sType -- skip (do not copy) this entry in the chain\n'
+        build_pnext_proc += '                    unrecognized_stype = true;\n'
+        build_pnext_proc += '                    cur_pnext = header->pNext;\n'
+        build_pnext_proc += '                    break;\n'
+        build_pnext_proc += '            }\n'
         build_pnext_proc += '        }\n'
         build_pnext_proc += '    }\n'
         build_pnext_proc += '    return cur_ext_struct;\n'
         build_pnext_proc += '}\n\n'
 
-        free_pnext_proc += '            default:\n'
-        free_pnext_proc += '                assert(0);\n'
-        free_pnext_proc += '        }\n'
+        free_pnext_proc += '        default:\n'
+        free_pnext_proc += '            // Do nothing -- skip unrecognized sTypes\n'
+        free_pnext_proc += '            break;\n'
         free_pnext_proc += '    }\n'
         free_pnext_proc += '}\n'
 
@@ -1136,6 +1146,7 @@
             custom_copy_txt = {
                 # VkGraphicsPipelineCreateInfo is special case because it has custom construct parameters
                 'VkGraphicsPipelineCreateInfo' :
+                    '    pNext = SafePnextCopy(src.pNext);\n'
                     '    if (stageCount && src.pStages) {\n'
                     '        pStages = new safe_VkPipelineShaderStageCreateInfo[stageCount];\n'
                     '        for (uint32_t i=0; i<stageCount; ++i) {\n'
@@ -1186,6 +1197,7 @@
                     '        pDynamicState = NULL;\n',
                  # VkPipelineViewportStateCreateInfo is special case because it has custom construct parameters
                 'VkPipelineViewportStateCreateInfo' :
+                    '    pNext = SafePnextCopy(src.pNext);\n'
                     '    if (src.pViewports) {\n'
                     '        pViewports = new VkViewport[src.viewportCount];\n'
                     '        memcpy ((void *)pViewports, (void *)src.pViewports, sizeof(VkViewport)*src.viewportCount);\n'
@@ -1204,8 +1216,11 @@
                                    '    if (pCode)\n'
                                    '        delete[] reinterpret_cast<const uint8_t *>(pCode);\n' }
 
+            copy_pnext = ''
             for member in item.members:
                 m_type = member.type
+                if member.name == 'pNext':
+                    copy_pnext = '    pNext = SafePnextCopy(in_struct->pNext);\n'
                 if member.type in self.structNames:
                     member_index = next((i for i, v in enumerate(self.structMembers) if v[0] == member.type), None)
                     if member_index is not None and self.NeedSafeStruct(self.structMembers[member_index]) == True:
@@ -1213,9 +1228,10 @@
                 if member.ispointer and 'safe_' not in m_type and self.TypeContainsObjectHandle(member.type, False) == False:
                     # Ptr types w/o a safe_struct, for non-null case need to allocate new ptr and copy data in
                     if m_type in ['void', 'char']:
-                        # For these exceptions just copy initial value over for now
-                        init_list += '\n    %s(in_struct->%s),' % (member.name, member.name)
-                        init_func_txt += '    %s = in_struct->%s;\n' % (member.name, member.name)
+                        if member.name != 'pNext':
+                            # For these exceptions just copy initial value over for now
+                            init_list += '\n    %s(in_struct->%s),' % (member.name, member.name)
+                            init_func_txt += '    %s = in_struct->%s;\n' % (member.name, member.name)
                     else:
                         default_init_list += '\n    %s(nullptr),' % (member.name)
                         init_list += '\n    %s(nullptr),' % (member.name)
@@ -1223,20 +1239,19 @@
                             construct_txt += '    %s = in_struct->%s;\n' % (member.name, member.name)
                         else:
                             init_func_txt += '    %s = nullptr;\n' % (member.name)
-                            if 'pNext' != member.name and 'void' not in m_type:
-                                if not member.isstaticarray and (member.len is None or '/' in member.len):
-                                    construct_txt += '    if (in_struct->%s) {\n' % member.name
-                                    construct_txt += '        %s = new %s(*in_struct->%s);\n' % (member.name, m_type, member.name)
-                                    construct_txt += '    }\n'
-                                    destruct_txt += '    if (%s)\n' % member.name
-                                    destruct_txt += '        delete %s;\n' % member.name
-                                else:
-                                    construct_txt += '    if (in_struct->%s) {\n' % member.name
-                                    construct_txt += '        %s = new %s[in_struct->%s];\n' % (member.name, m_type, member.len)
-                                    construct_txt += '        memcpy ((void *)%s, (void *)in_struct->%s, sizeof(%s)*in_struct->%s);\n' % (member.name, member.name, m_type, member.len)
-                                    construct_txt += '    }\n'
-                                    destruct_txt += '    if (%s)\n' % member.name
-                                    destruct_txt += '        delete[] %s;\n' % member.name
+                            if not member.isstaticarray and (member.len is None or '/' in member.len):
+                                construct_txt += '    if (in_struct->%s) {\n' % member.name
+                                construct_txt += '        %s = new %s(*in_struct->%s);\n' % (member.name, m_type, member.name)
+                                construct_txt += '    }\n'
+                                destruct_txt += '    if (%s)\n' % member.name
+                                destruct_txt += '        delete %s;\n' % member.name
+                            else:
+                                construct_txt += '    if (in_struct->%s) {\n' % member.name
+                                construct_txt += '        %s = new %s[in_struct->%s];\n' % (member.name, m_type, member.len)
+                                construct_txt += '        memcpy ((void *)%s, (void *)in_struct->%s, sizeof(%s)*in_struct->%s);\n' % (member.name, member.name, m_type, member.len)
+                                construct_txt += '    }\n'
+                                destruct_txt += '    if (%s)\n' % member.name
+                                destruct_txt += '        delete[] %s;\n' % member.name
                 elif member.isstaticarray or member.len is not None:
                     if member.len is None:
                         # Extract length of static array by grabbing val between []
@@ -1280,19 +1295,29 @@
                     init_func_txt += '    %s = in_struct->%s;\n' % (member.name, member.name)
             if '' != init_list:
                 init_list = init_list[:-1] # hack off final comma
+
+
             if item.name in custom_construct_txt:
                 construct_txt = custom_construct_txt[item.name]
+
+            construct_txt = copy_pnext + construct_txt
+
             if item.name in custom_destruct_txt:
                 destruct_txt = custom_destruct_txt[item.name]
+
+            if copy_pnext:
+                destruct_txt += '    if (pNext)\n        FreePnextChain(pNext);\n'
+
             safe_struct_body.append("\n%s::%s(const %s* in_struct%s) :%s\n{\n%s}" % (ss_name, ss_name, item.name, self.custom_construct_params.get(item.name, ''), init_list, construct_txt))
             if '' != default_init_list:
                 default_init_list = " :%s" % (default_init_list[:-1])
             safe_struct_body.append("\n%s::%s()%s\n{}" % (ss_name, ss_name, default_init_list))
             # Create slight variation of init and construct txt for copy constructor that takes a src object reference vs. struct ptr
             copy_construct_init = init_func_txt.replace('in_struct->', 'src.')
-            copy_construct_txt = construct_txt.replace(' (in_struct->', ' (src.')     # Exclude 'if' blocks from next line
-            copy_construct_txt = copy_construct_txt.replace('(in_struct->', '(*src.') # Pass object to copy constructors
-            copy_construct_txt = copy_construct_txt.replace('in_struct->', 'src.')    # Modify remaining struct refs for src object
+            copy_construct_txt = construct_txt.replace(' (in_struct->', ' (src.')            # Exclude 'if' blocks from next line
+            copy_construct_txt = construct_txt.replace(' (in_struct->', ' (src.')               # Exclude 'if' blocks from next line
+            copy_construct_txt = re.sub('(new \\w+)\\(in_struct->', '\\1(*src.', construct_txt) # Pass object to copy constructors
+            copy_construct_txt = copy_construct_txt.replace('in_struct->', 'src.')              # Modify remaining struct refs for src object
             if item.name in custom_copy_txt:
                 copy_construct_txt = custom_copy_txt[item.name]
             copy_assign_txt = '    if (&src == this) return *this;\n\n' + destruct_txt + '\n' + copy_construct_init + copy_construct_txt + '\n    return *this;'
diff --git a/scripts/layer_chassis_dispatch_generator.py b/scripts/layer_chassis_dispatch_generator.py
index 897c005..9bdc69a 100644
--- a/scripts/layer_chassis_dispatch_generator.py
+++ b/scripts/layer_chassis_dispatch_generator.py
@@ -1342,26 +1342,24 @@
     def build_extension_processing_func(self):
         # Construct helper functions to build and free pNext extension chains
         pnext_proc = ''
-        pnext_proc += 'void *CreateUnwrappedExtensionStructs(ValidationObject *layer_data, const void *pNext) {\n'
+        pnext_proc += 'void WrapPnextChainHandles(ValidationObject *layer_data, const void *pNext) {\n'
         pnext_proc += '    void *cur_pnext = const_cast<void *>(pNext);\n'
-        pnext_proc += '    void *head_pnext = NULL;\n'
-        pnext_proc += '    void *prev_ext_struct = NULL;\n'
-        pnext_proc += '    void *cur_ext_struct = NULL;\n\n'
         pnext_proc += '    while (cur_pnext != NULL) {\n'
         pnext_proc += '        VkBaseOutStructure *header = reinterpret_cast<VkBaseOutStructure *>(cur_pnext);\n\n'
         pnext_proc += '        switch (header->sType) {\n'
         for item in self.pnext_extension_structs:
             struct_info = self.struct_member_dict[item]
+            indent = '                '
+            (tmp_decl, tmp_pre, tmp_post) = self.uniquify_members(struct_info, indent, 'safe_struct->', 0, False, False, False, False)
+            # Only process extension structs containing handles
+            if not tmp_pre:
+                continue
             if struct_info[0].feature_protect is not None:
                 pnext_proc += '#ifdef %s \n' % struct_info[0].feature_protect
             pnext_proc += '            case %s: {\n' % self.structTypes[item].value
-            pnext_proc += '                    safe_%s *safe_struct = new safe_%s;\n' % (item, item)
-            pnext_proc += '                    safe_struct->initialize(reinterpret_cast<const %s *>(cur_pnext));\n' % item
+            pnext_proc += '                    safe_%s *safe_struct = reinterpret_cast<safe_%s *>(cur_pnext);\n' % (item, item)
             # Generate code to unwrap the handles
-            indent = '                '
-            (tmp_decl, tmp_pre, tmp_post) = self.uniquify_members(struct_info, indent, 'safe_struct->', 0, False, False, False, False)
             pnext_proc += tmp_pre
-            pnext_proc += '                    cur_ext_struct = reinterpret_cast<void *>(safe_struct);\n'
             pnext_proc += '                } break;\n'
             if struct_info[0].feature_protect is not None:
                 pnext_proc += '#endif // %s \n' % struct_info[0].feature_protect
@@ -1369,17 +1367,9 @@
         pnext_proc += '            default:\n'
         pnext_proc += '                break;\n'
         pnext_proc += '        }\n\n'
-        pnext_proc += '        // Save pointer to the first structure in the pNext chain\n'
-        pnext_proc += '        head_pnext = (head_pnext ? head_pnext : cur_ext_struct);\n\n'
-        pnext_proc += '        // For any extension structure but the first, link the last struct\'s pNext to the current ext struct\n'
-        pnext_proc += '        if (prev_ext_struct) {\n'
-        pnext_proc += '                reinterpret_cast<VkBaseOutStructure *>(prev_ext_struct)->pNext = reinterpret_cast<VkBaseOutStructure *>(cur_ext_struct);\n'
-        pnext_proc += '        }\n'
-        pnext_proc += '        prev_ext_struct = cur_ext_struct;\n\n'
         pnext_proc += '        // Process the next structure in the chain\n'
         pnext_proc += '        cur_pnext = header->pNext;\n'
         pnext_proc += '    }\n'
-        pnext_proc += '    return head_pnext;\n'
         pnext_proc += '}\n\n'
         pnext_proc += '// Free a pNext extension chain\n'
         pnext_proc += 'void FreeUnwrappedExtensionStructs(void *head) {\n'
@@ -1469,17 +1459,11 @@
 
     #
     # Clean up local declarations
-    def cleanUpLocalDeclarations(self, indent, prefix, name, len, index, process_pnext):
+    def cleanUpLocalDeclarations(self, indent, prefix, name, len, index):
         cleanup = '%sif (local_%s%s) {\n' % (indent, prefix, name)
         if len is not None:
-            if process_pnext:
-                cleanup += '%s    for (uint32_t %s = 0; %s < %s%s; ++%s) {\n' % (indent, index, index, prefix, len, index)
-                cleanup += '%s        FreeUnwrappedExtensionStructs(const_cast<void *>(local_%s%s[%s].pNext));\n' % (indent, prefix, name, index)
-                cleanup += '%s    }\n' % indent
             cleanup += '%s    delete[] local_%s%s;\n' % (indent, prefix, name)
         else:
-            if process_pnext:
-                cleanup += '%s    FreeUnwrappedExtensionStructs(const_cast<void *>(local_%s%s->pNext));\n' % (indent, prefix, name)
             cleanup += '%s    delete local_%s%s;\n' % (indent, prefix, name)
         cleanup += "%s}\n" % (indent)
         return cleanup
@@ -1576,7 +1560,7 @@
                         if first_level_param == True:
                             pre_code += '%s    %s[%s].initialize(&%s[%s]);\n' % (indent, new_prefix, index, member.name, index)
                             if process_pnext:
-                                pre_code += '%s    %s[%s].pNext = CreateUnwrappedExtensionStructs(layer_data, %s[%s].pNext);\n' % (indent, new_prefix, index, new_prefix, index)
+                                pre_code += '%s    WrapPnextChainHandles(layer_data, %s[%s].pNext);\n' % (indent, new_prefix, index)
                         local_prefix = '%s[%s].' % (new_prefix, index)
                         # Process sub-structs in this struct
                         (tmp_decl, tmp_pre, tmp_post) = self.uniquify_members(struct_info, indent, local_prefix, array_index, create_func, destroy_func, destroy_array, False)
@@ -1588,7 +1572,7 @@
                         indent = self.decIndent(indent)
                         pre_code += '%s    }\n' % indent
                         if first_level_param == True:
-                            post_code += self.cleanUpLocalDeclarations(indent, prefix, member.name, member.len, index, process_pnext)
+                            post_code += self.cleanUpLocalDeclarations(indent, prefix, member.name, member.len, index)
                     # Single Struct
                     elif ispointer:
                         # Update struct prefix
@@ -1608,11 +1592,11 @@
                         pre_code += tmp_pre
                         post_code += tmp_post
                         if process_pnext:
-                            pre_code += '%s    local_%s%s->pNext = CreateUnwrappedExtensionStructs(layer_data, local_%s%s->pNext);\n' % (indent, prefix, member.name, prefix, member.name)
+                            pre_code += '%s    WrapPnextChainHandles(layer_data, local_%s%s->pNext);\n' % (indent, prefix, member.name)
                         indent = self.decIndent(indent)
                         pre_code += '%s    }\n' % indent
                         if first_level_param == True:
-                            post_code += self.cleanUpLocalDeclarations(indent, prefix, member.name, member.len, index, process_pnext)
+                            post_code += self.cleanUpLocalDeclarations(indent, prefix, member.name, member.len, index)
                     else:
                         # Update struct prefix
                         if first_level_param == True:
@@ -1625,7 +1609,7 @@
                         pre_code += tmp_pre
                         post_code += tmp_post
                         if process_pnext:
-                            pre_code += '%s    local_%s%s.pNext = CreateUnwrappedExtensionStructs(layer_data, local_%s%s.pNext);\n' % (indent, prefix, member.name, prefix, member.name)
+                            pre_code += '%s    WrapPnextChainHandles(layer_data, local_%s%s.pNext);\n' % (indent, prefix, member.name)
         return decls, pre_code, post_code
     #
     # For a particular API, generate the non-dispatchable-object wrapping/unwrapping code