scripts: Initial version of safe_struct cpp gen

Change-Id: I1f3c28b7737e58731206e93af50fb1d18cd4f937
diff --git a/scripts/helper_file_generator.py b/scripts/helper_file_generator.py
index 1c958a6..4f83727 100644
--- a/scripts/helper_file_generator.py
+++ b/scripts/helper_file_generator.py
@@ -18,6 +18,7 @@
 # limitations under the License.
 #
 # Author: Mark Lobodzinski <mark@lunarg.com>
+# Author: Tobin Ehlis <tobine@google.com>
 
 import os,re,sys
 import xml.etree.ElementTree as etree
@@ -81,14 +82,10 @@
         # Internal state - accumulators for different inner block text
         self.structNames = []                             # List of Vulkan struct typenames
         self.structTypes = dict()                         # Map of Vulkan struct typename to required VkStructureType
-        self.handleTypes = set()                          # Set of handle type names
-        self.commands = []                                # List of CommandData records for all Vulkan commands
         self.structMembers = []                           # List of StructMemberData records for all Vulkan structs
-        self.flags = set()                                # Map of flags typenames
         # Named tuples to store struct and command data
         self.StructType = namedtuple('StructType', ['name', 'value'])
-        self.CommandParam = namedtuple('CommandParam', ['type', 'name', 'ispointer', 'isconst', 'iscount', 'len', 'extstructs', 'cdecl', 'islocal', 'iscreate', 'isdestroy'])
-        self.CommandData = namedtuple('CommandData', ['name', 'return_type', 'params', 'cdecl'])
+        self.CommandParam = namedtuple('CommandParam', ['type', 'name', 'ispointer', 'isstaticarray', 'isconst', 'iscount', 'len', 'extstructs', 'cdecl'])
         self.StructMemberData = namedtuple('StructMemberData', ['name', 'members', 'ifdef_protect'])
     #
     # Called once at the beginning of each run
@@ -183,6 +180,14 @@
                 ispointer = True
         return ispointer
     #
+    # Check if the parameter passed in is a static array
+    def paramIsStaticArray(self, param):
+        isstaticarray = 0
+        paramname = param.find('name')
+        if (paramname.tail is not None) and ('[' in paramname.tail):
+            isstaticarray = paramname.tail.count('[')
+        return isstaticarray
+    #
     # Retrieve the type and name for a parameter
     def getTypeNameTuple(self, param):
         type = ''
@@ -231,6 +236,26 @@
             result = str(result).replace('::', '->')
         return result
     #
+    # Check if a structure is or contains a dispatchable (dispatchable = True) or 
+    # non-dispatchable (dispatchable = False) handle
+    def TypeContainsObjectHandle(self, handle_type, dispatchable):
+        if dispatchable:
+            type_key = 'VK_DEFINE_HANDLE'
+        else:
+            type_key = 'VK_DEFINE_NON_DISPATCHABLE_HANDLE'
+        handle = self.registry.tree.find("types/type/[name='" + handle_type + "'][@category='handle']")
+        if handle is not None and handle.find('type').text == type_key:
+            return True
+        # if handle_type is a struct, search its members
+        if handle_type in self.structNames:
+            member_index = next((i for i, v in enumerate(self.structMembers) if v[0] == handle_type), None)
+            if member_index is not None:
+                for item in self.structMembers[member_index].members:
+                    handle = self.registry.tree.find("types/type/[name='" + item.type + "'][@category='handle']")
+                    if handle is not None and handle.find('type').text == type_key:
+                        return True
+        return False
+    #
     # Generate local ready-access data describing Vulkan structures and unions from the XML metadata
     def genStruct(self, typeinfo, typeName):
         OutputGenerator.genStruct(self, typeinfo, typeName)
@@ -248,7 +273,7 @@
             info = self.getTypeNameTuple(member)
             type = info[0]
             name = info[1]
-            cdecl = self.makeCParamDecl(member, 0)
+            cdecl = self.makeCParamDecl(member, 1)
             # Process VkStructureType
             if type == 'VkStructureType':
                 # Extract the required struct type value from the comments
