[MLIR] Add VectorTransferOps
This CL implements and uses VectorTransferOps in lieu of the former custom
call op. Tests are updated accordingly.
VectorTransferOps come in 2 flavors: VectorTransferReadOp and
VectorTransferWriteOp.
VectorTransferOps can be thought of as a backend-independent
pseudo op/library call that needs to be legalized to MLIR (whiteboxed) before
it can be lowered to backend-dependent IR.
Note that the current implementation does not yet support a real permutation
map. Proper support will come in a followup CL.
VectorTransferReadOp
====================
VectorTransferReadOp performs a blocking read from a scalar memref
location into a super-vector of the same elemental type. This operation is
called 'read' by opposition to 'load' because the super-vector granularity
is generally not representable with a single hardware register. As a
consequence, memory transfers will generally be required when lowering
VectorTransferReadOp. A VectorTransferReadOp is thus a mid-level abstraction
that supports super-vectorization with non-effecting padding for full-tile
only code.
A vector transfer read has semantics similar to a vector load, with additional
support for:
1. an optional value of the elemental type of the MemRef. This value
supports non-effecting padding and is inserted in places where the
vector read exceeds the MemRef bounds. If the value is not specified,
the access is statically guaranteed to be within bounds;
2. an attribute of type AffineMap to specify a slice of the original
MemRef access and its transposition into the super-vector shape. The
permutation_map is an unbounded AffineMap that must represent a
permutation from the MemRef dim space projected onto the vector dim
space.
Example:
```mlir
%A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>
...
%val = `ssa-value` : f32
// let %i, %j, %k, %l be ssa-values of type index
%v0 = vector_transfer_read %src, %i, %j, %k, %l
{permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
(memref<?x?x?x?xf32>, index, index, index, index) ->
vector<16x32x64xf32>
%v1 = vector_transfer_read %src, %i, %j, %k, %l, %val
{permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
(memref<?x?x?x?xf32>, index, index, index, index, f32) ->
vector<16x32x64xf32>
```
VectorTransferWriteOp
=====================
VectorTransferWriteOp performs a blocking write from a super-vector to
a scalar memref of the same elemental type. This operation is
called 'write' by opposition to 'store' because the super-vector
granularity is generally not representable with a single hardware register. As
a consequence, memory transfers will generally be required when lowering
VectorTransferWriteOp. A VectorTransferWriteOp is thus a mid-level
abstraction that supports super-vectorization with non-effecting padding
for full-tile only code.
A vector transfer write has semantics similar to a vector store, with
additional support for handling out-of-bounds situations.
Example:
```mlir
%A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>.
%val = `ssa-value` : vector<16x32x64xf32>
// let %i, %j, %k, %l be ssa-values of type index
vector_transfer_write %val, %src, %i, %j, %k, %l
{permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
(vector<16x32x64xf32>, memref<?x?x?x?xf32>, index, index, index, index)
```
PiperOrigin-RevId: 223873234
diff --git a/lib/Analysis/VectorAnalysis.cpp b/lib/Analysis/VectorAnalysis.cpp
index 75f6229..9c2160c 100644
--- a/lib/Analysis/VectorAnalysis.cpp
+++ b/lib/Analysis/VectorAnalysis.cpp
@@ -18,6 +18,7 @@
#include "mlir/Analysis/VectorAnalysis.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Statements.h"
+#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/STLExtras.h"
@@ -28,14 +29,6 @@
using namespace mlir;
-bool mlir::isaVectorTransferRead(const OperationStmt &stmt) {
- return stmt.getName().getStringRef().str() == kVectorTransferReadOpName;
-}
-
-bool mlir::isaVectorTransferWrite(const OperationStmt &stmt) {
- return stmt.getName().getStringRef().str() == kVectorTransferWriteOpName;
-}
-
Optional<SmallVector<unsigned, 4>> mlir::shapeRatio(ArrayRef<int> superShape,
ArrayRef<int> subShape) {
if (superShape.size() < subShape.size()) {
@@ -83,6 +76,20 @@
return shapeRatio(superVectorType.getShape(), subVectorType.getShape());
}
+AffineMap mlir::makePermutationMap(MemRefType memrefType,
+ VectorType vectorType) {
+ unsigned memRefRank = memrefType.getRank();
+ unsigned vectorRank = vectorType.getRank();
+ assert(memRefRank >= vectorRank && "Broadcast not supported");
+ unsigned offset = memRefRank - vectorRank;
+ SmallVector<AffineExpr, 4> perm;
+ perm.reserve(memRefRank);
+ for (unsigned i = 0; i < vectorRank; ++i) {
+ perm.push_back(getAffineDimExpr(offset + i, memrefType.getContext()));
+ }
+ return AffineMap::get(memRefRank, 0, perm, {});
+}
+
bool mlir::matcher::operatesOnStrictSuperVectors(const OperationStmt &opStmt,
VectorType subVectorType) {
// First, extract the vector type and ditinguish between:
@@ -96,15 +103,11 @@
/// do not have to special case. Maybe a trait, or just a method, unclear atm.
bool mustDivide = false;
VectorType superVectorType;
- if (isaVectorTransferRead(opStmt)) {
- superVectorType = opStmt.getResult(0)->getType().cast<VectorType>();
+ if (auto read = opStmt.dyn_cast<VectorTransferReadOp>()) {
+ superVectorType = read->getResultType();
mustDivide = true;
- } else if (isaVectorTransferWrite(opStmt)) {
- // TODO(ntv): if vector_transfer_write had store-like semantics we could
- // have written something similar to:
- // auto store = storeOp->cast<StoreOp>();
- // auto *value = store->getValueToStore();
- superVectorType = opStmt.getOperand(0)->getType().cast<VectorType>();
+ } else if (auto write = opStmt.dyn_cast<VectorTransferWriteOp>()) {
+ superVectorType = write->getVectorType();
mustDivide = true;
} else if (opStmt.getNumResults() == 0) {
assert(opStmt.isa<ReturnOp>() &&