scripts: Add extension unwrapping support routines

Change-Id: I86f30aed8a35d91d4c08585210cc0571a27c7430
diff --git a/scripts/unique_objects_generator.py b/scripts/unique_objects_generator.py
index 1a0b0ec..51c8130 100644
--- a/scripts/unique_objects_generator.py
+++ b/scripts/unique_objects_generator.py
@@ -170,7 +170,8 @@
         self.structMembers = []        # List of StructMemberData records for all Vulkan structs
         self.extension_structs = []    # List of all structs or sister-structs containing handles
                                        # A sister-struct may contain no handles but shares <validextensionstructs> with one that does
-        self.structTypes = dict()    # Map of Vulkan struct typename to required VkStructureType
+        self.structTypes = dict()      # Map of Vulkan struct typename to required VkStructureType
+        self.struct_member_dict = dict()
         # Named tuples to store struct and command data
         self.StructType = namedtuple('StructType', ['name', 'value'])
         self.CmdMemberData = namedtuple('CmdMemberData', ['name', 'members'])
@@ -208,12 +209,15 @@
         # Namespace
         self.newline()
         write('namespace unique_objects {', file = self.outFile)
-    #
+    # Now that the data is all collected and complete, generate and output the wrapping/unwrapping routines
     def endFile(self):
 
+        self.struct_member_dict = dict(self.structMembers)
+
+        # Generate the list of APIs that might need to handle wrapped extension structs
+        self.GenerateCommandWrapExtensionList()
         # Write out wrapping/unwrapping functions
         self.WrapCommands()
-
         # Build and write out pNext processing function
         extension_proc = self.build_extension_processing_func()
         self.newline()
@@ -450,11 +454,10 @@
                 ndo_list.add(item)
         return ndo_list
     #
-    # Generate pNext handling function
-    def build_extension_processing_func(self):
-
-        # Construct list of extension structs containing handles, or extension structs that share a <validextensionstructs>
-        # tag WITH an extension struct containing handles. All extension structs in any pNext chain will have to be copied.
+    # Construct list of extension structs containing handles, or extension structs that share a <validextensionstructs>
+    # tag WITH an extension struct containing handles. All extension structs in any pNext chain will have to be copied.
+    # TODO: make this recursive -- structs buried three or more levels deep are not searched for extensions
+    def GenerateCommandWrapExtensionList(self):
         for struct in self.structMembers:
             if (len(struct.members) > 1) and struct.members[1].extstructs is not None:
                 found = False;
@@ -465,6 +468,19 @@
                     for item in struct.members[1].extstructs.split(','):
                         if item != '' and item not in self.extension_structs:
                             self.extension_structs.append(item)
+    #
+    # Returns True if a struct may have a pNext chain containing an NDO
+    def StructWithExtensions(self, struct_type):
+        if struct_type in self.struct_member_dict:
+            param_info = self.struct_member_dict[struct_type]
+            if (len(param_info) > 1) and param_info[1].extstructs is not None:
+                for item in param_info[1].extstructs.split(','):
+                    if item in self.extension_structs:
+                        return True
+        return False
+    #
+    # Generate pNext handling function
+    def build_extension_processing_func(self):
         # Construct helper functions to build and free pNext extension chains
         pnext_proc = ''
         pnext_proc += 'void *CreateUnwrappedExtensionStructs(layer_data *dev_data, const void *pNext) {\n'
@@ -476,8 +492,7 @@
         pnext_proc += '        GenericHeader *header = reinterpret_cast<GenericHeader *>(cur_pnext);\n\n'
         pnext_proc += '        switch (header->sType) {\n'
         for item in self.extension_structs:
-            struct_member_dict = dict(self.structMembers)
-            struct_info = struct_member_dict[item]
+            struct_info = self.struct_member_dict[item]
             if struct_info[0].feature_protect is not None:
                 pnext_proc += '#ifdef %s \n' % struct_info[0].feature_protect
             pnext_proc += '            case %s: {\n' % self.structTypes[item].value
@@ -585,11 +600,17 @@
 
     #
     # Clean up local declarations
-    def cleanUpLocalDeclarations(self, indent, prefix, name, len):
+    def cleanUpLocalDeclarations(self, indent, prefix, name, len, index, process_pnext):
         cleanup = '%sif (local_%s%s)\n' % (indent, prefix, name)
         if len is not None:
+            if process_pnext:
+                cleanup += '%s    for (uint32_t %s = 0; %s < %s%s; ++%s) {\n' % (indent, index, index, prefix, len, index)
+                cleanup += '%s        FreeUnwrappedExtensionStructs(const_cast<void *>(local_%s%s[%s].pNext));\n' % (indent, prefix, name, index)
+                cleanup += '%s    }\n' % indent
             cleanup += '%s    delete[] local_%s%s;\n' % (indent, prefix, name)
         else:
+            if process_pnext:
+                cleanup += '%s    FreeUnwrappedExtensionStructs(const_cast<void *>(local_%s%s->pNext));\n' % (indent, prefix, name)
             cleanup += '%s    delete local_%s%s;\n' % (indent, prefix, name)
         return cleanup
     #
