Clean up binary operation constant folding code
Fix mixed up comments, remove unnecessary type conversions, clarify
variable names and improve formatting in a few places.
TEST=angle_unittests, WebGL conformance tests
BUG=angleproject:913
Change-Id: Ice8fe3682d8e97f42747752302a1fba116132df4
Reviewed-on: https://chromium-review.googlesource.com/266843
Reviewed-by: Jamie Madill <jmadill@chromium.org>
Tested-by: Olli Etuaho <oetuaho@nvidia.com>
Reviewed-by: Zhenyao Mo <zmo@chromium.org>
diff --git a/src/compiler/translator/IntermNode.cpp b/src/compiler/translator/IntermNode.cpp
index 67ec256..4d6da5d 100644
--- a/src/compiler/translator/IntermNode.cpp
+++ b/src/compiler/translator/IntermNode.cpp
@@ -666,7 +666,7 @@
// Returns the node to keep using, which may or may not be the node passed in.
//
TIntermTyped *TIntermConstantUnion::fold(
- TOperator op, TIntermTyped *constantNode, TInfoSink &infoSink)
+ TOperator op, TIntermConstantUnion *rightNode, TInfoSink &infoSink)
{
ConstantUnion *unionArray = getUnionArrayPointer();
@@ -675,39 +675,38 @@
size_t objectSize = getType().getObjectSize();
- if (constantNode)
+ if (rightNode)
{
// binary operations
- TIntermConstantUnion *node = constantNode->getAsConstantUnion();
- ConstantUnion *rightUnionArray = node->getUnionArrayPointer();
+ ConstantUnion *rightUnionArray = rightNode->getUnionArrayPointer();
TType returnType = getType();
if (!rightUnionArray)
return nullptr;
- // for a case like float f = 1.2 + vec4(2,3,4,5);
- if (constantNode->getType().getObjectSize() == 1 && objectSize > 1)
+ // for a case like float f = vec4(2, 3, 4, 5) + 1.2;
+ if (rightNode->getType().getObjectSize() == 1 && objectSize > 1)
{
rightUnionArray = new ConstantUnion[objectSize];
for (size_t i = 0; i < objectSize; ++i)
{
- rightUnionArray[i] = *node->getUnionArrayPointer();
+ rightUnionArray[i] = *rightNode->getUnionArrayPointer();
}
returnType = getType();
}
- else if (constantNode->getType().getObjectSize() > 1 && objectSize == 1)
+ else if (rightNode->getType().getObjectSize() > 1 && objectSize == 1)
{
- // for a case like float f = vec4(2,3,4,5) + 1.2;
- unionArray = new ConstantUnion[constantNode->getType().getObjectSize()];
- for (size_t i = 0; i < constantNode->getType().getObjectSize(); ++i)
+ // for a case like float f = 1.2 + vec4(2, 3, 4, 5);
+ unionArray = new ConstantUnion[rightNode->getType().getObjectSize()];
+ for (size_t i = 0; i < rightNode->getType().getObjectSize(); ++i)
{
unionArray[i] = *getUnionArrayPointer();
}
- returnType = node->getType();
- objectSize = constantNode->getType().getObjectSize();
+ returnType = rightNode->getType();
+ objectSize = rightNode->getType().getObjectSize();
}
- ConstantUnion *tempConstArray = NULL;
+ ConstantUnion *tempConstArray = nullptr;
TIntermConstantUnion *tempNode;
bool boolNodeFlag = false;
@@ -735,7 +734,7 @@
case EOpMatrixTimesMatrix:
{
if (getType().getBasicType() != EbtFloat ||
- node->getBasicType() != EbtFloat)
+ rightNode->getBasicType() != EbtFloat)
{
infoSink.info.message(
EPrefixInternalError, getLine(),
@@ -745,12 +744,12 @@
const int leftCols = getCols();
const int leftRows = getRows();
- const int rightCols = constantNode->getType().getCols();
- const int rightRows = constantNode->getType().getRows();
+ const int rightCols = rightNode->getType().getCols();
+ const int rightRows = rightNode->getType().getRows();
const int resultCols = rightCols;
const int resultRows = leftRows;
- tempConstArray = new ConstantUnion[resultCols*resultRows];
+ tempConstArray = new ConstantUnion[resultCols * resultRows];
for (int row = 0; row < resultRows; row++)
{
for (int column = 0; column < resultCols; column++)
@@ -862,7 +861,7 @@
case EOpMatrixTimesVector:
{
- if (node->getBasicType() != EbtFloat)
+ if (rightNode->getBasicType() != EbtFloat)
{
infoSink.info.message(
EPrefixInternalError, getLine(),
@@ -887,7 +886,7 @@
}
}
- returnType = node->getType();
+ returnType = rightNode->getType();
returnType.setPrimarySize(static_cast<unsigned char>(matrixRows));
tempNode = new TIntermConstantUnion(tempConstArray, returnType);
@@ -906,8 +905,8 @@
return nullptr;
}
- const int matrixCols = constantNode->getType().getCols();
- const int matrixRows = constantNode->getType().getRows();
+ const int matrixCols = rightNode->getType().getCols();
+ const int matrixRows = rightNode->getType().getRows();
tempConstArray = new ConstantUnion[matrixCols];
@@ -1035,8 +1034,8 @@
case EOpEqual:
if (getType().getBasicType() == EbtStruct)
{
- if (!CompareStructure(node->getType(),
- node->getUnionArrayPointer(),
+ if (!CompareStructure(rightNode->getType(),
+ rightNode->getUnionArrayPointer(),
unionArray))
{
boolNodeFlag = true;
@@ -1073,8 +1072,8 @@
case EOpNotEqual:
if (getType().getBasicType() == EbtStruct)
{
- if (CompareStructure(node->getType(),
- node->getUnionArrayPointer(),
+ if (CompareStructure(rightNode->getType(),
+ rightNode->getUnionArrayPointer(),
unionArray))
{
boolNodeFlag = true;
diff --git a/src/compiler/translator/IntermNode.h b/src/compiler/translator/IntermNode.h
index 81231c0..8fd8f45 100644
--- a/src/compiler/translator/IntermNode.h
+++ b/src/compiler/translator/IntermNode.h
@@ -293,7 +293,7 @@
virtual void traverse(TIntermTraverser *);
virtual bool replaceChildNode(TIntermNode *, TIntermNode *) { return false; }
- TIntermTyped *fold(TOperator, TIntermTyped *, TInfoSink &);
+ TIntermTyped *fold(TOperator op, TIntermConstantUnion *rightNode, TInfoSink &infoSink);
protected:
ConstantUnion *mUnionArrayPointer;
diff --git a/src/compiler/translator/Intermediate.cpp b/src/compiler/translator/Intermediate.cpp
index ecc6915..ebe6c39 100644
--- a/src/compiler/translator/Intermediate.cpp
+++ b/src/compiler/translator/Intermediate.cpp
@@ -143,7 +143,7 @@
if (childTempConstant)
{
- TIntermTyped *newChild = childTempConstant->fold(op, 0, mInfoSink);
+ TIntermTyped *newChild = childTempConstant->fold(op, nullptr, mInfoSink);
if (newChild)
return newChild;