@@ -262,17 +287,16 @@
                 # Store the required type value
                 self.structTypes[typeName] = self.StructType(name=name, value=value)
             # Store pointer/array/string info
+            isstaticarray = self.paramIsStaticArray(member)
             membersInfo.append(self.CommandParam(type=type,
                                                  name=name,
                                                  ispointer=self.paramIsPointer(member),
+                                                 isstaticarray=isstaticarray,
                                                  isconst=True if 'const' in cdecl else False,
                                                  iscount=True if name in lens else False,
                                                  len=self.getLen(member),
                                                  extstructs=member.attrib.get('validextensionstructs') if name == 'pNext' else None,
-                                                 cdecl=cdecl,
-                                                 islocal=False,
-                                                 iscreate=False,
-                                                 isdestroy=False))
+                                                 cdecl=cdecl))
         self.structMembers.append(self.StructMemberData(name=typeName, members=membersInfo, ifdef_protect=self.featureExtraProtect))
     #
     # Enum_string_header: Create a routine to convert an enumerated value into a string
@@ -448,14 +472,24 @@
     def GenerateSafeStructHeader(self):
         safe_struct_header = ''
         for item in self.structMembers:
-            if self.GenSafeStruct(item) == True:
+            if self.NeedSafeStruct(item) == True:
                 safe_struct_header += '\n'
                 if item.ifdef_protect != None:
                     safe_struct_header += '#ifdef %s\n' % item.ifdef_protect
                 safe_struct_header += 'struct safe_%s {\n' % (item.name)
                 for member in item.members:
-                    safe_struct_header += '%s;\n' % member.cdecl
-                # Boilerplate
+                    if member.type in self.structNames:
+                        member_index = next((i for i, v in enumerate(self.structMembers) if v[0] == member.type), None)
+                        if member_index is not None and self.NeedSafeStruct(self.structMembers[member_index]) == True:
+                            if member.ispointer:
+                                safe_struct_header += '    safe_%s* %s;\n' % (member.type, member.name)
+                            else:
+                                safe_struct_header += '    safe_%s %s;\n' % (member.type, member.name)
+                            continue
+                    if member.len is not None and (self.TypeContainsObjectHandle(member.type, True) or self.TypeContainsObjectHandle(member.type, False)):
+                            safe_struct_header += '    %s* %s;\n' % (member.type, member.name)
+                    else:
+                        safe_struct_header += '%s;\n' % member.cdecl
                 safe_struct_header += '    safe_%s(const %s* in_struct);\n' % (item.name, item.name)
                 safe_struct_header += '    safe_%s(const safe_%s& src);\n' % (item.name, item.name)
                 safe_struct_header += '    safe_%s();\n' % item.name
@@ -471,7 +505,7 @@
     #
     # Determine if a structure needs a safe_struct helper function
     # That is, it has an sType or one of its members is a pointer
-    def GenSafeStruct(self, structure):
+    def NeedSafeStruct(self, structure):
         if 'sType' == structure.name:
             return True
         for member in structure.members:
@@ -490,47 +524,151 @@
     #
     # safe_struct source -- create bodies of safe struct helper functions
     def GenerateSafeStructSource(self):
-        safe_struct_body = ''
+        safe_struct_body = []
         for item in self.structMembers:
-            safe_struct_body += '\n'
+            if self.NeedSafeStruct(item) == False:
+                continue
             if item.ifdef_protect != None:
