layers: Add NULL/sType checks for struct members
Add param_checker support for validating struct members. Messages
are logged for the following conditions, based on conditions
specified in vk.xml:
- A pointer is NULL and is not marked as optional in the XML
- An array is NULL and is not marked as optional in the XML, unless
its count is 0
- An array count is 0 and is not marked as optional in the XML
- A structure's sType value does not match the value specified
in the XML
Addresses GL105, GL109, GH82
Change-Id: I7063fe2582b30fdfc0006fe945a0f9c84a2aa66a
diff --git a/generator.py b/generator.py
index df5b7fe..cdca69a 100644
--- a/generator.py
+++ b/generator.py
@@ -2707,19 +2707,28 @@
OutputGenerator.__init__(self, errFile, warnFile, diagFile)
self.INDENT_SPACES = 4
# Commands to ignore
- self.blacklist = ['vkCreateInstance', 'vkCreateDevice', 'vkGetInstanceProcAddr', 'vkGetDeviceProcAddr',
- 'vkEnumerateInstanceLayerProperties', 'vkEnumerateInstanceExtensionsProperties',
- 'vkEnumerateDeviceLayerProperties', 'vkEnumerateDeviceExtensionsProperties',
- 'vkCreateDebugReportCallbackEXT', 'vkDebugReportMessageEXT']
+ self.blacklist = [
+ 'vkCreateInstance', 'vkCreateDevice',
+ 'vkGetInstanceProcAddr', 'vkGetDeviceProcAddr',
+ 'vkEnumerateInstanceLayerProperties',
+ 'vkEnumerateInstanceExtensionsProperties',
+ 'vkEnumerateDeviceLayerProperties',
+ 'vkEnumerateDeviceExtensionsProperties',
+ 'vkCreateDebugReportCallbackEXT',
+ 'vkDebugReportMessageEXT']
# Internal state - accumulators for different inner block text
self.sections = dict([(section, []) for section in self.ALL_SECTIONS])
- self.stypes = []
- self.structTypes = dict()
- self.commands = []
+ self.structNames = [] # List of Vulkan struct typenames
+ self.stypes = [] # Values from the VkStructureType enumeration
+ self.structTypes = dict() # Map of Vulkan struct typename to required VkStructureType
+ self.commands = [] # List of CommandData records for all Vulkan commands
+ self.structMembers = [] # List of StructMemberData records for all Vulkan structs
+ self.validatedStructs = set() # Set of structs containing members that require validation
# Named tuples to store struct and command data
self.StructType = namedtuple('StructType', ['name', 'value'])
self.CommandParam = namedtuple('CommandParam', ['type', 'name', 'ispointer', 'isstaticarray', 'isoptional', 'iscount', 'len', 'cdecl'])
self.CommandData = namedtuple('CommandData', ['name', 'params', 'cdecl'])
+ self.StructMemberData = namedtuple('StructMemberData', ['name', 'members'])
#
def incIndent(self, indent):
inc = ' ' * self.INDENT_SPACES
@@ -2774,9 +2783,12 @@
# end function prototypes separately for this feature. They're only
# printed in endFeature().
self.sections = dict([(section, []) for section in self.ALL_SECTIONS])
+ self.structNames = []
self.stypes = []
self.structTypes = dict()
self.commands = []
+ self.structMembers = []
+ self.validatedStructs = set()
def endFeature(self):
# C-specific
# Actually write the interface to the output file.
@@ -2787,7 +2799,10 @@
# or move it below the 'for section...' loop.
if (self.featureExtraProtect != None):
write('#ifdef', self.featureExtraProtect, file=self.outFile)
- # Generate the command text from the captured data
+ # Generate the struct member checking code from the captured data
+ self.prepareStructMemberData()
+ self.processStructMemberData()
+ # Generate the command parameter checking code from the captured data
self.processCmdData()
if (self.sections['command']):
if (self.genOpts.protectProto):
@@ -2814,6 +2829,7 @@
# generating a structure. Otherwise, emit the tag text.
category = typeElem.get('category')
if (category == 'struct' or category == 'union'):
+ self.structNames.append(name)
self.genStruct(typeinfo, name)
#
# Struct parameter check generation.
@@ -2825,18 +2841,29 @@
# structs etc.)
def genStruct(self, typeinfo, typeName):
OutputGenerator.genStruct(self, typeinfo, typeName)
- for member in typeinfo.elem.findall('.//member'):
+ members = typeinfo.elem.findall('.//member')
+ #
+ # Iterate over members once to get length parameters for arrays
+ lens = set()
+ for member in members:
+ len = self.getLen(member)
+ if len:
+ lens.add(len)
+ #
+ # Generate member info
+ membersInfo = []
+ for member in members:
# Get the member's type and name
- t = self.getTypeNameTuple(member)
- type = t[0]
- name = t[1]
- value = ''
+ info = self.getTypeNameTuple(member)
+ type = info[0]
+ name = info[1]
+ stypeValue = ''
# Process VkStructureType
if type == 'VkStructureType':
# Extract the required struct type value from the comments
# embedded in the original text defining the 'typeinfo' element
rawXml = etree.tostring(typeinfo.elem).decode('ascii')
- result = re.search('VK_STRUCTURE_TYPE_\w+', rawXml)
+ result = re.search(r'VK_STRUCTURE_TYPE_\w+', rawXml)
if result:
value = result.group(0)
# Make sure value is valid
@@ -2844,10 +2871,33 @@
# print('WARNING: {} is not part of the VkStructureType enumeration [{}]'.format(value, typeName))
else:
value = '<ERROR>'
- # Store the required value
+ # Store the required type value
self.structTypes[typeName] = self.StructType(name=name, value=value)
+ #
+ # Store pointer/array/string info
+ # Check for parameter name in lens set
+ iscount = False
+ if name in lens:
+ iscount = True
+ # The pNext members are not tagged as optional, but are treated as
+ # optional for parameter NULL checks. Static array members
+ # are also treated as optional to skip NULL pointer validation, as
+ # they won't be NULL.
+ isstaticarray = self.paramIsStaticArray(member)
+ isoptional = False
+ if self.paramIsOptional(member) or (name == 'pNext') or (isstaticarray):
+ isoptional = True
+ membersInfo.append(self.CommandParam(type=type, name=name,
+ ispointer=self.paramIsPointer(member),
+ isstaticarray=isstaticarray,
+ isoptional=isoptional,
+ iscount=iscount,
+ len=self.getLen(member),
+ cdecl=self.makeCParamDecl(member, 0)))
+ self.structMembers.append(self.StructMemberData(name=typeName, members=membersInfo))
#
- # Group (e.g. C "enum" type) generation.
+ # Capture group (e.g. C "enum" type) info to be used for
+ # param check code generation.
# These are concatenated together with other types.
def genGroup(self, groupinfo, groupName):
OutputGenerator.genGroup(self, groupinfo, groupName)
@@ -2857,7 +2907,8 @@
name = elem.get('name')
self.stypes.append(name)
#
- # Command generation
+ # Capture command parameter info to be used for param
+ # check code generation.
def genCmd(self, cmdinfo, name):
OutputGenerator.genCmd(self, cmdinfo, name)
if name not in self.blacklist:
@@ -2887,18 +2938,18 @@
#
# Check if the parameter passed in is a pointer
def paramIsPointer(self, param):
- ispointer = False
+ ispointer = 0
paramtype = param.find('type')
- if paramtype.tail is not None and '*' in paramtype.tail:
- ispointer = True
+ if (paramtype.tail is not None) and ('*' in paramtype.tail):
+ ispointer = paramtype.tail.count('*')
return ispointer
#
# Check if the parameter passed in is a static array
def paramIsStaticArray(self, param):
- isstaticarray = False
- tail = param.find('name').tail
- if tail and tail[0] == '[':
- isstaticarray = True
+ isstaticarray = 0
+ paramname = param.find('name')
+ if (paramname.tail is not None) and ('[' in paramname.tail):
+ isstaticarray = paramname.tail.count('[')
return isstaticarray
#
# Check if the parameter passed in is optional
@@ -2926,10 +2977,18 @@
#
# Retrieve the value of the len tag
def getLen(self, param):
+ result = None
len = param.attrib.get('len')
if len and len != 'null-terminated':
- return len
- return None
+ # For string arrays, 'len' can look like 'count,null-terminated',
+ # indicating that we have a null terminated array of strings. We
+ # strip the null-terminated from the 'len' field and only return
+ # the parameter specifying the string count
+ if 'null-terminated' in len:
+ result = len.split(',')[0]
+ else:
+ result = len
+ return result
#
# Retrieve the type and name for a parameter
def getTypeNameTuple(self, param):
@@ -2948,113 +3007,35 @@
if param.name == name:
return param
return None
+ #
+ # Get the length paramater record for the specified parameter name
+ def getLenParam(self, params, name):
+ lenParam = None
+ if name:
+ if '->' in name:
+ # The count is obtained by dereferencing a member of a struct parameter
+ lenParam = self.CommandParam(name=name, iscount=True, ispointer=False, isoptional=False, type=None, len=None, isstaticarray=None, cdecl=None)
+ elif 'latexmath' in name:
+ result = re.search('mathit\{(\w+)\}', name)
+ lenParam = self.getParamByName(params, result.group(1))
+ elif '/' in name:
+ # Len specified as an equation such as dataSize/4
+ lenParam = self.getParamByName(params, name.split('/')[0])
+ else:
+ lenParam = self.getParamByName(params, name)
+ return lenParam
+ #
+ # Convert a vulkan.h command declaration into a param_check.h definition
def getCmdDef(self, cmd):
- # TODO: Override makeCDecls
#
# Strip the trailing ';' and split into individual lines
lines = cmd.cdecl[:-1].split('\n')
# Replace Vulkan prototype
lines[0] = 'static VkBool32 param_check_' + cmd.name + '('
# Replace the first argument with debug_report_data
- lines[1] = ' debug_report_data* report_data,'
+ lines[1] = ' debug_report_data*'.ljust(self.genOpts.alignFuncParam) + 'report_data,'
return '\n'.join(lines)
#
- # Generate the command param check code from the captured data
- def processCmdData(self):
- indent = self.incIndent(None)
- for command in self.commands:
- cmdBody = ''
- unused = []
- for param in command.params:
- #
- # Check for NULL pointers, ignore the inout count parameters that
- # will be validated with their associated array
- if (param.ispointer or param.isstaticarray) and not param.iscount:
- #
- # Parameters for function argument generation
- checkExpr = '' # Code to check the current parameter
- req = 'VK_TRUE' # Paramerter can be NULL
- cpReq = 'VK_TRUE' # Count pointer can be NULL
- cvReq = 'VK_TRUE' # Count value can be 0
- lenParam = None
- #
- # Generate required parameter string for the pointer and count values
- if param.isoptional:
- req = 'VK_FALSE'
- if param.len:
- # The parameter is an array with an explicit count parameter
- # TODO: Better handling for special case counts and counts from struct members
- if param.len in ['pAllocateInfo->descriptorSetCount', 'pAllocateInfo->commandBufferCount']:
- lenParam = self.CommandParam(name=param.len, iscount=True, ispointer=False, isoptional=False, type=None, len=None, isstaticarray=None, cdecl=None)
- elif param.len == 'dataSize/4':
- lenParam = self.getParamByName(command.params, 'dataSize')
- else:
- lenParam = self.getParamByName(command.params, param.len)
- if lenParam.ispointer:
- # Count parameters that are pointers are inout
- if type(lenParam.isoptional) is list:
- if lenParam.isoptional[0]:
- cpReq = 'VK_FALSE'
- if lenParam.isoptional[1]:
- cvReq = 'VK_FALSE'
- else:
- if lenParam.isoptional:
- cpReq = 'VK_FALSE'
- else:
- if lenParam.isoptional:
- cvReq = 'VK_FALSE'
- #
- # If this is a pointer to a struct with an sType field, verify the type
- if param.type in self.structTypes:
- stype = self.structTypes[param.type]
- if lenParam:
- # This is an array
- if lenParam.ispointer:
- # When the length parameter is a pointer, there is an extra Boolean parameter in the function call to indicate if it is required
- checkExpr = 'skipCall |= validate_struct_type_array(report_data, "{}", "{}", "{}", "{}", {}, {}, {}, {}, {}, {});\n'.format(command.name, lenParam.name, param.name, stype.value, lenParam.name, param.name, stype.value, cpReq, cvReq, req)
- else:
- checkExpr = 'skipCall |= validate_struct_type_array(report_data, "{}", "{}", "{}", "{}", {}, {}, {}, {}, {});\n'.format(command.name, lenParam.name, param.name, stype.value, lenParam.name, param.name, stype.value, cvReq, req)
- else:
- checkExpr = 'skipCall |= validate_struct_type(report_data, "{}", "{}", "{}", {}, {}, {});\n'.format(command.name, param.name, stype.value, param.name, stype.value, req)
- else:
- if lenParam:
- # This is an array
- if lenParam.ispointer:
- # When the length parameter is a pointer, there is an extra Boolean parameter in the function call to indicate if it is required
- checkExpr = 'skipCall |= validate_array(report_data, "{}", "{}", "{}", {}, {}, {}, {}, {});\n'.format(command.name, lenParam.name, param.name, lenParam.name, param.name, cpReq, cvReq, req)
- else:
- checkExpr = 'skipCall |= validate_array(report_data, "{}", "{}", "{}", {}, {}, {}, {});\n'.format(command.name, lenParam.name, param.name, lenParam.name, param.name, cvReq, req)
- elif not param.isoptional:
- checkExpr = indent + 'skipCall |= validate_required_pointer(report_data, "{}", "{}", {});\n'.format(command.name, param.name, param.name)
- else:
- unused.append(param.name)
- # Append the parameter check to the function body for the current command
- if checkExpr:
- cmdBody += '\n'
- if lenParam and ('->' in lenParam.name):
- # Add checks to ensure the validation call does not dereference a NULL pointer to obtain the count
- cmdBody += self.genCheckedLengthCall(indent, lenParam.name, checkExpr)
- else:
- cmdBody += indent + checkExpr
- elif not param.iscount:
- unused.append(param.name)
- if cmdBody:
- cmdDef = self.getCmdDef(command) + '\n'
- cmdDef += '{\n'
- indent = self.incIndent(None)
- # Ignore the first dispatch handle parameter, which is not
- # processed by param_check
- for name in unused[1:]:
- cmdDef += indent + 'UNUSED_PARAMETER({});\n'.format(name)
- if len(unused) > 1:
- cmdDef += '\n'
- cmdDef += indent + 'VkBool32 skipCall = VK_FALSE;\n'
- cmdDef += cmdBody
- cmdDef += '\n'
- cmdDef += indent + 'return skipCall;\n'
- cmdDef += '}\n'
- self.appendSection('command', cmdDef)
- #
# Generate the code to check for a NULL dereference before calling the
# validation function
def genCheckedLengthCall(self, indent, name, expr):
@@ -3076,5 +3057,183 @@
return checkedExpr
# No if statements were required
return indent + expr
+ #
+ # Generate the parameter checking code
+ def genFuncBody(self, indent, name, values, valuePrefix, variablePrefix):
+ funcBody = ''
+ unused = []
+ for value in values:
+ checkExpr = '' # Code to check the current parameter
+ #
+ # Check for NULL pointers, ignore the inout count parameters that
+ # will be validated with their associated array
+ if (value.ispointer or value.isstaticarray) and not value.iscount:
+ #
+ # Generate the full name of the value, which will be printed in
+ # the error message, by adding the variable prefix to the
+ # value name
+ valueDisplayName = '(std::string({}) + std::string("{}")).c_str()'.format(variablePrefix, value.name) if variablePrefix else '"{}"'.format(value.name)
+ #
+ # Parameters for function argument generation
+ req = 'VK_TRUE' # Paramerter can be NULL
+ cpReq = 'VK_TRUE' # Count pointer can be NULL
+ cvReq = 'VK_TRUE' # Count value can be 0
+ lenParam = None
+ #
+ # Generate required/optional parameter strings for the pointer and count values
+ if value.isoptional:
+ req = 'VK_FALSE'
+ if value.len:
+ # The parameter is an array with an explicit count parameter
+ lenParam = self.getLenParam(values, value.len)
+ if not lenParam: print(value.len)
+ if lenParam.ispointer:
+ # Count parameters that are pointers are inout
+ if type(lenParam.isoptional) is list:
+ if lenParam.isoptional[0]:
+ cpReq = 'VK_FALSE'
+ if lenParam.isoptional[1]:
+ cvReq = 'VK_FALSE'
+ else:
+ if lenParam.isoptional:
+ cpReq = 'VK_FALSE'
+ else:
+ if lenParam.isoptional:
+ cvReq = 'VK_FALSE'
+ #
+ # If this is a pointer to a struct with an sType field, verify the type
+ if value.type in self.structTypes:
+ stype = self.structTypes[value.type]
+ if lenParam:
+ # This is an array
+ if lenParam.ispointer:
+ # When the length parameter is a pointer, there is an extra Boolean parameter in the function call to indicate if it is required
+ checkExpr = 'skipCall |= validate_struct_type_array(report_data, {}, "{ln}", {dn}, "{sv}", {pf}{ln}, {pf}{vn}, {sv}, {}, {}, {});\n'.format(name, cpReq, cvReq, req, ln=lenParam.name, dn=valueDisplayName, vn=value.name, sv=stype.value, pf=valuePrefix)
+ else:
+ checkExpr = 'skipCall |= validate_struct_type_array(report_data, {}, "{ln}", {dn}, "{sv}", {pf}{ln}, {pf}{vn}, {sv}, {}, {});\n'.format(name, cvReq, req, ln=lenParam.name, dn=valueDisplayName, vn=value.name, sv=stype.value, pf=valuePrefix)
+ else:
+ checkExpr = 'skipCall |= validate_struct_type(report_data, {}, {}, "{sv}", {}{vn}, {sv}, {});\n'.format(name, valueDisplayName, valuePrefix, req, vn=value.name, sv=stype.value)
+ else:
+ if lenParam:
+ # This is an array
+ if lenParam.ispointer:
+ # If count and array parameters are optional, there
+ # will be no validation
+ if req == 'VK_TRUE' or cpReq == 'VK_TRUE' or cvReq == 'VK_TRUE':
+ # When the length parameter is a pointer, there is an extra Boolean parameter in the function call to indicate if it is required
+ checkExpr = 'skipCall |= validate_array(report_data, {}, "{ln}", {dn}, {pf}{ln}, {pf}{vn}, {}, {}, {});\n'.format(name, cpReq, cvReq, req, ln=lenParam.name, dn=valueDisplayName, vn=value.name, pf=valuePrefix)
+ else:
+ # If count and array parameters are optional, there
+ # will be no validation
+ if req == 'VK_TRUE' or cvReq == 'VK_TRUE':
+ checkExpr = 'skipCall |= validate_array(report_data, {}, "{ln}", {dn}, {pf}{ln}, {pf}{vn}, {}, {});\n'.format(name, cvReq, req, ln=lenParam.name, dn=valueDisplayName, vn=value.name, pf=valuePrefix)
+ elif not value.isoptional:
+ checkExpr = 'skipCall |= validate_required_pointer(report_data, {}, {}, {}{vn});\n'.format(name, valueDisplayName, valuePrefix, vn=value.name)
+ else:
+ unused.append(value.name)
+ #
+ # If this is a pointer to a struct, see if it contains members
+ # that need to be checked
+ if value.type in self.validatedStructs:
+ if checkExpr:
+ checkExpr += '\n' + indent
+ #
+ # The name prefix used when reporting an error with a struct member (eg. the 'pCreateInfor->' in 'pCreateInfo->sType')
+ prefix = '(std::string({}) + std::string("{}->")).c_str()'.format(variablePrefix, value.name) if variablePrefix else '"{}->"'.format(value.name)
+ checkExpr += 'skipCall |= param_check_{}(report_data, {}, {}, {}{});\n'.format(value.type, name, prefix, valuePrefix, value.name)
+ elif value.type in self.validatedStructs:
+ # The name prefix used when reporting an error with a struct member (eg. the 'pCreateInfor->' in 'pCreateInfo->sType')
+ prefix = '(std::string({}) + std::string("{}.")).c_str()'.format(variablePrefix, value.name) if variablePrefix else '"{}."'.format(value.name)
+ checkExpr += 'skipCall |= param_check_{}(report_data, {}, {}, &({}{}));\n'.format(value.type, name, prefix, valuePrefix, value.name)
+ elif not value.iscount:
+ unused.append(value.name)
+ #
+ # Append the parameter check to the function body for the current command
+ if checkExpr:
+ funcBody += '\n'
+ if lenParam and ('->' in lenParam.name):
+ # Add checks to ensure the validation call does not dereference a NULL pointer to obtain the count
+ funcBody += self.genCheckedLengthCall(indent, lenParam.name, checkExpr)
+ else:
+ funcBody += indent + checkExpr
+ return funcBody, unused
+ #
+ # Post-process the collected struct member data to create a list of structs
+ # with members that need to be validated
+ def prepareStructMemberData(self):
+ for struct in self.structMembers:
+ for member in struct.members:
+ if not member.iscount:
+ lenParam = self.getLenParam(struct.members, member.len)
+ # The sType needs to be validated
+ # An required array/count needs to be validated
+ # A required pointer needs to be validated
+ validated = False
+ if member.type in self.structTypes:
+ validated = True
+ elif member.ispointer and lenParam: # This is an array
+ # Make sure len is not optional
+ if lenParam.ispointer:
+ if not lenParam.isoptional[0] or not lenParam.isoptional[1] or not member.isoptional:
+ validated = True
+ else:
+ if not lenParam.isoptional or not member.isoptional:
+ validated = True
+ elif member.ispointer and not member.isoptional:
+ validated = True
+ #
+ if validated:
+ self.validatedStructs.add(struct.name)
+ # Second pass to check for struct members that are structs
+ # requiring validation
+ for member in struct.members:
+ if member.type in self.validatedStructs:
+ self.validatedStructs.add(struct.name)
+ #
+ # Generate the struct member check code from the captured data
+ def processStructMemberData(self):
+ indent = self.incIndent(None)
+ for struct in self.structMembers:
+ # The string returned by genFuncBody will be nested in an if check
+ # for a NULL pointer, so needs its indent incremented
+ funcBody, unused = self.genFuncBody(self.incIndent(indent), 'pFuncName', struct.members, 'pStruct->', 'pVariableName')
+ if funcBody:
+ cmdDef = 'static VkBool32 param_check_{}(\n'.format(struct.name)
+ cmdDef += ' debug_report_data*'.ljust(self.genOpts.alignFuncParam) + ' report_data,\n'
+ cmdDef += ' const char*'.ljust(self.genOpts.alignFuncParam) + ' pFuncName,\n'
+ cmdDef += ' const char*'.ljust(self.genOpts.alignFuncParam) + ' pVariableName,\n'
+ cmdDef += ' const {}*'.format(struct.name).ljust(self.genOpts.alignFuncParam) + ' pStruct)\n'
+ cmdDef += '{\n'
+ cmdDef += indent + 'VkBool32 skipCall = VK_FALSE;\n'
+ cmdDef += '\n'
+ cmdDef += indent + 'if (pStruct != NULL) {'
+ cmdDef += funcBody
+ cmdDef += indent +'}\n'
+ cmdDef += '\n'
+ cmdDef += indent + 'return skipCall;\n'
+ cmdDef += '}\n'
+ self.appendSection('command', cmdDef)
+ #
+ # Generate the command param check code from the captured data
+ def processCmdData(self):
+ indent = self.incIndent(None)
+ for command in self.commands:
+ cmdBody, unused = self.genFuncBody(indent, '"{}"'.format(command.name), command.params, '', None)
+ if cmdBody:
+ cmdDef = self.getCmdDef(command) + '\n'
+ cmdDef += '{\n'
+ # Process unused parameters
+ # Ignore the first dispatch handle parameter, which is not
+ # processed by param_check
+ for name in unused[1:]:
+ cmdDef += indent + 'UNUSED_PARAMETER({});\n'.format(name)
+ if len(unused) > 1:
+ cmdDef += '\n'
+ cmdDef += indent + 'VkBool32 skipCall = VK_FALSE;\n'
+ cmdDef += cmdBody
+ cmdDef += '\n'
+ cmdDef += indent + 'return skipCall;\n'
+ cmdDef += '}\n'
+ self.appendSection('command', cmdDef)