[MLIR] Support for vectorizing operations.
This CL adds support for and a vectorization test to perform scalar 2-D addf.
The support extension notably comprises:
1. extend vectorizable test to exclude vector_transfer operations and
expose them to LoopAnalysis where they are needed. This is a temporary
solution a concrete MLIR Op exists;
2. add some more functional sugar mapKeys, apply and ScopeGuard (which became
relevant again);
3. fix improper shifting during coarsening;
4. rename unaligned load/store to vector_transfer_read/write and simplify the
design removing the unnecessary AllocOp that were introduced prematurely:
vector_transfer_read currently has the form:
(memref<?x?x?xf32>, index, index, index) -> vector<32x64x256xf32>
vector_transfer_write currently has the form:
(vector<32x64x256xf32>, memref<?x?x?xf32>, index, index, index) -> ()
5. adds vectorizeOperations which traverses the operations in a ForStmt and
rewrites them to their vector form;
6. add support for vector splat from a constant.
The relevant tests are also updated.
PiperOrigin-RevId: 221421426
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index 3455783..78a8e2d 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -30,6 +30,7 @@
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/MathExtras.h"
+#include "llvm/ADT/SmallString.h"
using namespace mlir;
@@ -189,6 +190,14 @@
return memRefType.getElementType().template isa<VectorType>();
}
+// TODO(ntv): make the following into MLIR instructions, then use isa<>.
+static bool isVectorTransferReadOrWrite(const Statement &stmt) {
+ const auto *opStmt = cast<OperationStmt>(&stmt);
+ llvm::SmallString<16> name(opStmt->getName().getStringRef());
+ return name == kVectorTransferReadOpName ||
+ name == kVectorTransferWriteOpName;
+}
+
using VectorizableStmtFun =
std::function<bool(const ForStmt &, const OperationStmt &)>;
@@ -206,6 +215,12 @@
return false;
}
+ auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite);
+ auto vectorTransfersMatched = vectorTransfers.match(forStmt);
+ if (!vectorTransfersMatched.empty()) {
+ return false;
+ }
+
auto loadAndStores = matcher::Op(matcher::isLoadOrStore);
auto loadAndStoresMatched = loadAndStores.match(forStmt);
for (auto ls : loadAndStoresMatched) {