-                safe_struct_body += '#ifdef %s\n' % item.ifdef_protect
-            safe_struct_body += 'size_t vk_size_%s(const %s* struct_ptr) {\n' % (item.name.lower(), item.name)
-            safe_struct_body += '    size_t struct_size = 0;\n'
-            safe_struct_body += '    if (struct_ptr) {\n'
-            safe_struct_body += '        struct_size = sizeof(%s);\n' % item.name
-            counter_declared = False
+                safe_struct_body.append("#ifdef %s\n" % item.ifdef_protect)
+            ss_name = "safe_%s" % item.name
+            init_list = ''          # list of members in struct constructor initializer
+            default_init_list = ''  # Default constructor just inits ptrs to nullptr in initializer
+            init_func_txt = ''      # Txt for initialize() function that takes struct ptr and inits members
+            construct_txt = ''      # Body of constuctor as well as body of initialize() func following init_func_txt
+            destruct_txt = ''
+            # VkWriteDescriptorSet is special case because pointers may be non-null but ignored
+            custom_construct_txt = {'VkWriteDescriptorSet' :
+                                    '    switch (descriptorType) {\n'
+                                    '        case VK_DESCRIPTOR_TYPE_SAMPLER:\n'
+                                    '        case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:\n'
+                                    '        case VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE:\n'
+                                    '        case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE:\n'
+                                    '        case VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT:\n'
+                                    '        if (descriptorCount && in_struct->pImageInfo) {\n'
+                                    '            pImageInfo = new VkDescriptorImageInfo[descriptorCount];\n'
+                                    '            for (uint32_t i=0; i<descriptorCount; ++i) {\n'
+                                    '                pImageInfo[i] = in_struct->pImageInfo[i];\n'
+                                    '            }\n'
+                                    '        }\n'
+                                    '        break;\n'
+                                    '        case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:\n'
+                                    '        case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER:\n'
+                                    '        case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC:\n'
+                                    '        case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC:\n'
+                                    '        if (descriptorCount && in_struct->pBufferInfo) {\n'
+                                    '            pBufferInfo = new VkDescriptorBufferInfo[descriptorCount];\n'
+                                    '            for (uint32_t i=0; i<descriptorCount; ++i) {\n'
+                                    '                pBufferInfo[i] = in_struct->pBufferInfo[i];\n'
+                                    '            }\n'
+                                    '        }\n'
+                                    '        break;\n'
+                                    '        case VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER:\n'
+                                    '        case VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER:\n'
+                                    '        if (descriptorCount && in_struct->pTexelBufferView) {\n'
+                                    '            pTexelBufferView = new VkBufferView[descriptorCount];\n'
+                                    '            for (uint32_t i=0; i<descriptorCount; ++i) {\n'
+                                    '                pTexelBufferView[i] = in_struct->pTexelBufferView[i];\n'
+                                    '            }\n'
+                                    '        }\n'
+                                    '        break;\n'
+                                    '        default:\n'
+                                    '        break;\n'
+                                    '    }\n'}
+
             for member in item.members:
-                vulkan_type = next((i for i, v in enumerate(self.structMembers) if v[0] == member.type), None)
-                if member.ispointer == True:
-                    if vulkan_type is not None:
-                        # If this is another Vulkan structure call generated size function
-                        if member.len is not None:
-                            safe_struct_body, counter_declared = self.DeclareCounter(safe_struct_body, counter_declared)
-                            safe_struct_body += '        for (i = 0; i < struct_ptr->%s; i++) {\n' % member.len
-                            safe_struct_body += '            struct_size += vk_size_%s(&struct_ptr->%s[i]);\n' % (member.type.lower(), member.name)
-                            safe_struct_body += '        }\n'
-                        else:
-                            safe_struct_body += '        struct_size += vk_size_%s(struct_ptr->%s);\n' % (member.type.lower(), member.name)
+                m_type = member.type
+                if member.type in self.structNames:
+                    member_index = next((i for i, v in enumerate(self.structMembers) if v[0] == member.type), None)
+                    if member_index is not None and self.NeedSafeStruct(self.structMembers[member_index]) == True:
+                        m_type = 'safe_%s' % member.type
+                if member.ispointer and 'safe_' not in m_type and self.TypeContainsObjectHandle(member.type, False) == False:
+                    # Ptr types w/o a safe_struct, for non-null case need to allocate new ptr and copy data in
+                    if 'KHR' in ss_name or m_type in ['void', 'char']:
+                        # For these exceptions just copy initial value over for now
+                        init_list += '\n    %s(in_struct->%s),' % (member.name, member.name)
+                        init_func_txt += '    %s = in_struct->%s;\n' % (member.name, member.name)
                     else:
