Add support for folding non-square matrix multiplications.
TRAC #23081
Signed-off-by: Geoff Lang
Signed-off-by: Nicolas Capens
Author: Jamie Madill
git-svn-id: https://angleproject.googlecode.com/svn/branches/es3proto@2396 736b8ea6-26fd-11df-bfd4-992fa37f6226
diff --git a/src/compiler/Intermediate.cpp b/src/compiler/Intermediate.cpp
index b81ed4f..adf35dd 100644
--- a/src/compiler/Intermediate.cpp
+++ b/src/compiler/Intermediate.cpp
@@ -1248,17 +1248,29 @@
return 0;
}
- int cols = getCols();
- int rows = getRows();
- tempConstArray = new ConstantUnion[cols*rows];
- for (int row = 0; row < rows; row++) {
- for (int column = 0; column < cols; column++) {
- tempConstArray[rows * column + row].setFConst(0.0f);
- for (int i = 0; i < cols; i++) {
- tempConstArray[rows * column + row].setFConst(tempConstArray[rows * column + row].getFConst() + unionArray[i * rows + row].getFConst() * (rightUnionArray[column * rows + i].getFConst()));
+ const int leftCols = getCols();
+ const int leftRows = getRows();
+ const int rightCols = constantNode->getType().getCols();
+ const int rightRows = constantNode->getType().getRows();
+ const int resultCols = rightCols;
+ const int resultRows = leftRows;
+
+ tempConstArray = new ConstantUnion[resultCols*resultRows];
+ for (int row = 0; row < resultRows; row++)
+ {
+ for (int column = 0; column < resultCols; column++)
+ {
+ tempConstArray[resultRows * column + row].setFConst(0.0f);
+ for (int i = 0; i < leftCols; i++)
+ {
+ tempConstArray[resultRows * column + row].setFConst(tempConstArray[resultRows * column + row].getFConst() + unionArray[i * leftRows + row].getFConst() * (rightUnionArray[column * rightRows + i].getFConst()));
}
}
}
+
+ // update return type for matrix product
+ returnType.setPrimarySize(resultCols);
+ returnType.setSecondarySize(resultRows);
}
break;
@@ -1308,16 +1320,25 @@
infoSink.info.message(EPrefixInternalError, "Constant Folding cannot be done for matrix times vector", getLine());
return 0;
}
- tempConstArray = new ConstantUnion[getNominalSize()];
- for (int size = getNominalSize(), i = 0; i < size; i++) {
- tempConstArray[i].setFConst(0.0f);
- for (int j = 0; j < size; j++) {
- tempConstArray[i].setFConst(tempConstArray[i].getFConst() + ((unionArray[j*size + i].getFConst()) * rightUnionArray[j].getFConst()));
+ const int matrixCols = getCols();
+ const int matrixRows = getRows();
+
+ tempConstArray = new ConstantUnion[matrixRows];
+
+ for (int matrixRow = 0; matrixRow < matrixRows; matrixRow++)
+ {
+ tempConstArray[matrixRow].setFConst(0.0f);
+ for (int col = 0; col < matrixCols; col++)
+ {
+ tempConstArray[matrixRow].setFConst(tempConstArray[matrixRow].getFConst() + ((unionArray[col * matrixRows + matrixRow].getFConst()) * rightUnionArray[col].getFConst()));
}
}
- tempNode = new TIntermConstantUnion(tempConstArray, node->getType());
+ returnType = node->getType();
+ returnType.setPrimarySize(matrixRows);
+
+ tempNode = new TIntermConstantUnion(tempConstArray, returnType);
tempNode->setLine(getLine());
return tempNode;
@@ -1331,15 +1352,21 @@
return 0;
}
- tempConstArray = new ConstantUnion[getNominalSize()];
- for (int size = getNominalSize(), i = 0; i < size; i++)
+ const int matrixCols = constantNode->getType().getCols();
+ const int matrixRows = constantNode->getType().getRows();
+
+ tempConstArray = new ConstantUnion[matrixCols];
+
+ for (int matrixCol = 0; matrixCol < matrixCols; matrixCol++)
{
- tempConstArray[i].setFConst(0.0f);
- for (int j = 0; j < size; j++)
+ tempConstArray[matrixCol].setFConst(0.0f);
+ for (int matrixRow = 0; matrixRow < matrixRows; matrixRow++)
{
- tempConstArray[i].setFConst(tempConstArray[i].getFConst() + ((unionArray[j].getFConst()) * rightUnionArray[i*size + j].getFConst()));
+ tempConstArray[matrixCol].setFConst(tempConstArray[matrixCol].getFConst() + ((unionArray[matrixRow].getFConst()) * rightUnionArray[matrixCol * matrixRows + matrixRow].getFConst()));
}
}
+
+ returnType.setPrimarySize(matrixCols);
}
break;