[MLIR] Add VectorTransferOps
This CL implements and uses VectorTransferOps in lieu of the former custom
call op. Tests are updated accordingly.
VectorTransferOps come in 2 flavors: VectorTransferReadOp and
VectorTransferWriteOp.
VectorTransferOps can be thought of as a backend-independent
pseudo op/library call that needs to be legalized to MLIR (whiteboxed) before
it can be lowered to backend-dependent IR.
Note that the current implementation does not yet support a real permutation
map. Proper support will come in a followup CL.
VectorTransferReadOp
====================
VectorTransferReadOp performs a blocking read from a scalar memref
location into a super-vector of the same elemental type. This operation is
called 'read' by opposition to 'load' because the super-vector granularity
is generally not representable with a single hardware register. As a
consequence, memory transfers will generally be required when lowering
VectorTransferReadOp. A VectorTransferReadOp is thus a mid-level abstraction
that supports super-vectorization with non-effecting padding for full-tile
only code.
A vector transfer read has semantics similar to a vector load, with additional
support for:
1. an optional value of the elemental type of the MemRef. This value
supports non-effecting padding and is inserted in places where the
vector read exceeds the MemRef bounds. If the value is not specified,
the access is statically guaranteed to be within bounds;
2. an attribute of type AffineMap to specify a slice of the original
MemRef access and its transposition into the super-vector shape. The
permutation_map is an unbounded AffineMap that must represent a
permutation from the MemRef dim space projected onto the vector dim
space.
Example:
```mlir
%A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>
...
%val = `ssa-value` : f32
// let %i, %j, %k, %l be ssa-values of type index
%v0 = vector_transfer_read %src, %i, %j, %k, %l
{permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
(memref<?x?x?x?xf32>, index, index, index, index) ->
vector<16x32x64xf32>
%v1 = vector_transfer_read %src, %i, %j, %k, %l, %val
{permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
(memref<?x?x?x?xf32>, index, index, index, index, f32) ->
vector<16x32x64xf32>
```
VectorTransferWriteOp
=====================
VectorTransferWriteOp performs a blocking write from a super-vector to
a scalar memref of the same elemental type. This operation is
called 'write' by opposition to 'store' because the super-vector
granularity is generally not representable with a single hardware register. As
a consequence, memory transfers will generally be required when lowering
VectorTransferWriteOp. A VectorTransferWriteOp is thus a mid-level
abstraction that supports super-vectorization with non-effecting padding
for full-tile only code.
A vector transfer write has semantics similar to a vector store, with
additional support for handling out-of-bounds situations.
Example:
```mlir
%A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>.
%val = `ssa-value` : vector<16x32x64xf32>
// let %i, %j, %k, %l be ssa-values of type index
vector_transfer_write %val, %src, %i, %j, %k, %l
{permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
(vector<16x32x64xf32>, memref<?x?x?x?xf32>, index, index, index, index)
```
PiperOrigin-RevId: 223873234
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index 8406a37..de98849 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -194,7 +194,8 @@
// TODO(ntv): make the following into MLIR instructions, then use isa<>.
static bool isVectorTransferReadOrWrite(const Statement &stmt) {
const auto *opStmt = cast<OperationStmt>(&stmt);
- return isaVectorTransferRead(*opStmt) || isaVectorTransferWrite(*opStmt);
+ return opStmt->isa<VectorTransferReadOp>() ||
+ opStmt->isa<VectorTransferWriteOp>();
}
using VectorizableStmtFun =
diff --git a/lib/Analysis/VectorAnalysis.cpp b/lib/Analysis/VectorAnalysis.cpp
index 75f6229..9c2160c 100644
--- a/lib/Analysis/VectorAnalysis.cpp
+++ b/lib/Analysis/VectorAnalysis.cpp
@@ -18,6 +18,7 @@
#include "mlir/Analysis/VectorAnalysis.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Statements.h"
+#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/STLExtras.h"
@@ -28,14 +29,6 @@
using namespace mlir;
-bool mlir::isaVectorTransferRead(const OperationStmt &stmt) {
- return stmt.getName().getStringRef().str() == kVectorTransferReadOpName;
-}
-
-bool mlir::isaVectorTransferWrite(const OperationStmt &stmt) {
- return stmt.getName().getStringRef().str() == kVectorTransferWriteOpName;
-}
-
Optional<SmallVector<unsigned, 4>> mlir::shapeRatio(ArrayRef<int> superShape,
ArrayRef<int> subShape) {
if (superShape.size() < subShape.size()) {
@@ -83,6 +76,20 @@
return shapeRatio(superVectorType.getShape(), subVectorType.getShape());
}
+AffineMap mlir::makePermutationMap(MemRefType memrefType,
+ VectorType vectorType) {
+ unsigned memRefRank = memrefType.getRank();
+ unsigned vectorRank = vectorType.getRank();
+ assert(memRefRank >= vectorRank && "Broadcast not supported");
+ unsigned offset = memRefRank - vectorRank;
+ SmallVector<AffineExpr, 4> perm;
+ perm.reserve(memRefRank);
+ for (unsigned i = 0; i < vectorRank; ++i) {
+ perm.push_back(getAffineDimExpr(offset + i, memrefType.getContext()));
+ }
+ return AffineMap::get(memRefRank, 0, perm, {});
+}
+
bool mlir::matcher::operatesOnStrictSuperVectors(const OperationStmt &opStmt,
VectorType subVectorType) {
// First, extract the vector type and ditinguish between:
@@ -96,15 +103,11 @@
/// do not have to special case. Maybe a trait, or just a method, unclear atm.
bool mustDivide = false;
VectorType superVectorType;
- if (isaVectorTransferRead(opStmt)) {
- superVectorType = opStmt.getResult(0)->getType().cast<VectorType>();
+ if (auto read = opStmt.dyn_cast<VectorTransferReadOp>()) {
+ superVectorType = read->getResultType();
mustDivide = true;
- } else if (isaVectorTransferWrite(opStmt)) {
- // TODO(ntv): if vector_transfer_write had store-like semantics we could
- // have written something similar to:
- // auto store = storeOp->cast<StoreOp>();
- // auto *value = store->getValueToStore();
- superVectorType = opStmt.getOperand(0)->getType().cast<VectorType>();
+ } else if (auto write = opStmt.dyn_cast<VectorTransferWriteOp>()) {
+ superVectorType = write->getVectorType();
mustDivide = true;
} else if (opStmt.getNumResults() == 0) {
assert(opStmt.isa<ReturnOp>() &&
diff --git a/lib/StandardOps/StandardOps.cpp b/lib/StandardOps/StandardOps.cpp
index 4de951a..4d71dde 100644
--- a/lib/StandardOps/StandardOps.cpp
+++ b/lib/StandardOps/StandardOps.cpp
@@ -40,7 +40,8 @@
addOperations<AddFOp, AddIOp, AllocOp, CallOp, CallIndirectOp, CmpIOp,
DeallocOp, DimOp, DmaStartOp, DmaWaitOp, ExtractElementOp,
LoadOp, MemRefCastOp, MulFOp, MulIOp, SelectOp, StoreOp, SubFOp,
- SubIOp, TensorCastOp>();
+ SubIOp, TensorCastOp, VectorTransferReadOp,
+ VectorTransferWriteOp>();
}
//===----------------------------------------------------------------------===//
@@ -1321,3 +1322,427 @@
return false;
}
+
+//===----------------------------------------------------------------------===//
+// VectorTransferReadOp
+//===----------------------------------------------------------------------===//
+template <typename EmitFun>
+static bool verifyPermutationMap(AffineMap permutationMap,
+ EmitFun emitOpError) {
+ SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
+ for (auto expr : permutationMap.getResults()) {
+ auto dim = expr.dyn_cast<AffineDimExpr>();
+ if (!dim) {
+ return emitOpError(
+ "requires a permutation_map that is an actual permutation");
+ }
+ if (seen[dim.getPosition()]) {
+ return emitOpError(
+ "requires a permutation_map that is a full column-rank "
+ "permutation (i.e. a permutation composed with an "
+ "orthogonal projection)");
+ }
+ seen[dim.getPosition()] = true;
+ }
+ return false;
+}
+
+void VectorTransferReadOp::build(Builder *builder, OperationState *result,
+ VectorType vectorType, SSAValue *srcMemRef,
+ ArrayRef<SSAValue *> srcIndices,
+ AffineMap permutationMap,
+ Optional<SSAValue *> paddingValue) {
+ result->addOperands(srcMemRef);
+ result->addOperands(srcIndices);
+ if (paddingValue) {
+ result->addOperands({*paddingValue});
+ }
+ result->addAttribute(getPermutationMapAttrName(),
+ builder->getAffineMapAttr(permutationMap));
+ result->addTypes(vectorType);
+}
+
+llvm::iterator_range<Operation::operand_iterator>
+VectorTransferReadOp::getIndices() {
+ auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
+ auto end = begin + getMemRefType().getRank();
+ return {begin, end};
+}
+
+llvm::iterator_range<Operation::const_operand_iterator>
+VectorTransferReadOp::getIndices() const {
+ auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
+ auto end = begin + getMemRefType().getRank();
+ return {begin, end};
+}
+
+Optional<SSAValue *> VectorTransferReadOp::getPaddingValue() {
+ auto memRefRank = getMemRefType().getRank();
+ if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
+ return None;
+ }
+ return Optional<SSAValue *>(
+ getOperand(Offsets::FirstIndexOffset + memRefRank));
+}
+
+Optional<const SSAValue *> VectorTransferReadOp::getPaddingValue() const {
+ auto memRefRank = getMemRefType().getRank();
+ if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
+ return None;
+ }
+ return Optional<const SSAValue *>(
+ getOperand(Offsets::FirstIndexOffset + memRefRank));
+}
+
+AffineMap VectorTransferReadOp::getPermutationMap() const {
+ return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
+}
+
+void VectorTransferReadOp::print(OpAsmPrinter *p) const {
+ *p << getOperationName() << " ";
+ p->printOperand(getMemRef());
+ *p << ", ";
+ p->printOperands(getIndices());
+ auto optionalPaddingValue = getPaddingValue();
+ if (optionalPaddingValue) {
+ *p << ", ";
+ p->printOperand(*optionalPaddingValue);
+ }
+ p->printOptionalAttrDict(getAttrs());
+ // Construct the FunctionType and print it.
+ llvm::SmallVector<Type, 8> inputs{getMemRefType()};
+ // Must have at least one actual index, see verify.
+ const SSAValue *firstIndex = *(getIndices().begin());
+ Type indexType = firstIndex->getType();
+ inputs.append(getMemRefType().getRank(), indexType);
+ if (optionalPaddingValue) {
+ inputs.push_back((*optionalPaddingValue)->getType());
+ }
+ *p << " : "
+ << FunctionType::get(inputs, {getResultType()}, indexType.getContext());
+}
+
+bool VectorTransferReadOp::parse(OpAsmParser *parser, OperationState *result) {
+ SmallVector<OpAsmParser::OperandType, 8> parsedOperands;
+ Type type;
+
+ // Parsing with support for optional paddingValue.
+ auto fail = parser->parseOperandList(parsedOperands) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type);
+ if (fail) {
+ return true;
+ }
+
+ // Resolution.
+ auto funType = type.dyn_cast<FunctionType>();
+ if (!funType) {
+ parser->emitError(parser->getNameLoc(), "Function type expected");
+ return true;
+ }
+ if (funType.getNumInputs() < 1) {
+ parser->emitError(parser->getNameLoc(),
+ "Function type expects at least one input");
+ return true;
+ }
+ MemRefType memrefType =
+ funType.getInput(Offsets::MemRefOffset).dyn_cast<MemRefType>();
+ if (!memrefType) {
+ parser->emitError(parser->getNameLoc(),
+ "MemRef type expected for first input");
+ return true;
+ }
+ if (funType.getNumResults() < 1) {
+ parser->emitError(parser->getNameLoc(),
+ "Function type expects exactly one vector result");
+ return true;
+ }
+ VectorType vectorType = funType.getResult(0).dyn_cast<VectorType>();
+ if (!vectorType) {
+ parser->emitError(parser->getNameLoc(),
+ "Vector type expected for first result");
+ return true;
+ }
+ if (parsedOperands.size() != funType.getNumInputs()) {
+ parser->emitError(parser->getNameLoc(), "requires " +
+ Twine(funType.getNumInputs()) +
+ " operands");
+ return true;
+ }
+
+ // Extract optional paddingValue.
+ OpAsmParser::OperandType memrefInfo = parsedOperands[0];
+ // At this point, indexInfo may contain the optional paddingValue, pop it out.
+ SmallVector<OpAsmParser::OperandType, 8> indexInfo{
+ parsedOperands.begin() + Offsets::FirstIndexOffset, parsedOperands.end()};
+ Type paddingType;
+ OpAsmParser::OperandType paddingValue;
+ bool hasPaddingValue = indexInfo.size() > memrefType.getRank();
+ unsigned expectedNumOperands = Offsets::FirstIndexOffset +
+ memrefType.getRank() +
+ (hasPaddingValue ? 1 : 0);
+ if (hasPaddingValue) {
+ paddingType = funType.getInputs().back();
+ paddingValue = indexInfo.pop_back_val();
+ }
+ if (funType.getNumInputs() != expectedNumOperands) {
+ parser->emitError(
+ parser->getNameLoc(),
+ "requires actual number of operands to match function type");
+ return true;
+ }
+
+ auto indexType = parser->getBuilder().getIndexType();
+ return parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
+ parser->resolveOperands(indexInfo, indexType, result->operands) ||
+ (hasPaddingValue && parser->resolveOperand(paddingValue, paddingType,
+ result->operands)) ||
+ parser->addTypeToList(vectorType, result->types);
+}
+
+bool VectorTransferReadOp::verify() const {
+ // Consistency of memref type in function type.
+ if (llvm::empty(getOperands())) {
+ return emitOpError(
+ "requires at least a memref operand followed by 'rank' indices");
+ }
+ if (!getMemRef()->getType().isa<MemRefType>()) {
+ return emitOpError("requires a memref as first operand");
+ }
+ // Consistency of vector type in function type.
+ if (!getResult()->getType().isa<VectorType>()) {
+ return emitOpError("should have a vector result type in function type: "
+ "(memref_type [, elemental_type]) -> vector_type");
+ }
+ // Consistency of elemental types in memref and vector.
+ MemRefType memrefType = getMemRefType();
+ VectorType vectorType = getResultType();
+ if (memrefType.getElementType() != vectorType.getElementType())
+ return emitOpError(
+ "requires memref and vector types of the same elemental type");
+ // Consistency of number of input types.
+ auto optionalPaddingValue = getPaddingValue();
+ unsigned expectedNumOperands = Offsets::FirstIndexOffset +
+ memrefType.getRank() +
+ (optionalPaddingValue ? 1 : 0);
+ // Checks on the actual operands and their types.
+ if (getNumOperands() != expectedNumOperands) {
+ return emitOpError("expects " + Twine(expectedNumOperands) +
+ " operands to match the types");
+ }
+ // Consistency of padding value with vector type.
+ if (optionalPaddingValue) {
+ auto paddingValue = *optionalPaddingValue;
+ auto elementalType = paddingValue->getType();
+ if (!VectorType::isValidElementType(elementalType)) {
+ return emitOpError("requires valid padding vector elemental type");
+ }
+ if (elementalType != vectorType.getElementType()) {
+ return emitOpError(
+ "requires formal padding and vector of the same elemental type");
+ }
+ }
+ // Consistency of indices types.
+ unsigned numIndices = 0;
+ for (auto *idx : getIndices()) {
+ if (!idx->getType().isIndex()) {
+ return emitOpError(
+ "index to vector_transfer_read must have 'index' type");
+ }
+ ++numIndices;
+ }
+ if (numIndices != memrefType.getRank()) {
+ return emitOpError("requires at least a memref operand followed by " +
+ Twine(memrefType.getRank()) + " indices");
+ }
+
+ // Consistency of AffineMap attribute.
+ if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) {
+ return emitOpError("requires an AffineMapAttr named 'permutation_map'");
+ }
+ auto permutationMap = getPermutationMap();
+ if (!permutationMap.getRangeSizes().empty()) {
+ return emitOpError("requires an unbounded permutation_map");
+ }
+ if (permutationMap.getNumSymbols() != 0) {
+ return emitOpError("requires a permutation_map without symbols");
+ }
+ if (permutationMap.getNumInputs() != memrefType.getRank()) {
+ return emitOpError("requires a permutation_map with input dims of the "
+ "same rank as the memref type");
+ }
+ if (permutationMap.getNumResults() != vectorType.getRank()) {
+ return emitOpError("requires a permutation_map with result dims of the "
+ "same rank as the vector type");
+ }
+ return verifyPermutationMap(permutationMap,
+ [this](Twine t) { return emitOpError(t); });
+}
+
+//===----------------------------------------------------------------------===//
+// VectorTransferWriteOp
+//===----------------------------------------------------------------------===//
+void VectorTransferWriteOp::build(Builder *builder, OperationState *result,
+ SSAValue *srcVector, SSAValue *dstMemRef,
+ ArrayRef<SSAValue *> dstIndices,
+ AffineMap permutationMap) {
+ result->addOperands({srcVector, dstMemRef});
+ result->addOperands(dstIndices);
+ result->addAttribute(getPermutationMapAttrName(),
+ builder->getAffineMapAttr(permutationMap));
+}
+
+llvm::iterator_range<Operation::operand_iterator>
+VectorTransferWriteOp::getIndices() {
+ auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
+ auto end = begin + getMemRefType().getRank();
+ return {begin, end};
+}
+
+llvm::iterator_range<Operation::const_operand_iterator>
+VectorTransferWriteOp::getIndices() const {
+ auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
+ auto end = begin + getMemRefType().getRank();
+ return {begin, end};
+}
+
+AffineMap VectorTransferWriteOp::getPermutationMap() const {
+ return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
+}
+
+void VectorTransferWriteOp::print(OpAsmPrinter *p) const {
+ *p << getOperationName();
+ *p << " " << *getVector();
+ *p << ", " << *getMemRef();
+ *p << ", ";
+ p->printOperands(getIndices());
+ p->printOptionalAttrDict(getAttrs());
+ Type indexType = (*getIndices().begin())->getType();
+ *p << " : ";
+ p->printType(getVectorType());
+ *p << ", ";
+ p->printType(getMemRefType());
+ for (unsigned r = 0, n = getMemRefType().getRank(); r < n; ++r) {
+ *p << ", ";
+ p->printType(indexType);
+ }
+}
+
+bool VectorTransferWriteOp::parse(OpAsmParser *parser, OperationState *result) {
+ SmallVector<OpAsmParser::OperandType, 8> parsedOperands;
+ SmallVector<Type, 8> types;
+
+ // Parsing with support for optional paddingValue.
+ auto fail = parser->parseOperandList(parsedOperands) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonTypeList(types);
+ if (fail) {
+ return true;
+ }
+
+ // Resolution.
+ if (parsedOperands.size() != types.size()) {
+ parser->emitError(parser->getNameLoc(),
+ "requires number of operands and input types to match");
+ return true;
+ }
+ if (parsedOperands.size() < Offsets::FirstIndexOffset) {
+ parser->emitError(parser->getNameLoc(),
+ "requires at least vector and memref operands");
+ return true;
+ }
+ VectorType vectorType = types[Offsets::VectorOffset].dyn_cast<VectorType>();
+ if (!vectorType) {
+ parser->emitError(parser->getNameLoc(),
+ "Vector type expected for first input type");
+ return true;
+ }
+ MemRefType memrefType = types[Offsets::MemRefOffset].dyn_cast<MemRefType>();
+ if (!memrefType) {
+ parser->emitError(parser->getNameLoc(),
+ "MemRef type expected for second input type");
+ return true;
+ }
+
+ unsigned expectedNumOperands =
+ Offsets::FirstIndexOffset + memrefType.getRank();
+ if (parsedOperands.size() != expectedNumOperands) {
+ parser->emitError(parser->getNameLoc(),
+ "requires " + Twine(expectedNumOperands) + " operands");
+ return true;
+ }
+
+ OpAsmParser::OperandType vectorInfo = parsedOperands[Offsets::VectorOffset];
+ OpAsmParser::OperandType memrefInfo = parsedOperands[Offsets::MemRefOffset];
+ SmallVector<OpAsmParser::OperandType, 8> indexInfo{
+ parsedOperands.begin() + Offsets::FirstIndexOffset, parsedOperands.end()};
+ auto indexType = parser->getBuilder().getIndexType();
+ return parser->resolveOperand(vectorInfo, vectorType, result->operands) ||
+ parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
+ parser->resolveOperands(indexInfo, indexType, result->operands);
+}
+
+bool VectorTransferWriteOp::verify() const {
+ // Consistency of memref type in function type.
+ if (llvm::empty(getOperands())) {
+ return emitOpError(
+ "requires at least a memref operand followed by 'rank' indices");
+ }
+ if (!getMemRef()->getType().isa<MemRefType>()) {
+ return emitOpError("requires a memref first operand");
+ }
+ // Consistency of vector type in function type.
+ if (!getVector()->getType().isa<VectorType>()) {
+ return emitOpError("should have a vector input type in function type: "
+ "(vector_type, memref_type [, elemental_type]) -> ()");
+ }
+ // Consistency of elemental types in memref and vector.
+ MemRefType memrefType = getMemRefType();
+ VectorType vectorType = getVectorType();
+ if (memrefType.getElementType() != vectorType.getElementType())
+ return emitOpError(
+ "requires memref and vector types of the same elemental type");
+ // Consistency of number of input types.
+ unsigned expectedNumOperands =
+ Offsets::FirstIndexOffset + memrefType.getRank();
+ // Checks on the actual operands and their types.
+ if (getNumOperands() != expectedNumOperands) {
+ return emitOpError("expects " + Twine(expectedNumOperands) +
+ " operands to match the types");
+ }
+ // Consistency of indices types.
+ unsigned numIndices = 0;
+ for (auto *idx : getIndices()) {
+ if (!idx->getType().isIndex()) {
+ return emitOpError(
+ "index to vector_transfer_write must have 'index' type");
+ }
+ numIndices++;
+ }
+ if (numIndices != memrefType.getRank()) {
+ return emitOpError("requires at least a memref operand followed by " +
+ Twine(memrefType.getRank()) + " indices");
+ }
+
+ // Consistency of AffineMap attribute.
+ if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) {
+ return emitOpError("requires an AffineMapAttr named 'permutation_map'");
+ }
+ auto permutationMap = getPermutationMap();
+ if (!permutationMap.getRangeSizes().empty()) {
+ return emitOpError("requires an unbounded permutation_map");
+ }
+ if (permutationMap.getNumSymbols() != 0) {
+ return emitOpError("requires a permutation_map without symbols");
+ }
+ if (permutationMap.getNumInputs() != memrefType.getRank()) {
+ return emitOpError("requires a permutation_map with input dims of the "
+ "same rank as the memref type");
+ }
+ if (permutationMap.getNumResults() != vectorType.getRank()) {
+ return emitOpError("requires a permutation_map with result dims of the "
+ "same rank as the vector type");
+ }
+ return verifyPermutationMap(permutationMap,
+ [this](Twine t) { return emitOpError(t); });
+}
diff --git a/lib/Transforms/MaterializeVectors.cpp b/lib/Transforms/MaterializeVectors.cpp
index 60f0c06..400b4fd 100644
--- a/lib/Transforms/MaterializeVectors.cpp
+++ b/lib/Transforms/MaterializeVectors.cpp
@@ -89,6 +89,7 @@
using namespace mlir;
+using functional::makePtrDynCaster;
using functional::map;
static llvm::cl::list<int>
@@ -243,11 +244,11 @@
/// TODO(ntv): support a concrete AffineMap and compose with it.
/// TODO(ntv): these implementation details should be captured in a
/// vectorization trait at the op level directly.
-static SmallVector<MLValue *, 8>
-reindexAffineIndices(MLFuncBuilder *b, Type hwVectorType,
+static SmallVector<SSAValue *, 8>
+reindexAffineIndices(MLFuncBuilder *b, VectorType hwVectorType,
ArrayRef<unsigned> hwVectorInstance,
ArrayRef<SSAValue *> memrefIndices) {
- auto vectorShape = hwVectorType.cast<VectorType>().getShape();
+ auto vectorShape = hwVectorType.getShape();
assert(hwVectorInstance.size() >= vectorShape.size());
unsigned numIndices = memrefIndices.size();
@@ -287,78 +288,21 @@
// TODO(ntv): support a concrete map and composition.
auto app = b->create<AffineApplyOp>(b->getInsertionPoint()->getLoc(),
affineMap, memrefIndices);
- unsigned numResults = app->getNumResults();
- SmallVector<MLValue *, 8> res;
- for (unsigned i = 0; i < numResults; ++i) {
- res.push_back(cast<MLValue>(app->getResult(i)));
- }
- return res;
+ return SmallVector<SSAValue *, 8>{app->getResults()};
}
-/// Returns the cloned operands of `opStmt` for the instance of
-/// `hwVectorInstance` when lowering from a super-vector type to
-/// `hwVectorType`. `hwVectorInstance` represents one particular instance of
-/// `hwVectorType` int the covering of the super-vector type. For a more
-/// detailed description of the problem, see the description of
-/// reindexAffineIndices.
-static SmallVector<MLValue *, 8>
-cloneAndUnrollOperands(OperationStmt *opStmt, Type hwVectorType,
- ArrayRef<unsigned> hwVectorInstance,
- DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
- using functional::map;
-
- // For Ops that are not vector_transfer_read/vector_transfer_write we can just
- // substitute and be done.
- if (!isaVectorTransferRead(*opStmt) && !isaVectorTransferWrite(*opStmt)) {
- return map([substitutionsMap](
- SSAValue *v) { return substitute(v, *substitutionsMap); },
- opStmt->getOperands());
- }
-
- // TODO(ntv): this error-prone boilerplate can be removed once we have a
- // proper Op for vectr_transfer.
- unsigned offset = 0;
- unsigned numIndices = 0;
- SmallVector<MLValue *, 8> res;
- auto operands = opStmt->getOperands();
- if (isaVectorTransferRead(*opStmt)) {
- offset = 1;
- numIndices = opStmt->getNumOperands() - 1;
- } else if (isaVectorTransferWrite(*opStmt)) {
- offset = 2;
- numIndices = opStmt->getNumOperands() - 2;
- }
- // Copy as-is the [optional valueToStore], memref.
- for (unsigned i = 0; i < offset; ++i) {
- res.push_back(substitute(*(operands.begin() + i), *substitutionsMap));
- }
-
- MLFuncBuilder b(opStmt);
- // TODO(ntv): indices extraction is brittle and unsafe before we have an Op.
- SmallVector<SSAValue *, 8> indices;
- for (auto it = operands.begin() + offset; it != operands.end(); ++it) {
- indices.push_back(*it);
- }
- auto affineValues =
- reindexAffineIndices(&b, hwVectorType, hwVectorInstance, indices);
- res.append(affineValues.begin(), affineValues.end());
-
- return res;
-}
-
-// Returns attributes with the following substitutions applied:
-// - splat of `superVectorType` is replaced by splat of `hwVectorType`.
-// TODO(ntv): add more substitutions on a per-need basis.
-static SmallVector<NamedAttribute, 2>
+/// Returns attributes with the following substitutions applied:
+/// - splat of `superVectorType` is replaced by splat of `hwVectorType`.
+/// TODO(ntv): add more substitutions on a per-need basis.
+static SmallVector<NamedAttribute, 1>
materializeAttributes(OperationStmt *opStmt, VectorType superVectorType,
VectorType hwVectorType) {
- SmallVector<NamedAttribute, 2> res;
+ SmallVector<NamedAttribute, 1> res;
for (auto a : opStmt->getAttrs()) {
auto splat = a.second.dyn_cast<SplatElementsAttr>();
bool splatOfSuperVectorType = splat && (splat.getType() == superVectorType);
if (splatOfSuperVectorType) {
- auto attr = SplatElementsAttr::get(hwVectorType.cast<VectorType>(),
- splat.getValue());
+ auto attr = SplatElementsAttr::get(hwVectorType, splat.getValue());
res.push_back(NamedAttribute(a.first, attr));
} else {
res.push_back(a);
@@ -367,6 +311,70 @@
return res;
}
+/// Creates an instantiated version of `opStmt`.
+/// Ops other than VectorTransferReadOp/VectorTransferWriteOp require no
+/// affine reindexing. Just substitute their SSAValue* operands and be done. For
+/// this case the actual instance is irrelevant. Just use the SSA values in
+/// substitutionsMap.
+static OperationStmt *
+instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType superVectorType,
+ VectorType hwVectorType,
+ DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
+ assert(!opStmt->isa<VectorTransferReadOp>() &&
+ "Should call the function specialized for VectorTransferReadOp");
+ assert(!opStmt->isa<VectorTransferWriteOp>() &&
+ "Should call the function specialized for VectorTransferWriteOp");
+ auto operands =
+ map([substitutionsMap](
+ SSAValue *v) { return substitute(v, *substitutionsMap); },
+ opStmt->getOperands());
+ return b->createOperation(
+ opStmt->getLoc(), opStmt->getName(), operands, {hwVectorType},
+ materializeAttributes(opStmt, superVectorType, hwVectorType));
+}
+
+/// Creates an instantiated version of `read` for the instance of
+/// `hwVectorInstance` when lowering from a super-vector type to
+/// `hwVectorType`. `hwVectorInstance` represents one particular instance of
+/// `hwVectorType` int the covering of the super-vector type. For a more
+/// detailed description of the problem, see the description of
+/// reindexAffineIndices.
+static OperationStmt *
+instantiate(MLFuncBuilder *b, VectorTransferReadOp *read,
+ VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance,
+ DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
+ SmallVector<SSAValue *, 8> indices =
+ map(makePtrDynCaster<SSAValue>(), read->getIndices());
+ auto affineIndices =
+ reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices);
+ auto cloned = b->create<VectorTransferReadOp>(
+ read->getLoc(), hwVectorType, read->getMemRef(), affineIndices,
+ makePermutationMap(read->getMemRefType(), hwVectorType),
+ read->getPaddingValue());
+ return cast<OperationStmt>(cloned->getOperation());
+}
+
+/// Creates an instantiated version of `write` for the instance of
+/// `hwVectorInstance` when lowering from a super-vector type to
+/// `hwVectorType`. `hwVectorInstance` represents one particular instance of
+/// `hwVectorType` int the covering of th3e super-vector type. For a more
+/// detailed description of the problem, see the description of
+/// reindexAffineIndices.
+static OperationStmt *
+instantiate(MLFuncBuilder *b, VectorTransferWriteOp *write,
+ VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance,
+ DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
+ SmallVector<SSAValue *, 8> indices =
+ map(makePtrDynCaster<SSAValue>(), write->getIndices());
+ auto affineIndices =
+ reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices);
+ auto cloned = b->create<VectorTransferWriteOp>(
+ write->getLoc(), substitute(write->getVector(), *substitutionsMap),
+ write->getMemRef(), affineIndices,
+ makePermutationMap(write->getMemRefType(), hwVectorType));
+ return cast<OperationStmt>(cloned->getOperation());
+}
+
/// Returns `true` if stmt instance is properly cloned and inserted, false
/// otherwise.
/// The multi-dimensional `hwVectorInstance` belongs to the shapeRatio of
@@ -386,45 +394,52 @@
/// type, all operands are substituted according to `substitutions`. Thanks
/// to the topological order of a slice, the substitution is always
/// possible.
-static bool cloneAndInsertHardwareVectorInstance(Statement *stmt,
- MaterializationState *state) {
- LLVM_DEBUG(dbgs() << "\nclone" << *stmt);
- if (auto *opStmt = dyn_cast<OperationStmt>(stmt)) {
- // TODO(ntv): Is it worth considering an OperationStmt.clone operation
- // which changes the type so we can promote an OperationStmt with less
- // boilerplate?
- assert(opStmt->getNumResults() <= 1 && "NYI: opStmt has > 1 results");
- auto operands = cloneAndUnrollOperands(opStmt, state->hwVectorType,
- state->hwVectorInstance,
- state->substitutionsMap);
- MLFuncBuilder b(stmt);
- if (opStmt->getNumResults() == 0) {
- // vector_transfer_write
- b.createOperation(stmt->getLoc(), opStmt->getName(), operands, {},
- materializeAttributes(opStmt, state->superVectorType,
- state->hwVectorType));
- } else {
- // vector_transfer_read
- auto *cloned = b.createOperation(
- stmt->getLoc(), opStmt->getName(), operands, {state->hwVectorType},
- materializeAttributes(opStmt, state->superVectorType,
- state->hwVectorType));
- state->substitutionsMap->insert(std::make_pair(
- cast<MLValue>(opStmt->getResult(0)),
- cast<MLValue>(cast<OperationStmt>(cloned)->getResult(0))));
- }
- return false;
- }
+static bool instantiateMaterialization(Statement *stmt,
+ MaterializationState *state) {
+ LLVM_DEBUG(dbgs() << "\ninstantiate: " << *stmt);
+ // Fail hard and wake up when needed.
if (isa<ForStmt>(stmt)) {
- // Fail hard and wake up when needed.
stmt->emitError("NYI path ForStmt");
return true;
}
// Fail hard and wake up when needed.
- stmt->emitError("NYI path IfStmt");
- return true;
+ if (isa<IfStmt>(stmt)) {
+ stmt->emitError("NYI path IfStmt");
+ return true;
+ }
+
+ // Create a builder here for unroll-and-jam effects.
+ MLFuncBuilder b(stmt);
+ auto *opStmt = cast<OperationStmt>(stmt);
+ if (auto write = opStmt->dyn_cast<VectorTransferWriteOp>()) {
+ instantiate(&b, &*write, state->hwVectorType, state->hwVectorInstance,
+ state->substitutionsMap);
+ return false;
+ } else if (auto read = opStmt->dyn_cast<VectorTransferReadOp>()) {
+ auto *clone = instantiate(&b, &*read, state->hwVectorType,
+ state->hwVectorInstance, state->substitutionsMap);
+ state->substitutionsMap->insert(std::make_pair(
+ cast<MLValue>(read->getResult()), cast<MLValue>(clone->getResult(0))));
+ return false;
+ }
+ // The only op with 0 results reaching this point must, by construction, be
+ // VectorTransferWriteOps and have been caught above. Ops with >= 2 results
+ // are not yet supported. So just support 1 result.
+ if (opStmt->getNumResults() != 1) {
+ stmt->emitError("NYI: ops with != 1 results");
+ return true;
+ }
+ if (opStmt->getResult(0)->getType() != state->superVectorType) {
+ stmt->emitError("Op does not return a supervector.");
+ return true;
+ }
+ auto *clone = instantiate(&b, opStmt, state->superVectorType,
+ state->hwVectorType, state->substitutionsMap);
+ state->substitutionsMap->insert(std::make_pair(
+ cast<MLValue>(opStmt->getResult(0)), cast<MLValue>(clone->getResult(0))));
+ return false;
}
/// Takes a slice and rewrites the operations in it so that occurrences
@@ -463,15 +478,22 @@
scopedState.substitutionsMap = &substitutionMap;
// slice are topologically sorted, we can just clone them in order.
for (auto *stmt : *slice) {
- auto fail = cloneAndInsertHardwareVectorInstance(stmt, &scopedState);
+ auto fail = instantiateMaterialization(stmt, &scopedState);
(void)fail;
assert(!fail && "Unhandled super-vector materialization failure");
}
}
+
+ LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
+ LLVM_DEBUG(
+ cast<OperationStmt>((*slice)[0])->getOperationFunction()->print(dbgs()));
+
// slice are topologically sorted, we can just erase them in reverse
// order. Reverse iterator does not just work simply with an operator*
// dereference.
for (int idx = slice->size() - 1; idx >= 0; --idx) {
+ LLVM_DEBUG(dbgs() << "\nErase: ");
+ LLVM_DEBUG((*slice)[idx]->print(dbgs()));
(*slice)[idx]->erase();
}
}
@@ -497,25 +519,21 @@
const SetVector<OperationStmt *> &terminators,
MaterializationState *state) {
DenseSet<Statement *> seen;
- for (auto terminator : terminators) {
- LLVM_DEBUG(dbgs() << "\nFrom terminator:" << *terminator);
-
+ for (auto *term : terminators) {
// Short-circuit test, a given terminator may have been reached by some
// other previous transitive use-def chains.
- if (seen.count(terminator) > 0) {
+ if (seen.count(term) > 0) {
continue;
}
- // Terminators are vector_transfer_write with 0 results by construction atm.
- assert(isaVectorTransferWrite(*terminator) && "");
- assert(terminator->getNumResults() == 0 &&
- "NYI: terminators must have 0 results");
+ auto terminator = term->cast<VectorTransferWriteOp>();
+ LLVM_DEBUG(dbgs() << "\nFrom terminator:" << *term);
// Get the transitive use-defs starting from terminator, limited to the
// current enclosing scope of the terminator. See the top of the function
// Note for the justification of this restriction.
// TODO(ntv): relax scoping constraints.
- auto *enclosingScope = terminator->getParentStmt();
+ auto *enclosingScope = term->getParentStmt();
auto keepIfInSameScope = [enclosingScope](Statement *stmt) {
assert(stmt && "NULL stmt");
if (!enclosingScope) {
@@ -525,7 +543,7 @@
return properlyDominates(*enclosingScope, *stmt);
};
SetVector<Statement *> slice =
- getSlice(terminator, keepIfInSameScope, keepIfInSameScope);
+ getSlice(term, keepIfInSameScope, keepIfInSameScope);
assert(!slice.empty());
// Sanity checks: transitive slice must be completely disjoint from
@@ -540,10 +558,9 @@
// Emit the current slice.
// Set scoped super-vector and corresponding hw vector types.
- state->superVectorType =
- terminator->getOperand(0)->getType().cast<VectorType>();
+ state->superVectorType = terminator->getVectorType();
assert((state->superVectorType.getElementType() ==
- Type::getF32(terminator->getContext())) &&
+ Type::getF32(term->getContext())) &&
"Only f32 supported for now");
state->hwVectorType = VectorType::get(
state->hwVectorSize, state->superVectorType.getElementType());
@@ -568,7 +585,7 @@
// super-vector of subVectorType.
auto filter = [subVectorType](const Statement &stmt) {
const auto &opStmt = cast<OperationStmt>(stmt);
- if (!isaVectorTransferWrite(opStmt)) {
+ if (!opStmt.isa<VectorTransferWriteOp>()) {
return false;
}
return matcher::operatesOnStrictSuperVectors(opStmt, subVectorType);
diff --git a/lib/Transforms/Vectorize.cpp b/lib/Transforms/Vectorize.cpp
index 5a408b0..e4822c2 100644
--- a/lib/Transforms/Vectorize.cpp
+++ b/lib/Transforms/Vectorize.cpp
@@ -541,6 +541,7 @@
#define DEBUG_TYPE "early-vect"
using functional::apply;
+using functional::makePtrDynCaster;
using functional::map;
using functional::ScopeGuard;
using llvm::dbgs;
@@ -820,23 +821,15 @@
/// TODO(andydavis,bondhugula,ntv):
/// 1. generalize to support padding semantics and offsets within vector type.
static OperationStmt *
-createVectorTransferRead(MLFuncBuilder *b, Location loc, VectorType vectorType,
+createVectorTransferRead(OperationStmt *loadOp, VectorType vectorType,
SSAValue *srcMemRef, ArrayRef<SSAValue *> srcIndices) {
- SmallVector<SSAValue *, 8> operands;
- operands.reserve(1 + srcIndices.size());
- operands.insert(operands.end(), srcMemRef);
- operands.insert(operands.end(), srcIndices.begin(), srcIndices.end());
- OperationState opState(b->getContext(), loc, kVectorTransferReadOpName,
- operands, vectorType);
- return b->createOperation(opState);
-}
-
-/// Unwraps a pointer type to another type (possibly the same).
-/// Used in particular to allow easier compositions of
-/// llvm::iterator_range<ForStmt::operand_iterator> types.
-template <typename T, typename ToType = T>
-static std::function<ToType *(T *)> unwrapPtr() {
- return [](T *val) { return dyn_cast<ToType>(val); };
+ auto memRefType = srcMemRef->getType().cast<MemRefType>();
+ MLFuncBuilder b(loadOp);
+ // TODO(ntv): neutral for noneffective padding.
+ auto transfer = b.create<VectorTransferReadOp>(
+ loadOp->getLoc(), vectorType, srcMemRef, srcIndices,
+ makePermutationMap(memRefType, vectorType));
+ return cast<OperationStmt>(transfer->getOperation());
}
/// Handles the vectorization of load and store MLIR operations.
@@ -865,15 +858,14 @@
// Materialize a MemRef with 1 vector.
auto *opStmt = cast<OperationStmt>(memoryOp->getOperation());
- MLFuncBuilder b(opStmt);
// For now, vector_transfers must be aligned, operate only on indices with an
// identity subset of AffineMap and do not change layout.
// TODO(ntv): increase the expressiveness power of vector_transfer operations
// as needed by various targets.
if (opStmt->template isa<LoadOp>()) {
auto *transfer = createVectorTransferRead(
- &b, opStmt->getLoc(), vectorType, memoryOp->getMemRef(),
- map(unwrapPtr<SSAValue>(), memoryOp->getIndices()));
+ opStmt, vectorType, memoryOp->getMemRef(),
+ map(makePtrDynCaster<SSAValue>(), memoryOp->getIndices()));
state->registerReplacement(opStmt, transfer);
} else {
state->registerTerminator(opStmt);
@@ -1008,7 +1000,7 @@
auto *splat = cast<OperationStmt>(b.createOperation(
loc, constantOpStmt->getName(), {}, {vectorType},
{make_pair(Identifier::get("value", b.getContext()), attr)}));
- return cast<MLValue>(cast<OperationStmt>(splat)->getResult(0));
+ return cast<MLValue>(splat->getResult(0));
}
/// Returns a uniqu'ed VectorType.
@@ -1106,17 +1098,17 @@
static OperationStmt *createVectorTransferWrite(OperationStmt *storeOp,
VectorizationState *state) {
auto store = storeOp->cast<StoreOp>();
+ auto *memRef = store->getMemRef();
+ auto memRefType = memRef->getType().cast<MemRefType>();
auto *value = store->getValueToStore();
- auto indices = map(unwrapPtr<SSAValue>(), store->getIndices());
- SmallVector<SSAValue *, 8> operands;
- operands.reserve(1 + 1 + indices.size());
- operands.insert(operands.end(), vectorizeOperand(value, storeOp, state));
- operands.insert(operands.end(), store->getMemRef());
- operands.insert(operands.end(), indices.begin(), indices.end());
+ auto *vectorValue = vectorizeOperand(value, storeOp, state);
+ auto vectorType = vectorValue->getType().cast<VectorType>();
+ auto indices = map(makePtrDynCaster<SSAValue>(), store->getIndices());
MLFuncBuilder b(storeOp);
- OperationState opState(b.getContext(), storeOp->getLoc(),
- kVectorTransferWriteOpName, operands, {});
- return b.createOperation(opState);
+ auto transfer = b.create<VectorTransferWriteOp>(
+ storeOp->getLoc(), vectorValue, memRef, indices,
+ makePermutationMap(memRefType, vectorType));
+ return cast<OperationStmt>(transfer->getOperation());
}
/// Encodes OperationStmt-specific behavior for vectorization. In general we
@@ -1134,9 +1126,9 @@
// Sanity checks.
assert(!stmt->isa<LoadOp>() &&
"all loads must have already been fully vectorized independently");
- assert(!isaVectorTransferRead(*stmt) &&
+ assert(!stmt->isa<VectorTransferReadOp>() &&
"vector_transfer_read cannot be further vectorized");
- assert(!isaVectorTransferWrite(*stmt) &&
+ assert(!stmt->isa<VectorTransferWriteOp>() &&
"vector_transfer_write cannot be further vectorized");
if (stmt->isa<StoreOp>()) {