-                        if member.type == 'char':
-                            # Deal with sizes of character strings
-                            if member.len is not None:
-                                safe_struct_body, counter_declared = self.DeclareCounter(safe_struct_body, counter_declared)
-                                safe_struct_body += '        for (i = 0; i < struct_ptr->%s; i++) {\n' % member.len
-                                safe_struct_body += '            struct_size += (sizeof(char*) + (sizeof(char) * (1 + strlen(struct_ptr->%s[i]))));\n' % (member.name)
-                                safe_struct_body += '        }\n'
+                        default_init_list += '\n    %s(nullptr),' % (member.name)
+                        init_list += '\n    %s(nullptr),' % (member.name)
+                        init_func_txt += '    %s = nullptr;\n' % (member.name)
+                        if 'pNext' != member.name and 'void' not in m_type:
+                            if not member.isstaticarray and member.len is None:
+                                construct_txt += '    if (in_struct->%s) {\n' % member.name
+                                construct_txt += '        %s = new %s(*in_struct->%s);\n' % (member.name, m_type, member.name)
+                                construct_txt += '    }\n'
+                                destruct_txt += '    if (%s)\n' % member.name
+                                destruct_txt += '        delete %s;\n' % member.name
                             else:
-                                safe_struct_body += '        struct_size += (struct_ptr->%s != NULL) ? sizeof(char)*(1+strlen(struct_ptr->%s)) : 0;\n' % (member.name, member.name)
+                                construct_txt += '    if (in_struct->%s) {\n' % member.name
+                                construct_txt += '        %s = new %s[in_struct->%s];\n' % (member.name, m_type, member.len)
+                                construct_txt += '        memcpy ((void *)%s, (void *)in_struct->%s, sizeof(%s)*in_struct->%s);\n' % (member.name, member.name, m_type, member.len)
+                                construct_txt += '    }\n'
+                                destruct_txt += '    if (%s)\n' % member.name
+                                destruct_txt += '        delete[] %s;\n' % member.name
+                elif member.isstaticarray or member.len is not None:
+                    if member.len is None:
+                        # Extract length of static array by grabbing val between []
+                        static_array_size = re.match(r"[^[]*\[([^]]*)\]", member.cdecl)
+                        construct_txt += '    for (uint32_t i=0; i<%s; ++i) {\n' % static_array_size.group(1)
+                        construct_txt += '        %s[i] = in_struct->%s[i];\n' % (member.name, member.name)
+                        construct_txt += '    }\n'
+                    else:
+                        # Init array ptr to NULL
+                        default_init_list += '\n    %s(nullptr),' % member.name
+                        init_list += '\n    %s(nullptr),' % member.name
+                        init_func_txt += '    %s = nullptr;\n' % member.name
+                        array_element = 'in_struct->%s[i]' % member.name
+                        if member.type in self.structNames:
+                            member_index = next((i for i, v in enumerate(self.structMembers) if v[0] == member.type), None)
+                            if member_index is not None and self.NeedSafeStruct(self.structMembers[member_index]) == True:
+                                array_element = '%s(&in_struct->safe_%s[i])' % (member.type, member.name)
+                        construct_txt += '    if (%s && in_struct->%s) {\n' % (member.len, member.name)
+                        construct_txt += '        %s = new %s[%s];\n' % (member.name, m_type, member.len)
+                        destruct_txt += '    if (%s)\n' % member.name
+                        destruct_txt += '        delete[] %s;\n' % member.name
+                        construct_txt += '        for (uint32_t i=0; i<%s; ++i) {\n' % (member.len)
+                        if 'safe_' in m_type:
+                            construct_txt += '            %s[i].initialize(&in_struct->%s[i]);\n' % (member.name, member.name)
                         else:
