layers: Refactor OT codegen for pre-post format

Codegen for ObjectTracker now creates and uses PreCallValidate,
PreCallRecord, and PostCallRecord functions.

Change-Id: I788daf849b4733a911d410ef4195734100067af9
diff --git a/scripts/object_tracker_generator.py b/scripts/object_tracker_generator.py
index 8f9926b..f08c42c 100644
--- a/scripts/object_tracker_generator.py
+++ b/scripts/object_tracker_generator.py
@@ -697,9 +697,6 @@
             if cmd_info[-1].len is not None:
                 object_array = True;
             handle_name = params[-1].find('name')
-            create_obj_code += '%sif (VK_SUCCESS == result) {\n' % (indent)
-            indent = self.incIndent(indent)
-            create_obj_code += '%sstd::lock_guard<std::mutex> lock(global_lock);\n' % (indent)
             object_dest = '*%s' % handle_name.text
             if object_array == True:
                 create_obj_code += '%sfor (uint32_t index = 0; index < %s; index++) {\n' % (indent, cmd_info[-1].len)
@@ -710,12 +707,12 @@
                 indent = self.decIndent(indent)
                 create_obj_code += '%s}\n' % indent
             indent = self.decIndent(indent)
-            create_obj_code += '%s}\n' % (indent)
         return create_obj_code
     #
     # Generate source for destroying a non-dispatchable object
     def generate_destroy_object_code(self, indent, proto, cmd_info):
-        destroy_obj_code = ''
+        validate_code = ''
+        record_code = ''
         object_array = False
         if True in [destroy_txt in proto.text for destroy_txt in ['Destroy', 'Free']]:
             # Check for special case where multiple handles are returned
@@ -731,16 +728,12 @@
             if self.isHandleTypeObject(cmd_info[param].type) == True:
                 if object_array == True:
                     # This API is freeing an array of handles -- add loop control
-                    destroy_obj_code += 'HEY, NEED TO DESTROY AN ARRAY\n'
+                    validate_code += 'HEY, NEED TO DESTROY AN ARRAY\n'
                 else:
                     # Call Destroy a single time
-                    destroy_obj_code += '%sif (skip) return;\n' % indent
-                    destroy_obj_code += '%s{\n' % indent
-                    destroy_obj_code += '%s    std::lock_guard<std::mutex> lock(global_lock);\n' % indent
-                    destroy_obj_code += '%s    ValidateDestroyObject(%s, %s, %s, pAllocator, %s, %s);\n' % (indent, cmd_info[0].name, cmd_info[param].name, self.GetVulkanObjType(cmd_info[param].type), compatalloc_vuid, nullalloc_vuid)
-                    destroy_obj_code += '%s    RecordDestroyObject(%s, %s, %s);\n' % (indent, cmd_info[0].name, cmd_info[param].name, self.GetVulkanObjType(cmd_info[param].type))
-                    destroy_obj_code += '%s}\n' % indent
-        return object_array, destroy_obj_code
+                    validate_code += '%s    skip |= ValidateDestroyObject(%s, %s, %s, pAllocator, %s, %s);\n' % (indent, cmd_info[0].name, cmd_info[param].name, self.GetVulkanObjType(cmd_info[param].type), compatalloc_vuid, nullalloc_vuid)
+                    record_code += '%s    RecordDestroyObject(%s, %s, %s);\n' % (indent, cmd_info[0].name, cmd_info[param].name, self.GetVulkanObjType(cmd_info[param].type))
+        return object_array, validate_code, record_code
     #
     # Output validation for a single object (obj_count is NULL) or a counted list of objects
     def outputObjects(self, obj_type, obj_name, obj_count, prefix, index, indent, disp_name, parent_name, null_allowed, top_level):
@@ -819,7 +812,15 @@
     #
     # For a particular API, generate the object handling code
     def generate_wrapping_code(self, cmd):
-        indent = '    '
+        indent = ''
+        pre_call_validate = ''
+        pre_call_record = ''
+        post_call_record = ''
+
+        destroy_array = False
+        validate_destroy_code = ''
+        record_destroy_code = ''
+
         proto = cmd.find('proto/name')
         params = cmd.findall('param')
         if proto.text is not None:
@@ -828,31 +829,16 @@
             disp_name = cmd_info[0].name
             # Handle object create operations if last parameter is created by this call
             if cmddata.iscreate:
