HLSL: add helper access methods to TAttributeMap
There was some code replication around getting string and integer
values out of an attribute map. This adds new methods to the
TAttributeMap class to encapsulate some accessor details.
diff --git a/hlsl/hlslAttributes.cpp b/hlsl/hlslAttributes.cpp
index 61ef805..2a8e370 100644
--- a/hlsl/hlslAttributes.cpp
+++ b/hlsl/hlslAttributes.cpp
@@ -36,6 +36,7 @@
#include "hlslAttributes.h"
#include <cstdlib>
#include <cctype>
+#include <algorithm>
namespace glslang {
// Map the given string to an attribute enum from TAttributeType,
@@ -131,4 +132,51 @@
return attributes.find(attr) != attributes.end();
}
+ // extract integers out of attribute arguments stored in attribute aggregate
+ bool TAttributeMap::getInt(TAttributeType attr, int& value, int argNum) const
+ {
+ const TConstUnion* intConst = getConstUnion(attr, EbtInt, argNum);
+
+ if (intConst == nullptr)
+ return false;
+
+ value = intConst->getIConst();
+ return true;
+ };
+
+ // extract strings out of attribute arguments stored in attribute aggregate.
+ // convert to lower case if converToLower is true (for case-insensitive compare convenience)
+ bool TAttributeMap::getString(TAttributeType attr, TString& value, int argNum, bool convertToLower) const
+ {
+ const TConstUnion* stringConst = getConstUnion(attr, EbtString, argNum);
+
+ if (stringConst == nullptr)
+ return false;
+
+ value = *stringConst->getSConst();
+
+ // Convenience.
+ if (convertToLower)
+ std::transform(value.begin(), value.end(), value.begin(), ::tolower);
+
+ return true;
+ };
+
+ // Helper to get attribute const union. Returns nullptr on failure.
+ const TConstUnion* TAttributeMap::getConstUnion(TAttributeType attr, TBasicType basicType, int argNum) const
+ {
+ const TIntermAggregate* attrAgg = (*this)[attr];
+ if (attrAgg == nullptr)
+ return nullptr;
+
+ if (argNum >= int(attrAgg->getSequence().size()))
+ return nullptr;
+
+ const TConstUnion* constVal = &attrAgg->getSequence()[argNum]->getAsConstantUnion()->getConstArray()[0];
+ if (constVal == nullptr || constVal->getType() != basicType)
+ return nullptr;
+
+ return constVal;
+ }
+
} // end namespace glslang
diff --git a/hlsl/hlslAttributes.h b/hlsl/hlslAttributes.h
index 16ec31d..2d7b6c7 100644
--- a/hlsl/hlslAttributes.h
+++ b/hlsl/hlslAttributes.h
@@ -93,7 +93,16 @@
// True if entry exists in map (even if value is nullptr)
bool contains(TAttributeType) const;
+ // Obtain attribute as integer
+ bool getInt(TAttributeType attr, int& value, int argNum = 0) const;
+
+ // Obtain attribute as string, with optional to-lower transform
+ bool getString(TAttributeType attr, TString& value, int argNum = 0, bool convertToLower = true) const;
+
protected:
+ // Helper to get attribute const union
+ const TConstUnion* getConstUnion(TAttributeType attr, TBasicType, int argNum) const;
+
// Find an attribute enum given its name.
static TAttributeType attributeFromName(const TString& nameSpace, const TString& name);
diff --git a/hlsl/hlslParseHelper.cpp b/hlsl/hlslParseHelper.cpp
index 706173d..9899cc5 100755
--- a/hlsl/hlslParseHelper.cpp
+++ b/hlsl/hlslParseHelper.cpp
@@ -1717,36 +1717,33 @@
}
// MaxVertexCount
- const TIntermAggregate* maxVertexCount = attributes[EatMaxVertexCount];
- if (maxVertexCount != nullptr) {
- if (! intermediate.setVertices(maxVertexCount->getSequence()[0]->getAsConstantUnion()->
- getConstArray()[0].getIConst())) {
- error(loc, "cannot change previously set maxvertexcount attribute", "", "");
+ if (attributes.contains(EatMaxVertexCount)) {
+ int maxVertexCount;
+
+ if (! attributes.getInt(EatMaxVertexCount, maxVertexCount)) {
+ error(loc, "invalid maxvertexcount", "", "");
+ } else {
+ if (! intermediate.setVertices(maxVertexCount))
+ error(loc, "cannot change previously set maxvertexcount attribute", "", "");
}
}
// Handle [patchconstantfunction("...")]
- const TIntermAggregate* pcfAttr = attributes[EatPatchConstantFunc];
- if (pcfAttr != nullptr) {
- const TConstUnion& pcfName = pcfAttr->getSequence()[0]->getAsConstantUnion()->getConstArray()[0];
-
- if (pcfName.getType() != EbtString) {
+ if (attributes.contains(EatPatchConstantFunc)) {
+ TString pcfName;
+ if (! attributes.getString(EatPatchConstantFunc, pcfName, 0, false)) {
error(loc, "invalid patch constant function", "", "");
} else {
- patchConstantFunctionName = *pcfName.getSConst();
+ patchConstantFunctionName = pcfName;
}
}
// Handle [domain("...")]
- const TIntermAggregate* domainAttr = attributes[EatDomain];
- if (domainAttr != nullptr) {
- const TConstUnion& domainType = domainAttr->getSequence()[0]->getAsConstantUnion()->getConstArray()[0];
- if (domainType.getType() != EbtString) {
+ if (attributes.contains(EatDomain)) {
+ TString domainStr;
+ if (! attributes.getString(EatDomain, domainStr)) {
error(loc, "invalid domain", "", "");
} else {
- TString domainStr = *domainType.getSConst();
- std::transform(domainStr.begin(), domainStr.end(), domainStr.begin(), ::tolower);
-
TLayoutGeometry domain = ElgNone;
if (domainStr == "tri") {
@@ -1770,15 +1767,11 @@
}
// Handle [outputtopology("...")]
- const TIntermAggregate* topologyAttr = attributes[EatOutputTopology];
- if (topologyAttr != nullptr) {
- const TConstUnion& topoType = topologyAttr->getSequence()[0]->getAsConstantUnion()->getConstArray()[0];
- if (topoType.getType() != EbtString) {
+ if (attributes.contains(EatOutputTopology)) {
+ TString topologyStr;
+ if (! attributes.getString(EatOutputTopology, topologyStr)) {
error(loc, "invalid outputtopology", "", "");
} else {
- TString topologyStr = *topoType.getSConst();
- std::transform(topologyStr.begin(), topologyStr.end(), topologyStr.begin(), ::tolower);
-
TVertexOrder vertexOrder = EvoNone;
TLayoutGeometry primitive = ElgNone;
@@ -1808,15 +1801,11 @@
}
// Handle [partitioning("...")]
- const TIntermAggregate* partitionAttr = attributes[EatPartitioning];
- if (partitionAttr != nullptr) {
- const TConstUnion& partType = partitionAttr->getSequence()[0]->getAsConstantUnion()->getConstArray()[0];
- if (partType.getType() != EbtString) {
+ if (attributes.contains(EatPartitioning)) {
+ TString partitionStr;
+ if (! attributes.getString(EatPartitioning, partitionStr)) {
error(loc, "invalid partitioning", "", "");
} else {
- TString partitionStr = *partType.getSConst();
- std::transform(partitionStr.begin(), partitionStr.end(), partitionStr.begin(), ::tolower);
-
TVertexSpacing partitioning = EvsNone;
if (partitionStr == "integer") {
@@ -1837,14 +1826,11 @@
}
// Handle [outputcontrolpoints("...")]
- const TIntermAggregate* outputControlPoints = attributes[EatOutputControlPoints];
- if (outputControlPoints != nullptr) {
- const TConstUnion& ctrlPointConst =
- outputControlPoints->getSequence()[0]->getAsConstantUnion()->getConstArray()[0];
- if (ctrlPointConst.getType() != EbtInt) {
+ if (attributes.contains(EatOutputControlPoints)) {
+ int ctrlPoints;
+ if (! attributes.getInt(EatOutputControlPoints, ctrlPoints)) {
error(loc, "invalid outputcontrolpoints", "", "");
} else {
- const int ctrlPoints = ctrlPointConst.getIConst();
if (! intermediate.setVertices(ctrlPoints)) {
error(loc, "cannot change previously set outputcontrolpoints attribute", "", "");
}
@@ -1856,37 +1842,23 @@
// attributes.
void HlslParseContext::transferTypeAttributes(const TAttributeMap& attributes, TType& type)
{
- // extract integers out of attribute arguments stored in attribute aggregate
- const auto getInt = [&](TAttributeType attr, int argNum, int& value) -> bool {
- const TIntermAggregate* attrAgg = attributes[attr];
- if (attrAgg == nullptr)
- return false;
- if (argNum >= (int)attrAgg->getSequence().size())
- return false;
- const TConstUnion& intConst = attrAgg->getSequence()[argNum]->getAsConstantUnion()->getConstArray()[0];
- if (intConst.getType() != EbtInt)
- return false;
- value = intConst.getIConst();
- return true;
- };
-
// location
int value;
- if (getInt(EatLocation, 0, value))
+ if (attributes.getInt(EatLocation, value))
type.getQualifier().layoutLocation = value;
// binding
- if (getInt(EatBinding, 0, value)) {
+ if (attributes.getInt(EatBinding, value)) {
type.getQualifier().layoutBinding = value;
type.getQualifier().layoutSet = 0;
}
// set
- if (getInt(EatBinding, 1, value))
+ if (attributes.getInt(EatBinding, value, 1))
type.getQualifier().layoutSet = value;
// input attachment
- if (getInt(EatInputAttachment, 0, value))
+ if (attributes.getInt(EatInputAttachment, value))
type.getQualifier().layoutAttachment = value;
}