[MLIR] Basic infrastructure for vectorization test
This CL implements a very simple loop vectorization **test** and the basic
infrastructure to support it.
The test simply consists in:
1. matching the loops in the MLFunction and all the Load/Store operations
nested under the loop;
2. testing whether all the Load/Store are contiguous along the innermost
memory dimension along that particular loop. If any reference is
non-contiguous (i.e. the ForStmt SSAValue appears in the expression), then
the loop is not-vectorizable.
The simple test above can gradually be extended with more interesting
behaviors to account for the fact that a layout permutation may exist that
enables contiguity etc. All these will come in due time but it is worthwhile
noting that the test already supports detection of outer-vetorizable loops.
In implementing this test, I also added a recursive MLFunctionMatcher and some
sugar that can capture patterns
such as `auto gemmLike = Doall(Doall(Red(LoadStore())))` and allows iterating
on the matched IR structures. For now it just uses in order traversal but
post-order DFS will be useful in the future once IR rewrites start occuring.
One may note that the memory management design decision follows a different
pattern from MLIR. After evaluating different designs and how they quickly
increase cognitive overhead, I decided to opt for the simplest solution in my
view: a class-wide (threadsafe) RAII context.
This way, a pass that needs MLFunctionMatcher can just have its own locally
scoped BumpPtrAllocator and everything is cleaned up when the pass is destroyed.
If passes are expected to have a longer lifetime, then the contexts can easily
be scoped inside the runOnMLFunction call and storage lifetime reduced.
Lastly, whatever the scope of threading (module, function, pass), this is
expected to also be future-proof wrt concurrency (but this is a detail atm).
PiperOrigin-RevId: 217622889
diff --git a/lib/Analysis/AffineStructures.cpp b/lib/Analysis/AffineStructures.cpp
index f8bd126..173e8b5 100644
--- a/lib/Analysis/AffineStructures.cpp
+++ b/lib/Analysis/AffineStructures.cpp
@@ -21,6 +21,7 @@
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinOps.h"
@@ -242,11 +243,11 @@
// Returns true and sets 'indexOfMatch' if 'valueToMatch' is found in
// 'valuesToSearch'. Returns false otherwise.
static bool findIndex(MLValue *valueToMatch, ArrayRef<MLValue *> valuesToSearch,
- unsigned &indexOfMatch) {
+ unsigned *indexOfMatch) {
unsigned size = valuesToSearch.size();
for (unsigned i = 0; i < size; ++i) {
if (valueToMatch == valuesToSearch[i]) {
- indexOfMatch = i;
+ *indexOfMatch = i;
return true;
}
}
@@ -354,7 +355,7 @@
auto *inputOperand =
cast<MLValue>(const_cast<SSAValue *>(inputOp.getOperand(i)));
unsigned outputIndex;
- if (findIndex(inputOperand, outputOperands, outputIndex)) {
+ if (findIndex(inputOperand, outputOperands, &outputIndex)) {
mapUpdate.inputDimMap[i] = outputIndex;
} else {
mapUpdate.inputDimMap[i] = outputOperandPosition++;
@@ -387,7 +388,7 @@
cast<MLValue>(const_cast<SSAValue *>(inputOp.getOperand(i)));
// Find output operand index of 'inputOperand' dup.
unsigned outputIndex;
- if (findIndex(inputOperand, outputOperands, outputIndex)) {
+ if (findIndex(inputOperand, outputOperands, &outputIndex)) {
unsigned outputSymbolPosition = outputIndex - outputNumDims;
mapUpdate.inputSymbolMap[inputSymbolPosition] = outputSymbolPosition;
} else {
@@ -412,6 +413,17 @@
return map.isMultipleOf(idx, factor);
}
+/// This method uses the invariant that operands are always positionally aligned
+/// with the AffineDimExpr in the underlying AffineMap.
+bool AffineValueMap::isFunctionOf(unsigned idx, MLValue *value) const {
+ unsigned index;
+ findIndex(value, operands, &index);
+ auto expr = const_cast<AffineValueMap *>(this)->getAffineMap().getResult(idx);
+ // TODO(ntv): this is better implemented on a flattened representation.
+ // At least for now it is conservative.
+ return expr.isFunctionOfDim(index);
+}
+
unsigned AffineValueMap::getNumOperands() const { return operands.size(); }
SSAValue *AffineValueMap::getOperand(unsigned i) const {
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index a17cb39..232d162 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -22,9 +22,14 @@
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/MLFunctionMatcher.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Statements.h"
+#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/MathExtras.h"
using namespace mlir;
@@ -113,3 +118,82 @@
// Trip count is not a known constant; return its largest known divisor.
return tripCountExpr.getLargestKnownDivisor();
}
+
+/// Given a MemRef accessed by `indices` and a dimension `dim`, determines
+/// whether indices[dim] is independent of the value `input`.
+// For now we assume no layout map or identity layout map in the MemRef.
+// TODO(ntv): support more than identity layout map.
+static bool isAccessInvariant(MLValue *input, MemRefType *memRefType,
+ ArrayRef<MLValue *> indices, unsigned dim) {
+ assert(indices.size() == memRefType->getRank());
+ assert(dim < indices.size());
+ auto layoutMap = memRefType->getAffineMaps();
+ assert(layoutMap.size() <= 1);
+ // TODO(ntv): remove dependency on Builder once we support non-identity
+ // layout map.
+ Builder b(memRefType->getContext());
+ assert(layoutMap.empty() ||
+ layoutMap[0] == b.getMultiDimIdentityMap(indices.size()));
+
+ SmallVector<OperationStmt *, 4> affineApplyOps;
+ getReachableAffineApplyOps({indices[dim]}, affineApplyOps);
+
+ if (affineApplyOps.empty()) {
+ // Pointer equality test because of MLValue pointer semantics.
+ return indices[dim] != input;
+ }
+
+ assert(affineApplyOps.size() == 1 &&
+ "CompositionAffineMapsPass must have "
+ "been run: there should be at most one AffineApplyOp");
+ auto composeOp = affineApplyOps[0]->getAs<AffineApplyOp>();
+ return !AffineValueMap(*composeOp).isFunctionOf(dim, input);
+}
+
+/// Determines whether a load or a store has a contiguous access along the
+/// value `input`. Contiguous is defined as either invariant or varying only
+/// along the fastest varying memory dimension.
+// TODO(ntv): allow more advanced notions of contiguity (non-fastest varying,
+// check strides, ...).
+template <typename LoadOrStoreOpPointer>
+static bool isContiguousAccess(MLValue *input, LoadOrStoreOpPointer memoryOp) {
+ auto indicesAsOperandIterators = memoryOp->getIndices();
+ auto *memRefType = cast<MemRefType>(memoryOp->getMemRef()->getType());
+ SmallVector<MLValue *, 4> indices;
+ for (auto *it : indicesAsOperandIterators) {
+ indices.push_back(cast<MLValue>(it));
+ }
+ unsigned numIndices = indices.size();
+ for (unsigned d = 0; d < numIndices - 1; ++d) {
+ if (!isAccessInvariant(input, memRefType, indices, d)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+/// Checks whether all the LoadOp and StoreOp matched have access indexing
+/// functions that are are either:
+/// 1. invariant along the `loop` induction variable;
+/// 2. varying along the fastest varying memory dimension only.
+// TODO(ntv): Also need to check the contiguous dimension to discriminate
+// between broadcast (i.e. stride 0), stride 1 and stride > 1 and return the
+// information so we can build a cost model.
+bool mlir::isVectorizableLoop(const ForStmt &loop) {
+ // TODO(ntv): check parallel or reduction loop semantics
+ using matcher::LoadStores;
+ auto *forStmt = &const_cast<ForStmt &>(loop);
+ auto loadAndStores = LoadStores();
+ auto &matches = loadAndStores.match(forStmt);
+ for (auto ls : matches) {
+ auto *op = cast<OperationStmt>(ls.first);
+ auto load = op->getAs<LoadOp>();
+ auto store = op->getAs<StoreOp>();
+ bool contiguous = load ? isContiguousAccess(forStmt, load)
+ : isContiguousAccess(forStmt, store);
+ if (!contiguous) {
+ return false;
+ }
+ }
+ return true;
+}
diff --git a/lib/Analysis/MLFunctionMatcher.cpp b/lib/Analysis/MLFunctionMatcher.cpp
new file mode 100644
index 0000000..8739edb
--- /dev/null
+++ b/lib/Analysis/MLFunctionMatcher.cpp
@@ -0,0 +1,260 @@
+//===- MLFunctionMatcher.cpp - MLFunctionMatcher Impl ----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Analysis/MLFunctionMatcher.h"
+#include "mlir/StandardOps/StandardOps.h"
+
+#include "llvm/Support/Allocator.h"
+
+namespace mlir {
+
+/// Underlying storage for MLFunctionMatches.
+struct MLFunctionMatchesStorage {
+ MLFunctionMatchesStorage(MLFunctionMatches::EntryType e) : matches({e}) {}
+
+ SmallVector<MLFunctionMatches::EntryType, 8> matches;
+};
+
+/// Underlying storage for MLFunctionMatcher.
+struct MLFunctionMatcherStorage {
+ MLFunctionMatcherStorage(Statement::Kind k,
+ MutableArrayRef<MLFunctionMatcher> c,
+ FilterFunctionType filter)
+ : kind(k), childrenMLFunctionMatchers(c.begin(), c.end()),
+ filter(filter) {}
+
+ Statement::Kind kind;
+ SmallVector<MLFunctionMatcher, 4> childrenMLFunctionMatchers;
+ FilterFunctionType filter;
+};
+
+} // end namespace mlir
+
+using namespace mlir;
+
+llvm::BumpPtrAllocator *&MLFunctionMatches::allocator() {
+ static thread_local llvm::BumpPtrAllocator *allocator = nullptr;
+ return allocator;
+}
+
+void MLFunctionMatches::append(Statement *stmt, MLFunctionMatches children) {
+ if (!storage) {
+ storage = allocator()->Allocate<MLFunctionMatchesStorage>();
+ new (storage) MLFunctionMatchesStorage(std::make_pair(stmt, children));
+ } else {
+ storage->matches.push_back(std::make_pair(stmt, children));
+ }
+}
+MLFunctionMatches::iterator MLFunctionMatches::begin() {
+ return storage->matches.begin();
+}
+MLFunctionMatches::iterator MLFunctionMatches::end() {
+ return storage->matches.end();
+}
+
+/// Return the combination of multiple MLFunctionMatches as a new object.
+static MLFunctionMatches combine(ArrayRef<MLFunctionMatches> matches) {
+ MLFunctionMatches res;
+ for (auto s : matches) {
+ for (auto ss : s) {
+ res.append(ss.first, ss.second);
+ }
+ }
+ return res;
+}
+
+/// Calls walk on `function`.
+MLFunctionMatches &MLFunctionMatcher::match(MLFunction *function) {
+ assert(!matches && "MLFunctionMatcher already matched!");
+ this->walk(function);
+ return matches;
+}
+
+/// Calls walk on `statement`.
+MLFunctionMatches &MLFunctionMatcher::match(Statement *statement) {
+ assert(!matches && "MLFunctionMatcher already matched!");
+ this->walk(statement);
+ return matches;
+}
+
+/// matchOrSkipOne is needed so that we can implement match without switching on
+/// the type of the Statement.
+/// The idea is that a MLFunctionMatcher first checks if it matches locally and
+/// then recursively applies its children matchers to its elem->children.
+/// Since we want to rely on the StmtWalker impl rather than duplicate its
+/// the logic, we allow an off-by-one traversal to account for the fact that
+/// we write:
+///
+/// void match(Statement *elem) {
+/// for (auto &c : getChildrenMLFunctionMatchers()) {
+/// MLFunctionMatcher childMLFunctionMatcher(...);
+/// childMLFunctionMatcher.walk(elem); <~~~ Needs off-by-one traversal.
+///
+void MLFunctionMatcher::matchOrSkipOne(Statement *elem) {
+ if (skipOne) {
+ skipOne = false;
+ return;
+ }
+ matchOne(elem);
+}
+
+/// Matches a single statement in the following way:
+/// 1. checks the kind of statement against the matcher, if different then
+/// there is no match;
+/// 2. calls the customizable filter function to refine the single statement
+/// match with extra semantic constraints;
+/// 3. if all is good, recursivey matches the children patterns;
+/// 4. if all children match then the single statement matches too and is
+/// appended to the list of matches;
+/// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will
+/// want to traverse in post-order DFS to avoid invalidating iterators.
+void MLFunctionMatcher::matchOne(Statement *elem) {
+ // Structural filter
+ if (elem->getKind() != getKind()) {
+ return;
+ }
+ // Local custom filter function
+ if (!getFilterFunction()(elem)) {
+ return;
+ }
+ SmallVector<MLFunctionMatches, 8> childrenMLFunctionMatches;
+ for (auto &c : getChildrenMLFunctionMatchers()) {
+ /// We create a new childMLFunctionMatcher here because a matcher holds its
+ /// results So we concretely need multiple copies of a given matcher, one
+ /// for each matching result.
+ MLFunctionMatcher childMLFunctionMatcher = forkMLFunctionMatcher(c);
+ childMLFunctionMatcher.walk(elem);
+ if (!childMLFunctionMatcher.matches) {
+ return;
+ }
+ childrenMLFunctionMatches.push_back(childMLFunctionMatcher.matches);
+ }
+ matches.append(elem, combine(childrenMLFunctionMatches));
+}
+
+llvm::BumpPtrAllocator *&MLFunctionMatcher::allocator() {
+ static thread_local llvm::BumpPtrAllocator *allocator = nullptr;
+ return allocator;
+}
+
+MLFunctionMatcher::MLFunctionMatcher(Statement::Kind k, MLFunctionMatcher child,
+ FilterFunctionType filter)
+ : storage(allocator()->Allocate<MLFunctionMatcherStorage>()),
+ skipOne(false) {
+ // Initialize with placement new.
+ new (storage) MLFunctionMatcherStorage(k, {child}, filter);
+}
+
+MLFunctionMatcher::MLFunctionMatcher(
+ Statement::Kind k, MutableArrayRef<MLFunctionMatcher> children,
+ FilterFunctionType filter)
+ : storage(allocator()->Allocate<MLFunctionMatcherStorage>()),
+ skipOne(false) {
+ // Initialize with placement new.
+ new (storage) MLFunctionMatcherStorage(k, children, filter);
+}
+
+MLFunctionMatcher
+MLFunctionMatcher::forkMLFunctionMatcher(MLFunctionMatcher tmpl) {
+ MLFunctionMatcher res(tmpl.getKind(), tmpl.getChildrenMLFunctionMatchers(),
+ tmpl.getFilterFunction());
+ res.skipOne = true;
+ return res;
+}
+
+Statement::Kind MLFunctionMatcher::getKind() { return storage->kind; }
+
+MutableArrayRef<MLFunctionMatcher>
+MLFunctionMatcher::getChildrenMLFunctionMatchers() {
+ return storage->childrenMLFunctionMatchers;
+}
+
+FilterFunctionType MLFunctionMatcher::getFilterFunction() {
+ return storage->filter;
+}
+
+namespace mlir {
+namespace matcher {
+
+MLFunctionMatcher Op(FilterFunctionType filter) {
+ return MLFunctionMatcher(Statement::Kind::Operation, {}, filter);
+}
+
+MLFunctionMatcher If(MLFunctionMatcher child) {
+ return MLFunctionMatcher(Statement::Kind::If, child, defaultFilterFunction);
+}
+MLFunctionMatcher If(FilterFunctionType filter, MLFunctionMatcher child) {
+ return MLFunctionMatcher(Statement::Kind::If, child, filter);
+}
+MLFunctionMatcher If(MutableArrayRef<MLFunctionMatcher> children) {
+ return MLFunctionMatcher(Statement::Kind::If, children,
+ defaultFilterFunction);
+}
+MLFunctionMatcher If(FilterFunctionType filter,
+ MutableArrayRef<MLFunctionMatcher> children) {
+ return MLFunctionMatcher(Statement::Kind::If, children, filter);
+}
+
+MLFunctionMatcher For(MLFunctionMatcher child) {
+ return MLFunctionMatcher(Statement::Kind::For, child, defaultFilterFunction);
+}
+MLFunctionMatcher For(FilterFunctionType filter, MLFunctionMatcher child) {
+ return MLFunctionMatcher(Statement::Kind::For, child, filter);
+}
+MLFunctionMatcher For(MutableArrayRef<MLFunctionMatcher> children) {
+ return MLFunctionMatcher(Statement::Kind::For, children,
+ defaultFilterFunction);
+}
+MLFunctionMatcher For(FilterFunctionType filter,
+ MutableArrayRef<MLFunctionMatcher> children) {
+ return MLFunctionMatcher(Statement::Kind::For, children, filter);
+}
+
+// TODO(ntv): parallel annotation on loops.
+FilterFunctionType isParallelLoop = [](Statement *stmt) {
+ auto *loop = cast<ForStmt>(stmt);
+ return (void *)loop || true; // loop->isParallel();
+};
+MLFunctionMatcher Doall(MLFunctionMatcher child) {
+ return MLFunctionMatcher(Statement::Kind::For, child, isParallelLoop);
+}
+MLFunctionMatcher Doall(MutableArrayRef<MLFunctionMatcher> children) {
+ return MLFunctionMatcher(Statement::Kind::For, children, isParallelLoop);
+}
+
+// TODO(ntv): reduction annotation on loops.
+FilterFunctionType isReductionLoop = [](Statement *stmt) {
+ auto *loop = cast<ForStmt>(stmt);
+ return (void *)loop || true; // loop->isReduction();
+};
+MLFunctionMatcher Red(MLFunctionMatcher child) {
+ return MLFunctionMatcher(Statement::Kind::For, child, isReductionLoop);
+}
+MLFunctionMatcher Red(MutableArrayRef<MLFunctionMatcher> children) {
+ return MLFunctionMatcher(Statement::Kind::For, children, isReductionLoop);
+}
+
+FilterFunctionType isLoadOrStore = [](Statement *stmt) {
+ auto *opStmt = dyn_cast<OperationStmt>(stmt);
+ return opStmt && (opStmt->is<LoadOp>() || opStmt->is<StoreOp>());
+};
+MLFunctionMatcher LoadStores() {
+ return MLFunctionMatcher(Statement::Kind::Operation, {}, isLoadOrStore);
+}
+
+} // end namespace matcher
+} // end namespace mlir
diff --git a/lib/IR/AffineExpr.cpp b/lib/IR/AffineExpr.cpp
index 2966819..ef19d20 100644
--- a/lib/IR/AffineExpr.cpp
+++ b/lib/IR/AffineExpr.cpp
@@ -146,6 +146,17 @@
}
}
+bool AffineExpr::isFunctionOfDim(unsigned position) const {
+ if (getKind() == AffineExprKind::DimId) {
+ return *this == mlir::getAffineDimExpr(position, getContext());
+ }
+ if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
+ return expr.getLHS().isFunctionOfDim(position) ||
+ expr.getRHS().isFunctionOfDim(position);
+ }
+ return false;
+}
+
AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
: AffineExpr(ptr) {}
AffineExpr AffineBinaryOpExpr::getLHS() const {
diff --git a/lib/Transforms/Vectorize.cpp b/lib/Transforms/Vectorize.cpp
new file mode 100644
index 0000000..37c4f3a
--- /dev/null
+++ b/lib/Transforms/Vectorize.cpp
@@ -0,0 +1,76 @@
+//===- Vectorize.cpp - Vectorize Pass Impl ----------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements vectorization of loops, operations and data types to
+// a target-independent, n-D virtual vector abstraction.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/MLFunctionMatcher.h"
+#include "mlir/StandardOps/StandardOps.h"
+#include "mlir/Transforms/Pass.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+using namespace mlir;
+
+static cl::list<unsigned> clVirtualVectorSize(
+ "virtual-vector-size",
+ cl::desc("Specify n-D virtual vector size for vectorization"),
+ cl::ZeroOrMore);
+
+namespace {
+
+struct Vectorize : public MLFunctionPass {
+ PassResult runOnMLFunction(MLFunction *f) override;
+
+ // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
+ MLFunctionMatcherContext MLContext;
+};
+
+} // end anonymous namespace
+
+PassResult Vectorize::runOnMLFunction(MLFunction *f) {
+ using matcher::Doall;
+ /// TODO(ntv): support at least 4 cases for each load/store:
+ /// 1. invariant along the loop index -> 1-D vectorizable with broadcast
+ /// 2. contiguous along the fastest varying dimension wrt the loop index
+ /// -> a. 1-D vectorizable via stripmine/sink if loop is not innermost
+ /// -> b. 1-D vectorizable if loop is innermost
+ /// 3. contiguous along non-fastest varying dimension wrt the loop index
+ /// -> needs data layout + copy to vectorize 1-D
+ /// 4. not contiguous => not vectorizable
+ auto pointwiseLike = Doall();
+ auto &matches = pointwiseLike.match(f);
+ for (auto loop : matches) {
+ auto *doall = cast<ForStmt>(loop.first);
+ if (!isVectorizableLoop(*doall)) {
+ outs() << "\nNon-vectorizable loop: ";
+ doall->print(outs());
+ continue;
+ }
+ outs() << "\nVectorizable loop: ";
+ doall->print(outs());
+ }
+ return PassResult::Success;
+}
+
+MLFunctionPass *mlir::createVectorizePass() { return new Vectorize(); }