-                create_obj_code = self.generate_create_object_code(indent, proto, params, cmd_info, cmddata.allocator)
-            else:
-                create_obj_code = ''
+                post_call_record += self.generate_create_object_code(indent, proto, params, cmd_info, cmddata.allocator)
             # Handle object destroy operations
             if cmddata.isdestroy:
-                (destroy_array, destroy_object_code) = self.generate_destroy_object_code(indent, proto, cmd_info)
-            else:
-                destroy_array = False
-                destroy_object_code = ''
-            param_pre_code = ''
-            param_post_code = ''
-            create_func = True if create_obj_code else False
-            destroy_func = True if destroy_object_code else False
-            param_pre_code = self.validate_objects(cmd_info, indent, '', 0, disp_name, proto.text, True)
-            param_post_code += create_obj_code
-            if destroy_object_code:
-                if destroy_array == True:
-                    param_post_code += destroy_object_code
-                else:
-                    param_pre_code += destroy_object_code
-            if param_pre_code:
-                if (not destroy_func) or (destroy_array):
-                    param_pre_code = '%s{\n%s%s%s%s}\n' % ('    ', indent, self.lock_guard(indent), param_pre_code, indent)
+                (destroy_array, validate_destroy_code, record_destroy_code) = self.generate_destroy_object_code(indent, proto, cmd_info)
 
-        return param_pre_code, param_post_code
+            pre_call_record += record_destroy_code
+            pre_call_validate += self.validate_objects(cmd_info, indent, '', 0, disp_name, proto.text, True)
+            pre_call_validate += validate_destroy_code
+
+        return pre_call_validate, pre_call_record, post_call_record
     #
     # Capture command parameter info needed to create, destroy, and validate objects
     def genCmd(self, cmdinfo, cmdname, alias):
@@ -929,22 +915,64 @@
                 self.intercepts += [ '    {"%s", (void *)%s},' % (cmdname,cmdname[2:]) ]
                 continue
             # Generate object handling code
-            (api_pre, api_post) = self.generate_wrapping_code(cmdinfo.elem)
+            (pre_call_validate, pre_call_record, post_call_record) = self.generate_wrapping_code(cmdinfo.elem)
+
             # If API doesn't contain any object handles, don't fool with it
-            if not api_pre and not api_post:
+            if not pre_call_validate and not pre_call_record and not post_call_record:
                 continue
+
             feature_extra_protect = cmddata.extra_protect
             if (feature_extra_protect is not None):
                 self.appendSection('command', '')
                 self.appendSection('command', '#ifdef '+ feature_extra_protect)
                 self.intercepts += [ '#ifdef %s' % feature_extra_protect ]
+
             # Add intercept to procmap
             self.intercepts += [ '    {"%s", (void*)%s},' % (cmdname,cmdname[2:]) ]
+
             decls = self.makeCDecls(cmdinfo.elem)
+
+            # Gather the parameter items
+            params = cmdinfo.elem.findall('param/name')
+            # Pull out the text for each of the parameters, separate them by commas in a list
+            paramstext = ', '.join([str(param.text) for param in params])
+            # Generate the API call template
+            fcn_call = cmdinfo.elem.attrib.get('name').replace('vk', 'TOKEN', 1) + '(' + paramstext + ');'
+
+            func_decl_template = decls[0][:-1].split('VKAPI_CALL ')
+            func_decl_template = func_decl_template[1] + ' {'
+
+            # Output PreCallValidateAPI function if necessary
+            if pre_call_validate:
+                pre_cv_func_decl = 'static bool PreCallValidate' + func_decl_template
+                self.appendSection('command', '')
+                self.appendSection('command', pre_cv_func_decl)
+                self.appendSection('command', '    bool skip = false;')
+                self.appendSection('command', pre_call_validate)
+                self.appendSection('command', '    return skip;')
+                self.appendSection('command', '}')
+
+            # Output PreCallRecordAPI function if necessary
+            if pre_call_record:
+                pre_cr_func_decl = 'static void PreCallRecord' + func_decl_template
+                self.appendSection('command', '')
+                self.appendSection('command', pre_cr_func_decl)
+                self.appendSection('command', pre_call_record)
+                self.appendSection('command', '}')
+
+            # Output PosCallRecordAPI function if necessary
+            if post_call_record:
+                post_cr_func_decl = 'static void PostCallRecord' + func_decl_template
+                self.appendSection('command', '')
+                self.appendSection('command', post_cr_func_decl)
+                self.appendSection('command', post_call_record)
+                self.appendSection('command', '}')
+
+            # Output API function:
             self.appendSection('command', '')
             self.appendSection('command', decls[0][:-1])
             self.appendSection('command', '{')
