[MLIR] Extend vectorization to 2+-D patterns
This CL adds support for vectorization using more interesting 2-D and 3-D
patterns. Note in particular the fact that we match some pretty complex
imperfectly nested 2-D patterns with a quite minimal change to the
implementation: we just add a bit of recursion to traverse the matched
patterns and actually vectorize the loops.
For instance, vectorizing the following loop by 128:
```
for %i3 = 0 to %0 {
%7 = affine_apply (d0) -> (d0)(%i3)
%8 = load %arg0[%c0_0, %7] : memref<?x?xf32>
}
```
Currently generates:
```
#map0 = ()[s0] -> (s0 + 127)
#map1 = (d0) -> (d0)
for %i3 = 0 to #map0()[%0] step 128 {
%9 = affine_apply #map1(%i3)
%10 = alloc() : memref<1xvector<128xf32>>
%11 = "n_d_unaligned_load"(%arg0, %c0_0, %9, %10, %c0) :
(memref<?x?xf32>, index, index, memref<1xvector<128xf32>>, index) ->
(memref<?x?xf32>, index, index, memref<1xvector<128xf32>>, index)
%12 = load %10[%c0] : memref<1xvector<128xf32>>
}
```
The above is subject to evolution.
PiperOrigin-RevId: 219629745
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index 1904a63..ce0bc6c 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -119,12 +119,12 @@
}
bool mlir::isAccessInvariant(const MLValue &input, MemRefType memRefType,
- ArrayRef<MLValue *> indices, unsigned dim) {
+ ArrayRef<const MLValue *> indices, unsigned dim) {
assert(indices.size() == memRefType.getRank());
assert(dim < indices.size());
auto layoutMap = memRefType.getAffineMaps();
assert(memRefType.getAffineMaps().size() <= 1);
- // TODO(ntv): remove dependency on Builder once we support non-identity
+ // TODO(ntv): remove dependence on Builder once we support non-identity
// layout map.
Builder b(memRefType.getContext());
assert(layoutMap.empty() ||
@@ -132,7 +132,8 @@
(void)layoutMap;
SmallVector<OperationStmt *, 4> affineApplyOps;
- getReachableAffineApplyOps({indices[dim]}, affineApplyOps);
+ getReachableAffineApplyOps({const_cast<MLValue *>(indices[dim])},
+ affineApplyOps);
if (affineApplyOps.empty()) {
// Pointer equality test because of MLValue pointer semantics.
@@ -168,7 +169,7 @@
LoadOrStoreOpPointer memoryOp,
unsigned fastestVaryingDim) {
using namespace functional;
- auto indices = map([](SSAValue *val) { return dyn_cast<MLValue>(val); },
+ auto indices = map([](const SSAValue *val) { return dyn_cast<MLValue>(val); },
memoryOp->getIndices());
auto memRefType = memoryOp->getMemRefType();
for (unsigned d = 0, numIndices = indices.size(); d < numIndices; ++d) {
@@ -188,7 +189,11 @@
return memRefType.getElementType().template isa<VectorType>();
}
-bool mlir::isVectorizableLoop(const ForStmt &loop, unsigned fastestVaryingDim) {
+using VectorizableStmtFun =
+ std::function<bool(const ForStmt &, const OperationStmt &)>;
+
+static bool isVectorizableLoopWithCond(const ForStmt &loop,
+ VectorizableStmtFun isVectorizableStmt) {
if (!matcher::isParallelLoop(loop) && !matcher::isReductionLoop(loop)) {
return false;
}
@@ -214,15 +219,32 @@
if (vector) {
return false;
}
- bool contiguous = load ? isContiguousAccess(loop, load, fastestVaryingDim)
- : isContiguousAccess(loop, store, fastestVaryingDim);
- if (!contiguous) {
+ if (!isVectorizableStmt(loop, *op)) {
return false;
}
}
return true;
}
+bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim(
+ const ForStmt &loop, unsigned fastestVaryingDim) {
+ VectorizableStmtFun fun(
+ [fastestVaryingDim](const ForStmt &loop, const OperationStmt &op) {
+ auto load = op.dyn_cast<LoadOp>();
+ auto store = op.dyn_cast<StoreOp>();
+ return load ? isContiguousAccess(loop, load, fastestVaryingDim)
+ : isContiguousAccess(loop, store, fastestVaryingDim);
+ });
+ return isVectorizableLoopWithCond(loop, fun);
+}
+
+bool mlir::isVectorizableLoop(const ForStmt &loop) {
+ VectorizableStmtFun fun(
+ // TODO: implement me
+ [](const ForStmt &loop, const OperationStmt &op) { return true; });
+ return isVectorizableLoopWithCond(loop, fun);
+}
+
/// Checks whether SSA dominance would be violated if a for stmt's body
/// statements are shifted by the specified shifts. This method checks if a
/// 'def' and all its uses have the same shift factor.