Clean up binary operation constant folding code

Fix mixed up comments, remove unnecessary type conversions, clarify
variable names and improve formatting in a few places.

TEST=angle_unittests, WebGL conformance tests
BUG=angleproject:913

Change-Id: Ice8fe3682d8e97f42747752302a1fba116132df4
Reviewed-on: https://chromium-review.googlesource.com/266843
Reviewed-by: Jamie Madill <jmadill@chromium.org>
Tested-by: Olli Etuaho <oetuaho@nvidia.com>
Reviewed-by: Zhenyao Mo <zmo@chromium.org>
diff --git a/src/compiler/translator/IntermNode.cpp b/src/compiler/translator/IntermNode.cpp
index 67ec256..4d6da5d 100644
--- a/src/compiler/translator/IntermNode.cpp
+++ b/src/compiler/translator/IntermNode.cpp
@@ -666,7 +666,7 @@
 // Returns the node to keep using, which may or may not be the node passed in.
 //
 TIntermTyped *TIntermConstantUnion::fold(
-    TOperator op, TIntermTyped *constantNode, TInfoSink &infoSink)
+    TOperator op, TIntermConstantUnion *rightNode, TInfoSink &infoSink)
 {
     ConstantUnion *unionArray = getUnionArrayPointer();
 
@@ -675,39 +675,38 @@
 
     size_t objectSize = getType().getObjectSize();
 
-    if (constantNode)
+    if (rightNode)
     {
         // binary operations
-        TIntermConstantUnion *node = constantNode->getAsConstantUnion();
-        ConstantUnion *rightUnionArray = node->getUnionArrayPointer();
+        ConstantUnion *rightUnionArray = rightNode->getUnionArrayPointer();
         TType returnType = getType();
 
         if (!rightUnionArray)
             return nullptr;
 
-        // for a case like float f = 1.2 + vec4(2,3,4,5);
-        if (constantNode->getType().getObjectSize() == 1 && objectSize > 1)
+        // for a case like float f = vec4(2, 3, 4, 5) + 1.2;
+        if (rightNode->getType().getObjectSize() == 1 && objectSize > 1)
         {
             rightUnionArray = new ConstantUnion[objectSize];
             for (size_t i = 0; i < objectSize; ++i)
             {
-                rightUnionArray[i] = *node->getUnionArrayPointer();
+                rightUnionArray[i] = *rightNode->getUnionArrayPointer();
             }
             returnType = getType();
         }
-        else if (constantNode->getType().getObjectSize() > 1 && objectSize == 1)
+        else if (rightNode->getType().getObjectSize() > 1 && objectSize == 1)
         {
-            // for a case like float f = vec4(2,3,4,5) + 1.2;
-            unionArray = new ConstantUnion[constantNode->getType().getObjectSize()];
-            for (size_t i = 0; i < constantNode->getType().getObjectSize(); ++i)
+            // for a case like float f = 1.2 + vec4(2, 3, 4, 5);
+            unionArray = new ConstantUnion[rightNode->getType().getObjectSize()];
+            for (size_t i = 0; i < rightNode->getType().getObjectSize(); ++i)
             {
                 unionArray[i] = *getUnionArrayPointer();
             }
-            returnType = node->getType();
-            objectSize = constantNode->getType().getObjectSize();
+            returnType = rightNode->getType();
+            objectSize = rightNode->getType().getObjectSize();
         }
 
-        ConstantUnion *tempConstArray = NULL;
+        ConstantUnion *tempConstArray = nullptr;
         TIntermConstantUnion *tempNode;
 
         bool boolNodeFlag = false;
@@ -735,7 +734,7 @@
           case EOpMatrixTimesMatrix:
             {
                 if (getType().getBasicType() != EbtFloat ||
-                    node->getBasicType() != EbtFloat)
+                    rightNode->getBasicType() != EbtFloat)
                 {
                     infoSink.info.message(
                         EPrefixInternalError, getLine(),
@@ -745,12 +744,12 @@
 
                 const int leftCols = getCols();
                 const int leftRows = getRows();
-                const int rightCols = constantNode->getType().getCols();
-                const int rightRows = constantNode->getType().getRows();
+                const int rightCols = rightNode->getType().getCols();
+                const int rightRows = rightNode->getType().getRows();
                 const int resultCols = rightCols;
                 const int resultRows = leftRows;
 
-                tempConstArray = new ConstantUnion[resultCols*resultRows];
+                tempConstArray = new ConstantUnion[resultCols * resultRows];
                 for (int row = 0; row < resultRows; row++)
                 {
                     for (int column = 0; column < resultCols; column++)
@@ -862,7 +861,7 @@
 
           case EOpMatrixTimesVector:
             {
-                if (node->getBasicType() != EbtFloat)
+                if (rightNode->getBasicType() != EbtFloat)
                 {
                     infoSink.info.message(
                         EPrefixInternalError, getLine(),
@@ -887,7 +886,7 @@
                     }
                 }
 
-                returnType = node->getType();
+                returnType = rightNode->getType();
                 returnType.setPrimarySize(static_cast<unsigned char>(matrixRows));
 
                 tempNode = new TIntermConstantUnion(tempConstArray, returnType);
@@ -906,8 +905,8 @@
                     return nullptr;
                 }
 
-                const int matrixCols = constantNode->getType().getCols();
-                const int matrixRows = constantNode->getType().getRows();
+                const int matrixCols = rightNode->getType().getCols();
+                const int matrixRows = rightNode->getType().getRows();
 
                 tempConstArray = new ConstantUnion[matrixCols];
 
@@ -1035,8 +1034,8 @@
           case EOpEqual:
             if (getType().getBasicType() == EbtStruct)
             {
-                if (!CompareStructure(node->getType(),
-                                      node->getUnionArrayPointer(),
+                if (!CompareStructure(rightNode->getType(),
+                                      rightNode->getUnionArrayPointer(),
                                       unionArray))
                 {
                     boolNodeFlag = true;
@@ -1073,8 +1072,8 @@
           case EOpNotEqual:
             if (getType().getBasicType() == EbtStruct)
             {
-                if (CompareStructure(node->getType(),
-                                     node->getUnionArrayPointer(),
+                if (CompareStructure(rightNode->getType(),
+                                     rightNode->getUnionArrayPointer(),
                                      unionArray))
                 {
                     boolNodeFlag = true;
diff --git a/src/compiler/translator/IntermNode.h b/src/compiler/translator/IntermNode.h
index 81231c0..8fd8f45 100644
--- a/src/compiler/translator/IntermNode.h
+++ b/src/compiler/translator/IntermNode.h
@@ -293,7 +293,7 @@
     virtual void traverse(TIntermTraverser *);
     virtual bool replaceChildNode(TIntermNode *, TIntermNode *) { return false; }
 
-    TIntermTyped *fold(TOperator, TIntermTyped *, TInfoSink &);
+    TIntermTyped *fold(TOperator op, TIntermConstantUnion *rightNode, TInfoSink &infoSink);
 
   protected:
     ConstantUnion *mUnionArrayPointer;
diff --git a/src/compiler/translator/Intermediate.cpp b/src/compiler/translator/Intermediate.cpp
index ecc6915..ebe6c39 100644
--- a/src/compiler/translator/Intermediate.cpp
+++ b/src/compiler/translator/Intermediate.cpp
@@ -143,7 +143,7 @@
 
     if (childTempConstant)
     {
-        TIntermTyped *newChild = childTempConstant->fold(op, 0, mInfoSink);
+        TIntermTyped *newChild = childTempConstant->fold(op, nullptr, mInfoSink);
 
         if (newChild)
             return newChild;