Add support for folding non-square matrix multiplications.

TRAC #23081

Signed-off-by: Geoff Lang
Signed-off-by: Nicolas Capens
Author: Jamie Madill

git-svn-id: https://angleproject.googlecode.com/svn/branches/es3proto@2396 736b8ea6-26fd-11df-bfd4-992fa37f6226
diff --git a/src/compiler/Intermediate.cpp b/src/compiler/Intermediate.cpp
index b81ed4f..adf35dd 100644
--- a/src/compiler/Intermediate.cpp
+++ b/src/compiler/Intermediate.cpp
@@ -1248,17 +1248,29 @@
                     return 0;
                 }
 
-                int cols = getCols();
-                int rows = getRows();
-                tempConstArray = new ConstantUnion[cols*rows];
-                for (int row = 0; row < rows; row++) {
-                    for (int column = 0; column < cols; column++) {
-                        tempConstArray[rows * column + row].setFConst(0.0f);
-                        for (int i = 0; i < cols; i++) {
-                            tempConstArray[rows * column + row].setFConst(tempConstArray[rows * column + row].getFConst() + unionArray[i * rows + row].getFConst() * (rightUnionArray[column * rows + i].getFConst()));
+                const int leftCols = getCols();
+                const int leftRows = getRows();
+                const int rightCols = constantNode->getType().getCols();
+                const int rightRows = constantNode->getType().getRows();
+                const int resultCols = rightCols;
+                const int resultRows = leftRows;
+
+                tempConstArray = new ConstantUnion[resultCols*resultRows];
+                for (int row = 0; row < resultRows; row++)
+                {
+                    for (int column = 0; column < resultCols; column++)
+                    {
+                        tempConstArray[resultRows * column + row].setFConst(0.0f);
+                        for (int i = 0; i < leftCols; i++)
+                        {
+                            tempConstArray[resultRows * column + row].setFConst(tempConstArray[resultRows * column + row].getFConst() + unionArray[i * leftRows + row].getFConst() * (rightUnionArray[column * rightRows + i].getFConst()));
                         }
                     }
                 }
+
+                // update return type for matrix product
+                returnType.setPrimarySize(resultCols);
+                returnType.setSecondarySize(resultRows);
             }
             break;
 
@@ -1308,16 +1320,25 @@
                     infoSink.info.message(EPrefixInternalError, "Constant Folding cannot be done for matrix times vector", getLine());
                     return 0;
                 }
-                tempConstArray = new ConstantUnion[getNominalSize()];
 
-                for (int size = getNominalSize(), i = 0; i < size; i++) {
-                    tempConstArray[i].setFConst(0.0f);
-                    for (int j = 0; j < size; j++) {
-                        tempConstArray[i].setFConst(tempConstArray[i].getFConst() + ((unionArray[j*size + i].getFConst()) * rightUnionArray[j].getFConst()));
+                const int matrixCols = getCols();
+                const int matrixRows = getRows();
+
+                tempConstArray = new ConstantUnion[matrixRows];
+
+                for (int matrixRow = 0; matrixRow < matrixRows; matrixRow++)
+                {
+                    tempConstArray[matrixRow].setFConst(0.0f);
+                    for (int col = 0; col < matrixCols; col++)
+                    {
+                        tempConstArray[matrixRow].setFConst(tempConstArray[matrixRow].getFConst() + ((unionArray[col * matrixRows + matrixRow].getFConst()) * rightUnionArray[col].getFConst()));
                     }
                 }
 
-                tempNode = new TIntermConstantUnion(tempConstArray, node->getType());
+                returnType = node->getType();
+                returnType.setPrimarySize(matrixRows);
+
+                tempNode = new TIntermConstantUnion(tempConstArray, returnType);
                 tempNode->setLine(getLine());
 
                 return tempNode;
@@ -1331,15 +1352,21 @@
                     return 0;
                 }
 
-                tempConstArray = new ConstantUnion[getNominalSize()];
-                for (int size = getNominalSize(), i = 0; i < size; i++)
+                const int matrixCols = constantNode->getType().getCols();
+                const int matrixRows = constantNode->getType().getRows();
+
+                tempConstArray = new ConstantUnion[matrixCols];
+
+                for (int matrixCol = 0; matrixCol < matrixCols; matrixCol++)
                 {
-                    tempConstArray[i].setFConst(0.0f);
-                    for (int j = 0; j < size; j++)
+                    tempConstArray[matrixCol].setFConst(0.0f);
+                    for (int matrixRow = 0; matrixRow < matrixRows; matrixRow++)
                     {
-                        tempConstArray[i].setFConst(tempConstArray[i].getFConst() + ((unionArray[j].getFConst()) * rightUnionArray[i*size + j].getFConst()));
+                        tempConstArray[matrixCol].setFConst(tempConstArray[matrixCol].getFConst() + ((unionArray[matrixRow].getFConst()) * rightUnionArray[matrixCol * matrixRows + matrixRow].getFConst()));
                     }
                 }
+
+                returnType.setPrimarySize(matrixCols);
             }
             break;