-                            if member.len is not None:
-                                safe_struct_body += '        struct_size += struct_ptr->%s * sizeof(%s);\n' % (member.len, member.name)
-            safe_struct_body += '    }\n'
-            safe_struct_body += '    return struct_size\n'
-            safe_struct_body += '}\n'
+                            construct_txt += '            %s[i] = %s;\n' % (member.name, array_element)
+                        construct_txt += '        }\n'
+                        construct_txt += '    }\n'
+                elif member.ispointer == True:
+                    construct_txt += '    if (in_struct->%s)\n' % member.name
+                    construct_txt += '        %s = new %s(in_struct->%s);\n' % (member.name, m_type, member.name)
+                    construct_txt += '    else\n'
+                    construct_txt += '        %s = NULL;\n' % member.name
+                    destruct_txt += '    if (%s)\n' % member.name
+                    destruct_txt += '        delete %s;\n' % member.name
+                elif 'safe_' in m_type:
+                    init_list += '\n    %s(&in_struct->%s),' % (member.name, member.name)
+                    init_func_txt += '    %s.initialize(&in_struct->%s);\n' % (member.name, member.name)
+                else:
+                    init_list += '\n    %s(in_struct->%s),' % (member.name, member.name)
+                    init_func_txt += '    %s = in_struct->%s;\n' % (member.name, member.name)
+            if '' != init_list:
+                init_list = init_list[:-1] # hack off final comma
+            if item.name in custom_construct_txt:
+                construct_txt = custom_construct_txt[item.name]
+            safe_struct_body.append("\n%s::%s(const %s* in_struct) :%s\n{\n%s}" % (ss_name, ss_name, item.name, init_list, construct_txt))
+            if '' != default_init_list:
+                default_init_list = " :%s" % (default_init_list[:-1])
+            safe_struct_body.append("\n%s::%s()%s\n{}" % (ss_name, ss_name, default_init_list))
+            # Create slight variation of init and construct txt for copy constructor that takes a src object reference vs. struct ptr
+            copy_construct_init = init_func_txt.replace('in_struct->', 'src.')
+            copy_construct_txt = construct_txt.replace(' (in_struct->', ' (src.')     # Exclude 'if' blocks from next line
+            copy_construct_txt = copy_construct_txt.replace('(in_struct->', '(*src.') # Pass object to copy constructors
+            copy_construct_txt = copy_construct_txt.replace('in_struct->', 'src.')    # Modify remaining struct refs for src object
+            safe_struct_body.append("\n%s::%s(const %s& src)\n{\n%s%s}" % (ss_name, ss_name, ss_name, copy_construct_init, copy_construct_txt)) # Copy constructor
+            safe_struct_body.append("\n%s::~%s()\n{\n%s}" % (ss_name, ss_name, destruct_txt))
+            safe_struct_body.append("\nvoid %s::initialize(const %s* in_struct)\n{\n%s%s}" % (ss_name, item.name, init_func_txt, construct_txt))
+            # Copy initializer uses same txt as copy constructor but has a ptr and not a reference
+            init_copy = copy_construct_init.replace('src.', 'src->')
+            init_construct = copy_construct_txt.replace('src.', 'src->')
+            safe_struct_body.append("\nvoid %s::initialize(const %s* src)\n{\n%s%s}" % (ss_name, ss_name, init_copy, init_construct))
             if item.ifdef_protect != None:
-                safe_struct_body += '#endif // %s\n' % item.ifdef_protect
-        return safe_struct_body
+                safe_struct_body.append("#endif // %s\n" % item.ifdef_protect)
+        return "\n".join(safe_struct_body)
     #
     # Create a helper file and return it as a string
     def OutputDestFile(self):
@@ -545,5 +683,5 @@
         elif self.helper_file_type == 'safe_struct_source':
             return self.GenerateSafeStructHelperSource()
         else:
-            return 'Bad Helper Generator Option %s' % self.helper_file_type
+            return 'Bad Helper File Generator Option %s' % self.helper_file_type