-            self.appendSection('command', '    bool skip = false;')
+
             # Handle return values, if any
             resulttype = cmdinfo.elem.find('proto/type')
             if (resulttype is not None and resulttype.text == 'void'):
@@ -953,14 +981,35 @@
                 assignresult = resulttype.text + ' result = '
             else:
                 assignresult = ''
-            # Pre-pend pre-api-call codegen
-            if api_pre:
-                self.appendSection('command', "\n".join(str(api_pre).rstrip().split("\n")))
-            # Generate the API call itself
-            # Gather the parameter items
-            params = cmdinfo.elem.findall('param/name')
-            # Pull out the text for each of the parameters, separate them by commas in a list
-            paramstext = ', '.join([str(param.text) for param in params])
+
+            # Output all pre-API-call source
+            if pre_call_validate:
+                self.appendSection('command', '    bool skip = false;')
+            if pre_call_validate or pre_call_record:
+                self.appendSection('command', '    {')
+                self.appendSection('command', '        std::lock_guard<std::mutex> lock(global_lock);')
+            # If necessary, add call to PreCallValidateApi(...);
+            if pre_call_validate:
+                pcv_call = fcn_call.replace('TOKEN', '        skip |= PreCallValidate', 1)
+                self.appendSection('command', pcv_call)
+                if assignresult != '':
+                    if resulttype.text == 'VkResult':
+                        self.appendSection('command', '        if (skip) return VK_ERROR_VALIDATION_FAILED_EXT;')
+                    elif resulttype.text == 'VkBool32':
+                        self.appendSection('command', '        if (skip) return VK_FALSE;')
+                    else:
+                        raise Exception('Unknown result type ' + resulttype.text)
+                else:
+                    self.appendSection('command', '        if (skip) return;')
+            # If necessary, add call to PreCallRecordApi(...);
+            if pre_call_record:
+                pre_cr_call = fcn_call.replace('TOKEN', '        PreCallRecord', 1)
+                self.appendSection('command', pre_cr_call)
+
+            if pre_call_validate or pre_call_record:
+                self.appendSection('command', '    }')
+
+            # Build down-chain call strings and output source
             # Use correct dispatch table
             disp_name = cmdinfo.elem.find('param/name').text
             disp_type = cmdinfo.elem.find('param/type').text
@@ -969,20 +1018,21 @@
             else:
                 object_type = 'device'
             dispatch_table = 'GetLayerDataPtr(get_dispatch_key(%s), layer_data_map)->%s_dispatch_table.' % (disp_name, object_type)
-            API = cmdinfo.elem.attrib.get('name').replace('vk', dispatch_table, 1)
-            # Put all this together for the final down-chain call
-            if assignresult != '':
-                if resulttype.text == 'VkResult':
-                    self.appendSection('command', '    if (skip) return VK_ERROR_VALIDATION_FAILED_EXT;')
-                elif resulttype.text == 'VkBool32':
-                    self.appendSection('command', '    if (skip) return VK_FALSE;')
-                else:
-                    raise Exception('Unknown result type ' + resulttype.text)
-            else:
-                self.appendSection('command', '    if (skip) return;')
-            self.appendSection('command', '    ' + assignresult + API + '(' + paramstext + ');')
-            # And add the post-API-call codegen
-            self.appendSection('command', "\n".join(str(api_post).rstrip().split("\n")))
+            down_chain_call = fcn_call.replace('TOKEN', dispatch_table, 1)
+            self.appendSection('command', '    ' + assignresult + down_chain_call)
+
+            # If necessary, add call to PostCallRecordApi(...);
+            if post_call_record:
+                if assignresult:
+                    if resulttype.text == 'VkResult':
+                        self.appendSection('command', '    if (VK_SUCCESS == result) {')
+                    elif resulttype.text == 'VkBool32':
+                        self.appendSection('command', '    if (VK_TRUE == result) {')
+                self.appendSection('command', '        std::lock_guard<std::mutex> lock(global_lock);')
+                post_cr_call = fcn_call.replace('TOKEN', '        PostCallRecord', 1)
+                self.appendSection('command', post_cr_call)
+                self.appendSection('command', '    }')
+
             # Handle the return result variable, if any
             if (resulttype is not None):
                 self.appendSection('command', '    return result;')