[MLIR] Basic infrastructure for vectorization test
This CL implements a very simple loop vectorization **test** and the basic
infrastructure to support it.
The test simply consists in:
1. matching the loops in the MLFunction and all the Load/Store operations
nested under the loop;
2. testing whether all the Load/Store are contiguous along the innermost
memory dimension along that particular loop. If any reference is
non-contiguous (i.e. the ForStmt SSAValue appears in the expression), then
the loop is not-vectorizable.
The simple test above can gradually be extended with more interesting
behaviors to account for the fact that a layout permutation may exist that
enables contiguity etc. All these will come in due time but it is worthwhile
noting that the test already supports detection of outer-vetorizable loops.
In implementing this test, I also added a recursive MLFunctionMatcher and some
sugar that can capture patterns
such as `auto gemmLike = Doall(Doall(Red(LoadStore())))` and allows iterating
on the matched IR structures. For now it just uses in order traversal but
post-order DFS will be useful in the future once IR rewrites start occuring.
One may note that the memory management design decision follows a different
pattern from MLIR. After evaluating different designs and how they quickly
increase cognitive overhead, I decided to opt for the simplest solution in my
view: a class-wide (threadsafe) RAII context.
This way, a pass that needs MLFunctionMatcher can just have its own locally
scoped BumpPtrAllocator and everything is cleaned up when the pass is destroyed.
If passes are expected to have a longer lifetime, then the contexts can easily
be scoped inside the runOnMLFunction call and storage lifetime reduced.
Lastly, whatever the scope of threading (module, function, pass), this is
expected to also be future-proof wrt concurrency (but this is a detail atm).
PiperOrigin-RevId: 217622889
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index a17cb39..232d162 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -22,9 +22,14 @@
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/MLFunctionMatcher.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Statements.h"
+#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/MathExtras.h"
using namespace mlir;
@@ -113,3 +118,82 @@
// Trip count is not a known constant; return its largest known divisor.
return tripCountExpr.getLargestKnownDivisor();
}
+
+/// Given a MemRef accessed by `indices` and a dimension `dim`, determines
+/// whether indices[dim] is independent of the value `input`.
+// For now we assume no layout map or identity layout map in the MemRef.
+// TODO(ntv): support more than identity layout map.
+static bool isAccessInvariant(MLValue *input, MemRefType *memRefType,
+ ArrayRef<MLValue *> indices, unsigned dim) {
+ assert(indices.size() == memRefType->getRank());
+ assert(dim < indices.size());
+ auto layoutMap = memRefType->getAffineMaps();
+ assert(layoutMap.size() <= 1);
+ // TODO(ntv): remove dependency on Builder once we support non-identity
+ // layout map.
+ Builder b(memRefType->getContext());
+ assert(layoutMap.empty() ||
+ layoutMap[0] == b.getMultiDimIdentityMap(indices.size()));
+
+ SmallVector<OperationStmt *, 4> affineApplyOps;
+ getReachableAffineApplyOps({indices[dim]}, affineApplyOps);
+
+ if (affineApplyOps.empty()) {
+ // Pointer equality test because of MLValue pointer semantics.
+ return indices[dim] != input;
+ }
+
+ assert(affineApplyOps.size() == 1 &&
+ "CompositionAffineMapsPass must have "
+ "been run: there should be at most one AffineApplyOp");
+ auto composeOp = affineApplyOps[0]->getAs<AffineApplyOp>();
+ return !AffineValueMap(*composeOp).isFunctionOf(dim, input);
+}
+
+/// Determines whether a load or a store has a contiguous access along the
+/// value `input`. Contiguous is defined as either invariant or varying only
+/// along the fastest varying memory dimension.
+// TODO(ntv): allow more advanced notions of contiguity (non-fastest varying,
+// check strides, ...).
+template <typename LoadOrStoreOpPointer>
+static bool isContiguousAccess(MLValue *input, LoadOrStoreOpPointer memoryOp) {
+ auto indicesAsOperandIterators = memoryOp->getIndices();
+ auto *memRefType = cast<MemRefType>(memoryOp->getMemRef()->getType());
+ SmallVector<MLValue *, 4> indices;
+ for (auto *it : indicesAsOperandIterators) {
+ indices.push_back(cast<MLValue>(it));
+ }
+ unsigned numIndices = indices.size();
+ for (unsigned d = 0; d < numIndices - 1; ++d) {
+ if (!isAccessInvariant(input, memRefType, indices, d)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+/// Checks whether all the LoadOp and StoreOp matched have access indexing
+/// functions that are are either:
+/// 1. invariant along the `loop` induction variable;
+/// 2. varying along the fastest varying memory dimension only.
+// TODO(ntv): Also need to check the contiguous dimension to discriminate
+// between broadcast (i.e. stride 0), stride 1 and stride > 1 and return the
+// information so we can build a cost model.
+bool mlir::isVectorizableLoop(const ForStmt &loop) {
+ // TODO(ntv): check parallel or reduction loop semantics
+ using matcher::LoadStores;
+ auto *forStmt = &const_cast<ForStmt &>(loop);
+ auto loadAndStores = LoadStores();
+ auto &matches = loadAndStores.match(forStmt);
+ for (auto ls : matches) {
+ auto *op = cast<OperationStmt>(ls.first);
+ auto load = op->getAs<LoadOp>();
+ auto store = op->getAs<StoreOp>();
+ bool contiguous = load ? isContiguousAccess(forStmt, load)
+ : isContiguousAccess(forStmt, store);
+ if (!contiguous) {
+ return false;
+ }
+ }
+ return true;
+}