@@ -643,11 +664,11 @@
         decls = ''
         pre_code = ''
         post_code = ''
-        struct_member_dict = dict(self.structMembers)
         index = 'index%s' % str(array_index)
         array_index += 1
         # Process any NDOs in this structure and recurse for any sub-structs in this struct
         for member in members:
+            process_pnext = self.StructWithExtensions(member.type)
             # Handle NDOs
             if self.isHandleTypeNonDispatchable(member.type) == True:
                 count_name = member.len
@@ -661,10 +682,10 @@
                     pre_code += tmp_pre
                     post_code += tmp_post
             # Handle Structs that contain NDOs at some level
-            elif member.type in struct_member_dict:
-                # All structs at first level will have an NDO
-                if self.struct_contains_ndo(member.type) == True:
-                    struct_info = struct_member_dict[member.type]
+            elif member.type in self.struct_member_dict:
+                # Structs at first level will have an NDO, OR, we need a safe_struct for the pnext chain
+                if self.struct_contains_ndo(member.type) == True or process_pnext:
+                    struct_info = self.struct_member_dict[member.type]
                     # Struct Array
                     if member.len is not None:
                         # Update struct prefix
@@ -682,6 +703,8 @@
                         indent = self.incIndent(indent)
                         if first_level_param == True:
                             pre_code += '%s    %s[%s].initialize(&%s[%s]);\n' % (indent, new_prefix, index, member.name, index)
+                            if process_pnext:
+                                pre_code += '%s    %s[%s].pNext = CreateUnwrappedExtensionStructs(dev_data, %s[%s].pNext);\n' % (indent, new_prefix, index, new_prefix, index)
                         local_prefix = '%s[%s].' % (new_prefix, index)
                         # Process sub-structs in this struct
                         (tmp_decl, tmp_pre, tmp_post) = self.uniquify_members(struct_info, indent, local_prefix, array_index, create_func, destroy_func, destroy_array, False)
@@ -693,7 +716,7 @@
                         indent = self.decIndent(indent)
                         pre_code += '%s    }\n' % indent
                         if first_level_param == True:
-                            post_code += self.cleanUpLocalDeclarations(indent, prefix, member.name, member.len)
+                            post_code += self.cleanUpLocalDeclarations(indent, prefix, member.name, member.len, index, process_pnext)
                     # Single Struct
                     else:
                         # Update struct prefix
@@ -712,10 +735,12 @@
                         decls += tmp_decl
                         pre_code += tmp_pre
                         post_code += tmp_post
+                        if process_pnext:
+                            pre_code += '%s    local_%s%s->pNext = CreateUnwrappedExtensionStructs(dev_data, local_%s%s->pNext);\n' % (indent, prefix, member.name, prefix, member.name)
                         indent = self.decIndent(indent)
                         pre_code += '%s    }\n' % indent
                         if first_level_param == True:
-                            post_code += self.cleanUpLocalDeclarations(indent, prefix, member.name, member.len)
+                            post_code += self.cleanUpLocalDeclarations(indent, prefix, member.name, member.len, index, process_pnext)
         return decls, pre_code, post_code
     #
     # For a particular API, generate the non-dispatchable-object wrapping/unwrapping code
@@ -723,6 +748,7 @@
         indent = '    '
         proto = cmd.find('proto/name')
         params = cmd.findall('param')
+
         if proto.text is not None:
             cmd_member_dict = dict(self.cmdMembers)
             cmd_info = cmd_member_dict[proto.text]
@@ -769,6 +795,7 @@
         struct_member_dict = dict(self.structMembers)
         # Generate member info
         membersInfo = []
+        constains_extension_structs = False
         for member in members:
             # Get type and name of member
             info = self.getTypeNameTuple(member)
@@ -789,17 +816,16 @@
             elif type in struct_member_dict:
                 if self.struct_contains_ndo(type) == True:
                     islocal = True
-
             isdestroy = True if True in [destroy_txt in cmdname for destroy_txt in ['Destroy', 'Free']] else False
             iscreate = True if True in [create_txt in cmdname for create_txt in ['Create', 'Allocate', 'GetRandROutputDisplayEXT', 'RegisterDeviceEvent', 'RegisterDisplayEvent']] else False
-
+            extstructs = member.attrib.get('validextensionstructs') if name == 'pNext' else None
             membersInfo.append(self.CommandParam(type=type,
                                                  name=name,
                                                  ispointer=ispointer,
                                                  isconst=isconst,
                                                  iscount=iscount,
                                                  len=len,
-                                                 extstructs=member.attrib.get('validextensionstructs') if name == 'pNext' else None,
+                                                 extstructs=extstructs,
                                                  cdecl=cdecl,
                                                  islocal=islocal,
                                                  iscreate=iscreate,
@@ -869,7 +895,7 @@
             # If any of these paramters has been replaced by a local var, fix up the list
             params = cmd_member_dict[cmdname]
             for param in params:
-                if param.islocal == True:
+                if param.islocal == True or self.StructWithExtensions(param.type):
                     if param.ispointer == True:
                         paramstext = paramstext.replace(param.name, '(%s %s*)local_%s' % ('const', param.type, param.name))
                     else: