Extract vector_transfer_* Ops into a SuperVectorDialect.
From the beginning, vector_transfer_read and vector_transfer_write opreations
were intended as a mid-level vectorization abstraction. In particular, they
are lowered to the StandardOps dialect before further processing. As such, it
does not make sense to keep them at the same level as StandardOps. Introduce
the new SuperVectorOps dialect and move vector_transfer_* operations there.
This will be used as a testbed for the generic lowering/legalization pass.
PiperOrigin-RevId: 225554492
diff --git a/g3doc/Dialects/SuperVector.md b/g3doc/Dialects/SuperVector.md
new file mode 100644
index 0000000..a8dfcb4
--- /dev/null
+++ b/g3doc/Dialects/SuperVector.md
@@ -0,0 +1,156 @@
+# SuperVector Dialect
+
+This dialect provides mid-level abstraction for the MLIR super-vectorizer.
+
+[TOC]
+
+## Operations {#operations}
+
+### Vector transfers {#vector-transfers}
+
+#### `vector_transfer_read` operation {#'vector_transfer_read'-operation}
+
+Syntax:
+
+``` {.ebnf}
+operation ::= ssa-id `=` `vector_transfer_read` ssa-use-list `{` attribute-entry `} :` function-type
+```
+
+Examples:
+
+```mlir {.mlir}
+// Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into vector<32x256xf32> and
+// pad with %f0 to handle the boundary case:
+%f0 = constant 0.0f : f32
+for %i0 = 0 to %0 {
+ for %i1 = 0 to %1 step 256 {
+ for %i2 = 0 to %2 step 32 {
+ %v = vector_transfer_read %A, %i0, %i1, %i2, %f0
+ {permutation_map: (d0, d1, d2) -> (d2, d1)} :
+ (memref<?x?x?xf32>, index, index, f32) -> vector<32x256xf32>
+}}}
+
+// Read the slice `%A[%i0, %i1]` (i.e. the element `%A[%i0, %i1]`) into
+// vector<128xf32>. The underlying implementation will require a 1-D vector
+// broadcast:
+for %i0 = 0 to %0 {
+ for %i1 = 0 to %1 {
+ %3 = vector_transfer_read %A, %i0, %i1
+ {permutation_map: (d0, d1) -> (0)} :
+ (memref<?x?xf32>, index, index) -> vector<128xf32>
+ }
+}
+```
+
+The `vector_transfer_read` performs a blocking read from a slice within a scalar
+[MemRef](#memref-type) supplied as its first operand into a
+[vector](#vector-type) of the same elemental type. The slice is further defined
+by a full-rank index within the MemRef, supplied as the operands `2 .. 1 +
+rank(memref)`. The permutation_map [attribute](#attributes) is an
+[affine-map](#affine-maps) which specifies the transposition on the slice to
+match the vector shape. The size of the slice is specified by the size of the
+vector, given as the return type. Optionally, an `ssa-value` of the same
+elemental type as the MemRef is provided as the last operand to specify padding
+in the case of out-of-bounds accesses. Absence of the optional padding value
+signifies the `vector_transfer_read` is statically guaranteed to remain within
+the MemRef bounds. This operation is called 'read' by opposition to 'load'
+because the super-vector granularity is generally not representable with a
+single hardware register. A `vector_transfer_read` is thus a mid-level
+abstraction that supports super-vectorization with non-effecting padding for
+full-tile-only code.
+
+More precisely, let's dive deeper into the permutation_map for the following :
+
+```mlir {.mlir}
+vector_transfer_read %A, %expr1, %expr2, %expr3, %expr4
+ { permutation_map : (d0,d1,d2,d3) -> (d2,0,d0) } :
+ (memref<?x?x?x?xf32>, index, index, index, index) -> vector<3x4x5xf32>
+```
+
+This operation always reads a slice starting at `%A[%expr1, %expr2, %expr3,
+%expr4]`. The size of the slice is 3 along d2 and 5 along d0, so the slice is:
+`%A[%expr1 : %expr1 + 5, %expr2, %expr3:%expr3 + 3, %expr4]`
+
+That slice needs to be read into a `vector<3x4x5xf32>`. Since the permutation
+map is not full rank, there must be a broadcast along vector dimension `1`.
+
+A notional lowering of vector_transfer_read could generate code resembling:
+
+```mlir {.mlir}
+// %expr1, %expr2, %expr3, %expr4 defined before this point
+%tmp = alloc() : vector<3x4x5xf32>
+%view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>>
+for %i = 0 to 3 {
+ for %j = 0 to 4 {
+ for %k = 0 to 5 {
+ %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
+ store %tmp[%i, %j, %k] : vector<3x4x5xf32>
+}}}
+%c0 = constant 0 : index
+%vec = load %view_in_tmp[%c0] : vector<3x4x5xf32>
+```
+
+On a GPU one could then map `i`, `j`, `k` to blocks and threads. Notice that the
+temporary storage footprint is `3 * 5` values but `3 * 4 * 5` values are
+actually transferred betwen `%A` and `%tmp`.
+
+Alternatively, if a notional vector broadcast instruction were available, the
+lowered code would resemble:
+
+```mlir {.mlir}
+// %expr1, %expr2, %expr3, %expr4 defined before this point
+%tmp = alloc() : vector<3x4x5xf32>
+%view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>>
+for %i = 0 to 3 {
+ for %k = 0 to 5 {
+ %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
+ store %tmp[%i, 0, %k] : vector<3x4x5xf32>
+}}
+%c0 = constant 0 : index
+%tmpvec = load %view_in_tmp[%c0] : vector<3x4x5xf32>
+%vec = broadcast %tmpvec, 1 : vector<3x4x5xf32>
+```
+
+where `broadcast` broadcasts from element 0 to all others along the specified
+dimension. This time, the temporary storage footprint is `3 * 5` values which is
+the same amount of data as the `3 * 5` values transferred. An additional `1`
+broadcast is required. On a GPU this broadcast could be implemented using a
+warp-shuffle if loop `j` were mapped to `threadIdx.x`.
+
+#### `vector_transfer_write` operation {#'vector_transfer_write'-operation}
+
+Syntax:
+
+``` {.ebnf}
+operation ::= `vector_transfer_write` ssa-use-list `{` attribute-entry `} :` vector-type ', ' memref-type ', ' index-type-list
+```
+
+Examples:
+
+```mlir {.mlir}
+// write vector<16x32x64xf32> into the slice `%A[%i0, %i1:%i1+32, %i2:%i2+64, %i3:%i3+16]`:
+for %i0 = 0 to %0 {
+ for %i1 = 0 to %1 step 32 {
+ for %i2 = 0 to %2 step 64 {
+ for %i3 = 0 to %3 step 16 {
+ %val = `ssa-value` : vector<16x32x64xf32>
+ vector_transfer_write %val, %A, %i0, %i1, %i2, %i3
+ {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
+ vector<16x32x64xf32>, memref<?x?x?x?xf32>, index, index, index, index
+}}}}
+```
+
+The `vector_transfer_write` performs a blocking write from a
+[vector](#vector-type), supplied as its first operand, into a slice within a
+scalar [MemRef](#memref-type) of the same elemental type, supplied as its second
+operand. The slice is further defined by a full-rank index within the MemRef,
+supplied as the operands `3 .. 2 + rank(memref)`. The permutation_map
+[attribute](#attributes) is an [affine-map](#affine-maps) which specifies the
+transposition on the slice to match the vector shape. The size of the slice is
+specified by the size of the vector. This operation is called 'write' by
+opposition to 'store' because the super-vector granularity is generally not
+representable with a single hardware register. A `vector_transfer_write` is thus
+a mid-level abstraction that supports super-vectorization with non-effecting
+padding for full-tile-only code. It is the responsibility of
+`vector_transfer_write`'s implementation to ensure the memory writes are valid.
+Different lowerings may be pertinent depending on the hardware support.
diff --git a/g3doc/LangRef.md b/g3doc/LangRef.md
index 3548b1f..4737a03 100644
--- a/g3doc/LangRef.md
+++ b/g3doc/LangRef.md
@@ -1965,156 +1965,20 @@
They must either have the same rank, or one may be an unknown rank. The
operation is invalid if converting to a mismatching constant dimension.
-#### 'vector_transfer_read' operation {#'vector_transfer_read'-operation}
+## Dialects
-Syntax:
+MLIR supports multiple dialects containing a set of operations and types defined
+together, potentially outside of the main tree. Dialects are produced and
+consumed by certain passes. MLIR can be converted between different dialects by
+a conversion pass.
-``` {.ebnf}
-operation ::= ssa-id `=` `vector_transfer_read` ssa-use-list `{` attribute-entry `} :` function-type
-```
+Currently, MLIR supports the following dialects:
-Examples:
+* [Standard dialect](#standard-operations)
+* [SuperVector dialect](Dialects/SuperVector.md)
+* [TensorFlow dialect](#tensorflow-operations)
-```mlir {.mlir}
-// Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into vector<32x256xf32> and
-// pad with %f0 to handle the boundary case:
-%f0 = constant 0.0f : f32
-for %i0 = 0 to %0 {
- for %i1 = 0 to %1 step 256 {
- for %i2 = 0 to %2 step 32 {
- %v = vector_transfer_read %A, %i0, %i1, %i2, %f0
- {permutation_map: (d0, d1, d2) -> (d2, d1)} :
- (memref<?x?x?xf32>, index, index, f32) -> vector<32x256xf32>
-}}}
-
-// Read the slice `%A[%i0, %i1]` (i.e. the element `%A[%i0, %i1]`) into
-// vector<128xf32>. The underlying implementation will require a 1-D vector
-// broadcast:
-for %i0 = 0 to %0 {
- for %i1 = 0 to %1 {
- %3 = vector_transfer_read %A, %i0, %i1
- {permutation_map: (d0, d1) -> (0)} :
- (memref<?x?xf32>, index, index) -> vector<128xf32>
- }
-}
-```
-
-The `vector_transfer_read` performs a blocking read from a slice within a scalar
-[MemRef](#memref-type) supplied as its first operand into a
-[vector](#vector-type) of the same elemental type. The slice is further defined
-by a full-rank index within the MemRef, supplied as the operands `2 .. 1 +
-rank(memref)`. The permutation_map [attribute](#attributes) is an
-[affine-map](#affine-maps) which specifies the transposition on the slice to
-match the vector shape. The size of the slice is specified by the size of the
-vector, given as the return type. Optionally, an `ssa-value` of the same
-elemental type as the MemRef is provided as the last operand to specify padding
-in the case of out-of-bounds accesses. Absence of the optional padding value
-signifies the `vector_transfer_read` is statically guaranteed to remain within
-the MemRef bounds. This operation is called 'read' by opposition to 'load'
-because the super-vector granularity is generally not representable with a
-single hardware register. A `vector_transfer_read` is thus a mid-level
-abstraction that supports super-vectorization with non-effecting padding for
-full-tile-only code.
-
-More precisely, let's dive deeper into the permutation_map for the following :
-
-```mlir {.mlir}
-vector_transfer_read %A, %expr1, %expr2, %expr3, %expr4
- { permutation_map : (d0,d1,d2,d3) -> (d2,0,d0) } :
- (memref<?x?x?x?xf32>, index, index, index, index) -> vector<3x4x5xf32>
-```
-
-This operation always reads a slice starting at
-`%A[%expr1, %expr2, %expr3, %expr4]`.
-The size of the slice is 3 along d2 and 5 along d0, so the slice is:
-`%A[%expr1 : %expr1 + 5, %expr2, %expr3:%expr3 + 3, %expr4]`
-
-That slice needs to be read into a `vector<3x4x5xf32>`.
-Since the permutation map is not full rank, there must be a broadcast along
-vector dimension `1`.
-
-A notional lowering of vector_transfer_read could generate code resembling:
-
-```mlir {.mlir}
-// %expr1, %expr2, %expr3, %expr4 defined before this point
-%tmp = alloc() : vector<3x4x5xf32>
-%view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>>
-for %i = 0 to 3 {
- for %j = 0 to 4 {
- for %k = 0 to 5 {
- %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
- store %tmp[%i, %j, %k] : vector<3x4x5xf32>
-}}}
-%c0 = constant 0 : index
-%vec = load %view_in_tmp[%c0] : vector<3x4x5xf32>
-```
-
-On a GPU one could then map `i`, `j`, `k` to blocks and threads. Notice that the
-temporary storage footprint is `3 * 5` values but `3 * 4 * 5` values are
-actually transferred betwen `%A` and `%tmp`.
-
-Alternatively, if a notional vector broadcast instruction were available, the
-lowered code would resemble:
-
-```mlir {.mlir}
-// %expr1, %expr2, %expr3, %expr4 defined before this point
-%tmp = alloc() : vector<3x4x5xf32>
-%view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>>
-for %i = 0 to 3 {
- for %k = 0 to 5 {
- %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
- store %tmp[%i, 0, %k] : vector<3x4x5xf32>
-}}
-%c0 = constant 0 : index
-%tmpvec = load %view_in_tmp[%c0] : vector<3x4x5xf32>
-%vec = broadcast %tmpvec, 1 : vector<3x4x5xf32>
-```
-
-where `broadcast` broadcasts from element 0 to all others along the specified
-dimension. This time, the temporary storage footprint is `3 * 5` values which
-is the same amount of data as the `3 * 5` values transferred. An additional
-`1` broadcast is required. On a GPU this broadcast could be implemented
-using a warp-shuffle if loop `j` were mapped to `threadIdx.x`.
-
-#### 'vector_transfer_write' operation {#'vector_transfer_write'-operation}
-
-Syntax:
-
-``` {.ebnf}
-operation ::= `vector_transfer_write` ssa-use-list `{` attribute-entry `} :` vector-type ', ' memref-type ', ' index-type-list
-```
-
-Examples:
-
-```mlir {.mlir}
-// write vector<16x32x64xf32> into the slice `%A[%i0, %i1:%i1+32, %i2:%i2+64, %i3:%i3+16]`:
-for %i0 = 0 to %0 {
- for %i1 = 0 to %1 step 32 {
- for %i2 = 0 to %2 step 64 {
- for %i3 = 0 to %3 step 16 {
- %val = `ssa-value` : vector<16x32x64xf32>
- vector_transfer_write %val, %A, %i0, %i1, %i2, %i3
- {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
- vector<16x32x64xf32>, memref<?x?x?x?xf32>, index, index, index, index
-}}}}
-```
-
-The `vector_transfer_write` performs a blocking write from a
-[vector](#vector-type), supplied as its first operand, into a slice within a
-scalar [MemRef](#memref-type) of the same elemental type, supplied as its second
-operand. The slice is further defined by a full-rank index within the MemRef,
-supplied as the operands `3 .. 2 + rank(memref)`. The permutation_map
-[attribute](#attributes) is an [affine-map](#affine-maps) which specifies the
-transposition on the slice to match the vector shape. The size of the slice is
-specified by the size of the vector. This operation is called 'write' by
-opposition to 'store' because the super-vector granularity is generally not
-representable with a single hardware register. A `vector_transfer_write` is thus
-a mid-level abstraction that supports super-vectorization with non-effecting
-padding for full-tile-only code. It is the responsibility of
-`vector_transfer_write`'s implementation to ensure the memory writes are valid.
-Different lowerings may be pertinent depending on the hardware support.
-
-## TensorFlow operations {#tensorflow-operations}
+### TensorFlow operations {#tensorflow-operations}
MLIR operations can represent arbitrary TensorFlow operations with a reversible
mapping. Switch and merge nodes are represented with the MLIR control flow
@@ -2143,7 +2007,7 @@
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
```
-## Target specific operations {#target-specific-operations}
+### Target specific operations {#target-specific-operations}
We expect to expose many target-specific (such as TPU-specific) operations
directly through to MLIR.
diff --git a/include/mlir/StandardOps/StandardOps.h b/include/mlir/StandardOps/StandardOps.h
index 21fa01c..6073eba 100644
--- a/include/mlir/StandardOps/StandardOps.h
+++ b/include/mlir/StandardOps/StandardOps.h
@@ -846,170 +846,6 @@
explicit TensorCastOp(const Operation *state) : CastOp(state) {}
};
-/// VectorTransferReadOp performs a blocking read from a scalar memref
-/// location into a super-vector of the same elemental type. This operation is
-/// called 'read' by opposition to 'load' because the super-vector granularity
-/// is generally not representable with a single hardware register. As a
-/// consequence, memory transfers will generally be required when lowering
-/// VectorTransferReadOp. A VectorTransferReadOp is thus a mid-level abstraction
-/// that supports super-vectorization with non-effecting padding for full-tile
-/// only code.
-//
-/// A vector transfer read has semantics similar to a vector load, with
-/// additional support for:
-/// 1. an optional value of the elemental type of the MemRef. This value
-/// supports non-effecting padding and is inserted in places where the
-/// vector read exceeds the MemRef bounds. If the value is not specified,
-/// the access is statically guaranteed to be within bounds;
-/// 2. an attribute of type AffineMap to specify a slice of the original
-/// MemRef access and its transposition into the super-vector shape.
-/// The permutation_map is an unbounded AffineMap that must
-/// represent a permutation from the MemRef dim space projected onto the
-/// vector dim space.
-/// This permutation_map has as many output dimensions as the vector rank.
-/// However, it is not necessarily full rank on the target space to signify
-/// that broadcast operations will be needed along certain vector
-/// dimensions.
-/// In the limit, one may load a 0-D slice of a memref (i.e. a single
-/// value) into a vector, which corresponds to broadcasting that value in
-/// the whole vector (i.e. a non-constant splat).
-///
-/// Example with full rank permutation_map:
-/// ```mlir
-/// %A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>
-/// ...
-/// %val = `ssa-value` : f32
-/// // let %i, %j, %k, %l be ssa-values of type index
-/// %v0 = vector_transfer_read %src, %i, %j, %k, %l
-/// {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
-/// (memref<?x?x?x?xf32>, index, index, index, index) ->
-/// vector<16x32x64xf32>
-/// %v1 = vector_transfer_read %src, %i, %j, %k, %l, %val
-/// {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
-/// (memref<?x?x?x?xf32>, index, index, index, index, f32) ->
-/// vector<16x32x64xf32>
-/// ```
-///
-/// Example with partial rank permutation_map:
-/// ```mlir
-/// %c0 = constant 0 : index
-/// %A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>
-/// ...
-/// // let %i, %j be ssa-values of type index
-/// %v0 = vector_transfer_read %src, %i, %c0, %c0, %c0
-/// {permutation_map: (d0, d1, d2, d3) -> (0, d1, 0)} :
-/// (memref<?x?x?x?xf32>, index, index, index, index) ->
-/// vector<16x32x64xf32>
-class VectorTransferReadOp
- : public Op<VectorTransferReadOp, OpTrait::VariadicOperands,
- OpTrait::OneResult> {
- enum Offsets : unsigned { MemRefOffset = 0, FirstIndexOffset = 1 };
-
-public:
- static StringRef getOperationName() { return "vector_transfer_read"; }
- static StringRef getPermutationMapAttrName() { return "permutation_map"; }
- static void build(Builder *builder, OperationState *result,
- VectorType vectorType, SSAValue *srcMemRef,
- ArrayRef<SSAValue *> srcIndices, AffineMap permutationMap,
- Optional<SSAValue *> paddingValue = None);
- VectorType getResultType() const {
- return getResult()->getType().cast<VectorType>();
- }
- SSAValue *getVector() { return getResult(); }
- const SSAValue *getVector() const { return getResult(); }
- SSAValue *getMemRef() { return getOperand(Offsets::MemRefOffset); }
- const SSAValue *getMemRef() const {
- return getOperand(Offsets::MemRefOffset);
- }
- VectorType getVectorType() const { return getResultType(); }
- MemRefType getMemRefType() const {
- return getMemRef()->getType().cast<MemRefType>();
- }
- llvm::iterator_range<Operation::operand_iterator> getIndices();
- llvm::iterator_range<Operation::const_operand_iterator> getIndices() const;
- Optional<SSAValue *> getPaddingValue();
- Optional<const SSAValue *> getPaddingValue() const;
- AffineMap getPermutationMap() const;
-
- static bool parse(OpAsmParser *parser, OperationState *result);
- void print(OpAsmPrinter *p) const;
- bool verify() const;
-
-private:
- friend class Operation;
- explicit VectorTransferReadOp(const Operation *state) : Op(state) {}
-};
-
-/// VectorTransferWriteOp performs a blocking write from a super-vector to
-/// a scalar memref of the same elemental type. This operation is
-/// called 'write' by opposition to 'store' because the super-vector granularity
-/// is generally not representable with a single hardware register. As a
-/// consequence, memory transfers will generally be required when lowering
-/// VectorTransferWriteOp. A VectorTransferWriteOp is thus a mid-level
-/// abstraction that supports super-vectorization with non-effecting padding for
-/// full-tile only code.
-///
-/// A vector transfer write has semantics similar to a vector store, with
-/// additional support for handling out-of-bounds situations. It is the
-/// responsibility of vector_transfer_write's implementation to ensure the
-/// memory writes are valid. Different implementations may be pertinent
-/// depending on the hardware support including:
-/// 1. predication;
-/// 2. explicit control-flow;
-/// 3. Read-Modify-Write;
-/// 4. writing out of bounds of the memref when the allocation allows it.
-///
-/// Example:
-/// ```mlir
-/// %A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>.
-/// %val = `ssa-value` : vector<16x32x64xf32>
-/// // let %i, %j, %k, %l be ssa-values of type index
-/// vector_transfer_write %val, %src, %i, %j, %k, %l
-/// {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
-/// vector<16x32x64xf32>, memref<?x?x?x?xf32>, index, index, index, index
-/// ```
-class VectorTransferWriteOp
- : public Op<VectorTransferWriteOp, OpTrait::VariadicOperands,
- OpTrait::ZeroResult> {
- enum Offsets : unsigned {
- VectorOffset = 0,
- MemRefOffset = 1,
- FirstIndexOffset = 2
- };
-
-public:
- static StringRef getOperationName() { return "vector_transfer_write"; }
- static StringRef getPermutationMapAttrName() { return "permutation_map"; }
- static void build(Builder *builder, OperationState *result,
- SSAValue *srcVector, SSAValue *dstMemRef,
- ArrayRef<SSAValue *> dstIndices, AffineMap permutationMap);
- SSAValue *getVector() { return getOperand(Offsets::VectorOffset); }
- const SSAValue *getVector() const {
- return getOperand(Offsets::VectorOffset);
- }
- VectorType getVectorType() const {
- return getVector()->getType().cast<VectorType>();
- }
- SSAValue *getMemRef() { return getOperand(Offsets::MemRefOffset); }
- const SSAValue *getMemRef() const {
- return getOperand(Offsets::MemRefOffset);
- }
- MemRefType getMemRefType() const {
- return getMemRef()->getType().cast<MemRefType>();
- }
- llvm::iterator_range<Operation::operand_iterator> getIndices();
- llvm::iterator_range<Operation::const_operand_iterator> getIndices() const;
- AffineMap getPermutationMap() const;
-
- static bool parse(OpAsmParser *parser, OperationState *result);
- void print(OpAsmPrinter *p) const;
- bool verify() const;
-
-private:
- friend class Operation;
- explicit VectorTransferWriteOp(const Operation *state) : Op(state) {}
-};
-
} // end namespace mlir
#endif
diff --git a/include/mlir/SuperVectorOps/SuperVectorOps.h b/include/mlir/SuperVectorOps/SuperVectorOps.h
new file mode 100644
index 0000000..5cd0a1c
--- /dev/null
+++ b/include/mlir/SuperVectorOps/SuperVectorOps.h
@@ -0,0 +1,204 @@
+//===- SuperVectorOps.h - MLIR Super Vectorizer Operations ------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines convenience types for working with super-vectorization
+// operations, in particular super-vector loads and stores.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INCLUDE_MLIR_SUPERVECTOROPS_SUPERVECTOROPS_H
+#define MLIR_INCLUDE_MLIR_SUPERVECTOROPS_SUPERVECTOROPS_H
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+
+/// Dialect for super-vectorization Ops.
+class SuperVectorOpsDialect : public Dialect {
+public:
+ SuperVectorOpsDialect(MLIRContext *context);
+};
+
+/// VectorTransferReadOp performs a blocking read from a scalar memref
+/// location into a super-vector of the same elemental type. This operation is
+/// called 'read' by opposition to 'load' because the super-vector granularity
+/// is generally not representable with a single hardware register. As a
+/// consequence, memory transfers will generally be required when lowering
+/// VectorTransferReadOp. A VectorTransferReadOp is thus a mid-level abstraction
+/// that supports super-vectorization with non-effecting padding for full-tile
+/// only code.
+//
+/// A vector transfer read has semantics similar to a vector load, with
+/// additional support for:
+/// 1. an optional value of the elemental type of the MemRef. This value
+/// supports non-effecting padding and is inserted in places where the
+/// vector read exceeds the MemRef bounds. If the value is not specified,
+/// the access is statically guaranteed to be within bounds;
+/// 2. an attribute of type AffineMap to specify a slice of the original
+/// MemRef access and its transposition into the super-vector shape.
+/// The permutation_map is an unbounded AffineMap that must
+/// represent a permutation from the MemRef dim space projected onto the
+/// vector dim space.
+/// This permutation_map has as many output dimensions as the vector rank.
+/// However, it is not necessarily full rank on the target space to signify
+/// that broadcast operations will be needed along certain vector
+/// dimensions.
+/// In the limit, one may load a 0-D slice of a memref (i.e. a single
+/// value) into a vector, which corresponds to broadcasting that value in
+/// the whole vector (i.e. a non-constant splat).
+///
+/// Example with full rank permutation_map:
+/// ```mlir
+/// %A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>
+/// ...
+/// %val = `ssa-value` : f32
+/// // let %i, %j, %k, %l be ssa-values of type index
+/// %v0 = vector_transfer_read %src, %i, %j, %k, %l
+/// {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
+/// (memref<?x?x?x?xf32>, index, index, index, index) ->
+/// vector<16x32x64xf32>
+/// %v1 = vector_transfer_read %src, %i, %j, %k, %l, %val
+/// {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
+/// (memref<?x?x?x?xf32>, index, index, index, index, f32) ->
+/// vector<16x32x64xf32>
+/// ```
+///
+/// Example with partial rank permutation_map:
+/// ```mlir
+/// %c0 = constant 0 : index
+/// %A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>
+/// ...
+/// // let %i, %j be ssa-values of type index
+/// %v0 = vector_transfer_read %src, %i, %c0, %c0, %c0
+/// {permutation_map: (d0, d1, d2, d3) -> (0, d1, 0)} :
+/// (memref<?x?x?x?xf32>, index, index, index, index) ->
+/// vector<16x32x64xf32>
+class VectorTransferReadOp
+ : public Op<VectorTransferReadOp, OpTrait::VariadicOperands,
+ OpTrait::OneResult> {
+ enum Offsets : unsigned { MemRefOffset = 0, FirstIndexOffset = 1 };
+
+public:
+ static StringRef getOperationName() { return "vector_transfer_read"; }
+ static StringRef getPermutationMapAttrName() { return "permutation_map"; }
+ static void build(Builder *builder, OperationState *result,
+ VectorType vectorType, SSAValue *srcMemRef,
+ ArrayRef<SSAValue *> srcIndices, AffineMap permutationMap,
+ Optional<SSAValue *> paddingValue = None);
+ VectorType getResultType() const {
+ return getResult()->getType().cast<VectorType>();
+ }
+ SSAValue *getVector() { return getResult(); }
+ const SSAValue *getVector() const { return getResult(); }
+ SSAValue *getMemRef() { return getOperand(Offsets::MemRefOffset); }
+ const SSAValue *getMemRef() const {
+ return getOperand(Offsets::MemRefOffset);
+ }
+ VectorType getVectorType() const { return getResultType(); }
+ MemRefType getMemRefType() const {
+ return getMemRef()->getType().cast<MemRefType>();
+ }
+ llvm::iterator_range<Operation::operand_iterator> getIndices();
+ llvm::iterator_range<Operation::const_operand_iterator> getIndices() const;
+ Optional<SSAValue *> getPaddingValue();
+ Optional<const SSAValue *> getPaddingValue() const;
+ AffineMap getPermutationMap() const;
+
+ static bool parse(OpAsmParser *parser, OperationState *result);
+ void print(OpAsmPrinter *p) const;
+ bool verify() const;
+
+private:
+ friend class Operation;
+ explicit VectorTransferReadOp(const Operation *state) : Op(state) {}
+};
+
+/// VectorTransferWriteOp performs a blocking write from a super-vector to
+/// a scalar memref of the same elemental type. This operation is
+/// called 'write' by opposition to 'store' because the super-vector granularity
+/// is generally not representable with a single hardware register. As a
+/// consequence, memory transfers will generally be required when lowering
+/// VectorTransferWriteOp. A VectorTransferWriteOp is thus a mid-level
+/// abstraction that supports super-vectorization with non-effecting padding for
+/// full-tile only code.
+///
+/// A vector transfer write has semantics similar to a vector store, with
+/// additional support for handling out-of-bounds situations. It is the
+/// responsibility of vector_transfer_write's implementation to ensure the
+/// memory writes are valid. Different implementations may be pertinent
+/// depending on the hardware support including:
+/// 1. predication;
+/// 2. explicit control-flow;
+/// 3. Read-Modify-Write;
+/// 4. writing out of bounds of the memref when the allocation allows it.
+///
+/// Example:
+/// ```mlir
+/// %A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>.
+/// %val = `ssa-value` : vector<16x32x64xf32>
+/// // let %i, %j, %k, %l be ssa-values of type index
+/// vector_transfer_write %val, %src, %i, %j, %k, %l
+/// {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
+/// vector<16x32x64xf32>, memref<?x?x?x?xf32>, index, index, index, index
+/// ```
+class VectorTransferWriteOp
+ : public Op<VectorTransferWriteOp, OpTrait::VariadicOperands,
+ OpTrait::ZeroResult> {
+ enum Offsets : unsigned {
+ VectorOffset = 0,
+ MemRefOffset = 1,
+ FirstIndexOffset = 2
+ };
+
+public:
+ static StringRef getOperationName() { return "vector_transfer_write"; }
+ static StringRef getPermutationMapAttrName() { return "permutation_map"; }
+ static void build(Builder *builder, OperationState *result,
+ SSAValue *srcVector, SSAValue *dstMemRef,
+ ArrayRef<SSAValue *> dstIndices, AffineMap permutationMap);
+ SSAValue *getVector() { return getOperand(Offsets::VectorOffset); }
+ const SSAValue *getVector() const {
+ return getOperand(Offsets::VectorOffset);
+ }
+ VectorType getVectorType() const {
+ return getVector()->getType().cast<VectorType>();
+ }
+ SSAValue *getMemRef() { return getOperand(Offsets::MemRefOffset); }
+ const SSAValue *getMemRef() const {
+ return getOperand(Offsets::MemRefOffset);
+ }
+ MemRefType getMemRefType() const {
+ return getMemRef()->getType().cast<MemRefType>();
+ }
+ llvm::iterator_range<Operation::operand_iterator> getIndices();
+ llvm::iterator_range<Operation::const_operand_iterator> getIndices() const;
+ AffineMap getPermutationMap() const;
+
+ static bool parse(OpAsmParser *parser, OperationState *result);
+ void print(OpAsmPrinter *p) const;
+ bool verify() const;
+
+private:
+ friend class Operation;
+ explicit VectorTransferWriteOp(const Operation *state) : Op(state) {}
+};
+
+} // end namespace mlir
+
+#endif // MLIR_INCLUDE_MLIR_SUPERVECTOROPS_SUPERVECTOROPS_H
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index 4bde64d..5e6bd7f 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -29,6 +29,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Statements.h"
#include "mlir/StandardOps/StandardOps.h"
+#include "mlir/SuperVectorOps/SuperVectorOps.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/MathExtras.h"
@@ -229,7 +230,6 @@
return memRefType.getElementType().template isa<VectorType>();
}
-// TODO(ntv): make the following into MLIR instructions, then use isa<>.
static bool isVectorTransferReadOrWrite(const Statement &stmt) {
const auto *opStmt = cast<OperationStmt>(&stmt);
return opStmt->isa<VectorTransferReadOp>() ||
diff --git a/lib/Analysis/VectorAnalysis.cpp b/lib/Analysis/VectorAnalysis.cpp
index 7bbe2d6..bfef98d 100644
--- a/lib/Analysis/VectorAnalysis.cpp
+++ b/lib/Analysis/VectorAnalysis.cpp
@@ -21,6 +21,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Statements.h"
#include "mlir/StandardOps/StandardOps.h"
+#include "mlir/SuperVectorOps/SuperVectorOps.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/STLExtras.h"
diff --git a/lib/StandardOps/StandardOps.cpp b/lib/StandardOps/StandardOps.cpp
index 7304223..c1ad976 100644
--- a/lib/StandardOps/StandardOps.cpp
+++ b/lib/StandardOps/StandardOps.cpp
@@ -40,8 +40,7 @@
addOperations<AddFOp, AddIOp, AllocOp, CallOp, CallIndirectOp, CmpIOp,
DeallocOp, DimOp, DmaStartOp, DmaWaitOp, ExtractElementOp,
LoadOp, MemRefCastOp, MulFOp, MulIOp, SelectOp, StoreOp, SubFOp,
- SubIOp, TensorCastOp, VectorTransferReadOp,
- VectorTransferWriteOp>();
+ SubIOp, TensorCastOp>();
}
//===----------------------------------------------------------------------===//
@@ -1371,416 +1370,3 @@
return false;
}
-//===----------------------------------------------------------------------===//
-// VectorTransferReadOp
-//===----------------------------------------------------------------------===//
-template <typename EmitFun>
-static bool verifyPermutationMap(AffineMap permutationMap,
- EmitFun emitOpError) {
- SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
- for (auto expr : permutationMap.getResults()) {
- auto dim = expr.dyn_cast<AffineDimExpr>();
- auto zero = expr.dyn_cast<AffineConstantExpr>();
- if (zero) {
- if (zero.getValue() != 0) {
- return emitOpError(
- "requires a projected permutation_map (at most one dim or the zero "
- "constant can appear in each result)");
- }
- continue;
- }
- if (!dim) {
- return emitOpError("requires a projected permutation_map (at most one "
- "dim or the zero constant can appear in each result)");
- }
- if (seen[dim.getPosition()]) {
- return emitOpError(
- "requires a permutation_map that is a permutation (found one dim "
- "used more than once)");
- }
- seen[dim.getPosition()] = true;
- }
- return false;
-}
-
-void VectorTransferReadOp::build(Builder *builder, OperationState *result,
- VectorType vectorType, SSAValue *srcMemRef,
- ArrayRef<SSAValue *> srcIndices,
- AffineMap permutationMap,
- Optional<SSAValue *> paddingValue) {
- result->addOperands(srcMemRef);
- result->addOperands(srcIndices);
- if (paddingValue) {
- result->addOperands({*paddingValue});
- }
- result->addAttribute(getPermutationMapAttrName(),
- builder->getAffineMapAttr(permutationMap));
- result->addTypes(vectorType);
-}
-
-llvm::iterator_range<Operation::operand_iterator>
-VectorTransferReadOp::getIndices() {
- auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
- auto end = begin + getMemRefType().getRank();
- return {begin, end};
-}
-
-llvm::iterator_range<Operation::const_operand_iterator>
-VectorTransferReadOp::getIndices() const {
- auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
- auto end = begin + getMemRefType().getRank();
- return {begin, end};
-}
-
-Optional<SSAValue *> VectorTransferReadOp::getPaddingValue() {
- auto memRefRank = getMemRefType().getRank();
- if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
- return None;
- }
- return Optional<SSAValue *>(
- getOperand(Offsets::FirstIndexOffset + memRefRank));
-}
-
-Optional<const SSAValue *> VectorTransferReadOp::getPaddingValue() const {
- auto memRefRank = getMemRefType().getRank();
- if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
- return None;
- }
- return Optional<const SSAValue *>(
- getOperand(Offsets::FirstIndexOffset + memRefRank));
-}
-
-AffineMap VectorTransferReadOp::getPermutationMap() const {
- return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
-}
-
-void VectorTransferReadOp::print(OpAsmPrinter *p) const {
- *p << getOperationName() << " ";
- p->printOperand(getMemRef());
- *p << ", ";
- p->printOperands(getIndices());
- auto optionalPaddingValue = getPaddingValue();
- if (optionalPaddingValue) {
- *p << ", ";
- p->printOperand(*optionalPaddingValue);
- }
- p->printOptionalAttrDict(getAttrs());
- // Construct the FunctionType and print it.
- llvm::SmallVector<Type, 8> inputs{getMemRefType()};
- // Must have at least one actual index, see verify.
- const SSAValue *firstIndex = *(getIndices().begin());
- Type indexType = firstIndex->getType();
- inputs.append(getMemRefType().getRank(), indexType);
- if (optionalPaddingValue) {
- inputs.push_back((*optionalPaddingValue)->getType());
- }
- *p << " : "
- << FunctionType::get(inputs, {getResultType()}, indexType.getContext());
-}
-
-bool VectorTransferReadOp::parse(OpAsmParser *parser, OperationState *result) {
- SmallVector<OpAsmParser::OperandType, 8> parsedOperands;
- Type type;
-
- // Parsing with support for optional paddingValue.
- auto fail = parser->parseOperandList(parsedOperands) ||
- parser->parseOptionalAttributeDict(result->attributes) ||
- parser->parseColonType(type);
- if (fail) {
- return true;
- }
-
- // Resolution.
- auto funType = type.dyn_cast<FunctionType>();
- if (!funType)
- return parser->emitError(parser->getNameLoc(), "Function type expected");
- if (funType.getNumInputs() < 1)
- return parser->emitError(parser->getNameLoc(),
- "Function type expects at least one input");
- MemRefType memrefType =
- funType.getInput(Offsets::MemRefOffset).dyn_cast<MemRefType>();
- if (!memrefType)
- return parser->emitError(parser->getNameLoc(),
- "MemRef type expected for first input");
- if (funType.getNumResults() < 1)
- return parser->emitError(parser->getNameLoc(),
- "Function type expects exactly one vector result");
- VectorType vectorType = funType.getResult(0).dyn_cast<VectorType>();
- if (!vectorType)
- return parser->emitError(parser->getNameLoc(),
- "Vector type expected for first result");
- if (parsedOperands.size() != funType.getNumInputs())
- return parser->emitError(parser->getNameLoc(),
- "requires " + Twine(funType.getNumInputs()) +
- " operands");
-
- // Extract optional paddingValue.
- OpAsmParser::OperandType memrefInfo = parsedOperands[0];
- // At this point, indexInfo may contain the optional paddingValue, pop it out.
- SmallVector<OpAsmParser::OperandType, 8> indexInfo{
- parsedOperands.begin() + Offsets::FirstIndexOffset, parsedOperands.end()};
- Type paddingType;
- OpAsmParser::OperandType paddingValue;
- bool hasPaddingValue = indexInfo.size() > memrefType.getRank();
- unsigned expectedNumOperands = Offsets::FirstIndexOffset +
- memrefType.getRank() +
- (hasPaddingValue ? 1 : 0);
- if (hasPaddingValue) {
- paddingType = funType.getInputs().back();
- paddingValue = indexInfo.pop_back_val();
- }
- if (funType.getNumInputs() != expectedNumOperands)
- return parser->emitError(
- parser->getNameLoc(),
- "requires actual number of operands to match function type");
-
- auto indexType = parser->getBuilder().getIndexType();
- return parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
- parser->resolveOperands(indexInfo, indexType, result->operands) ||
- (hasPaddingValue && parser->resolveOperand(paddingValue, paddingType,
- result->operands)) ||
- parser->addTypeToList(vectorType, result->types);
-}
-
-bool VectorTransferReadOp::verify() const {
- // Consistency of memref type in function type.
- if (llvm::empty(getOperands())) {
- return emitOpError(
- "requires at least a memref operand followed by 'rank' indices");
- }
- if (!getMemRef()->getType().isa<MemRefType>()) {
- return emitOpError("requires a memref as first operand");
- }
- // Consistency of vector type in function type.
- if (!getResult()->getType().isa<VectorType>()) {
- return emitOpError("should have a vector result type in function type: "
- "(memref_type [, elemental_type]) -> vector_type");
- }
- // Consistency of elemental types in memref and vector.
- MemRefType memrefType = getMemRefType();
- VectorType vectorType = getResultType();
- if (memrefType.getElementType() != vectorType.getElementType())
- return emitOpError(
- "requires memref and vector types of the same elemental type");
- // Consistency of number of input types.
- auto optionalPaddingValue = getPaddingValue();
- unsigned expectedNumOperands = Offsets::FirstIndexOffset +
- memrefType.getRank() +
- (optionalPaddingValue ? 1 : 0);
- // Checks on the actual operands and their types.
- if (getNumOperands() != expectedNumOperands) {
- return emitOpError("expects " + Twine(expectedNumOperands) +
- " operands to match the types");
- }
- // Consistency of padding value with vector type.
- if (optionalPaddingValue) {
- auto paddingValue = *optionalPaddingValue;
- auto elementalType = paddingValue->getType();
- if (!VectorType::isValidElementType(elementalType)) {
- return emitOpError("requires valid padding vector elemental type");
- }
- if (elementalType != vectorType.getElementType()) {
- return emitOpError(
- "requires formal padding and vector of the same elemental type");
- }
- }
- // Consistency of indices types.
- unsigned numIndices = 0;
- for (auto *idx : getIndices()) {
- if (!idx->getType().isIndex()) {
- return emitOpError(
- "index to vector_transfer_read must have 'index' type");
- }
- ++numIndices;
- }
- if (numIndices != memrefType.getRank()) {
- return emitOpError("requires at least a memref operand followed by " +
- Twine(memrefType.getRank()) + " indices");
- }
-
- // Consistency of AffineMap attribute.
- if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) {
- return emitOpError("requires an AffineMapAttr named 'permutation_map'");
- }
- auto permutationMap = getPermutationMap();
- if (!permutationMap.getRangeSizes().empty()) {
- return emitOpError("requires an unbounded permutation_map");
- }
- if (permutationMap.getNumSymbols() != 0) {
- return emitOpError("requires a permutation_map without symbols");
- }
- if (permutationMap.getNumInputs() != memrefType.getRank()) {
- return emitOpError("requires a permutation_map with input dims of the "
- "same rank as the memref type");
- }
- if (permutationMap.getNumResults() != vectorType.getRank()) {
- return emitOpError("requires a permutation_map with result dims of the "
- "same rank as the vector type (" +
- Twine(permutationMap.getNumResults()) + " vs " +
- Twine(vectorType.getRank()));
- }
- return verifyPermutationMap(permutationMap,
- [this](Twine t) { return emitOpError(t); });
-}
-
-//===----------------------------------------------------------------------===//
-// VectorTransferWriteOp
-//===----------------------------------------------------------------------===//
-void VectorTransferWriteOp::build(Builder *builder, OperationState *result,
- SSAValue *srcVector, SSAValue *dstMemRef,
- ArrayRef<SSAValue *> dstIndices,
- AffineMap permutationMap) {
- result->addOperands({srcVector, dstMemRef});
- result->addOperands(dstIndices);
- result->addAttribute(getPermutationMapAttrName(),
- builder->getAffineMapAttr(permutationMap));
-}
-
-llvm::iterator_range<Operation::operand_iterator>
-VectorTransferWriteOp::getIndices() {
- auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
- auto end = begin + getMemRefType().getRank();
- return {begin, end};
-}
-
-llvm::iterator_range<Operation::const_operand_iterator>
-VectorTransferWriteOp::getIndices() const {
- auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
- auto end = begin + getMemRefType().getRank();
- return {begin, end};
-}
-
-AffineMap VectorTransferWriteOp::getPermutationMap() const {
- return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
-}
-
-void VectorTransferWriteOp::print(OpAsmPrinter *p) const {
- *p << getOperationName();
- *p << " " << *getVector();
- *p << ", " << *getMemRef();
- *p << ", ";
- p->printOperands(getIndices());
- p->printOptionalAttrDict(getAttrs());
- Type indexType = (*getIndices().begin())->getType();
- *p << " : ";
- p->printType(getVectorType());
- *p << ", ";
- p->printType(getMemRefType());
- for (unsigned r = 0, n = getMemRefType().getRank(); r < n; ++r) {
- *p << ", ";
- p->printType(indexType);
- }
-}
-
-bool VectorTransferWriteOp::parse(OpAsmParser *parser, OperationState *result) {
- SmallVector<OpAsmParser::OperandType, 8> parsedOperands;
- SmallVector<Type, 8> types;
-
- // Parsing with support for optional paddingValue.
- auto fail = parser->parseOperandList(parsedOperands) ||
- parser->parseOptionalAttributeDict(result->attributes) ||
- parser->parseColonTypeList(types);
- if (fail) {
- return true;
- }
-
- // Resolution.
- if (parsedOperands.size() != types.size())
- return parser->emitError(
- parser->getNameLoc(),
- "requires number of operands and input types to match");
- if (parsedOperands.size() < Offsets::FirstIndexOffset)
- return parser->emitError(parser->getNameLoc(),
- "requires at least vector and memref operands");
- VectorType vectorType = types[Offsets::VectorOffset].dyn_cast<VectorType>();
- if (!vectorType)
- return parser->emitError(parser->getNameLoc(),
- "Vector type expected for first input type");
- MemRefType memrefType = types[Offsets::MemRefOffset].dyn_cast<MemRefType>();
- if (!memrefType)
- return parser->emitError(parser->getNameLoc(),
- "MemRef type expected for second input type");
-
- unsigned expectedNumOperands =
- Offsets::FirstIndexOffset + memrefType.getRank();
- if (parsedOperands.size() != expectedNumOperands)
- return parser->emitError(parser->getNameLoc(),
- "requires " + Twine(expectedNumOperands) +
- " operands");
-
- OpAsmParser::OperandType vectorInfo = parsedOperands[Offsets::VectorOffset];
- OpAsmParser::OperandType memrefInfo = parsedOperands[Offsets::MemRefOffset];
- SmallVector<OpAsmParser::OperandType, 8> indexInfo{
- parsedOperands.begin() + Offsets::FirstIndexOffset, parsedOperands.end()};
- auto indexType = parser->getBuilder().getIndexType();
- return parser->resolveOperand(vectorInfo, vectorType, result->operands) ||
- parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
- parser->resolveOperands(indexInfo, indexType, result->operands);
-}
-
-bool VectorTransferWriteOp::verify() const {
- // Consistency of memref type in function type.
- if (llvm::empty(getOperands())) {
- return emitOpError(
- "requires at least a memref operand followed by 'rank' indices");
- }
- if (!getMemRef()->getType().isa<MemRefType>()) {
- return emitOpError("requires a memref first operand");
- }
- // Consistency of vector type in function type.
- if (!getVector()->getType().isa<VectorType>()) {
- return emitOpError("should have a vector input type in function type: "
- "(vector_type, memref_type [, elemental_type]) -> ()");
- }
- // Consistency of elemental types in memref and vector.
- MemRefType memrefType = getMemRefType();
- VectorType vectorType = getVectorType();
- if (memrefType.getElementType() != vectorType.getElementType())
- return emitOpError(
- "requires memref and vector types of the same elemental type");
- // Consistency of number of input types.
- unsigned expectedNumOperands =
- Offsets::FirstIndexOffset + memrefType.getRank();
- // Checks on the actual operands and their types.
- if (getNumOperands() != expectedNumOperands) {
- return emitOpError("expects " + Twine(expectedNumOperands) +
- " operands to match the types");
- }
- // Consistency of indices types.
- unsigned numIndices = 0;
- for (auto *idx : getIndices()) {
- if (!idx->getType().isIndex()) {
- return emitOpError(
- "index to vector_transfer_write must have 'index' type");
- }
- numIndices++;
- }
- if (numIndices != memrefType.getRank()) {
- return emitOpError("requires at least a memref operand followed by " +
- Twine(memrefType.getRank()) + " indices");
- }
-
- // Consistency of AffineMap attribute.
- if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) {
- return emitOpError("requires an AffineMapAttr named 'permutation_map'");
- }
- auto permutationMap = getPermutationMap();
- if (!permutationMap.getRangeSizes().empty()) {
- return emitOpError("requires an unbounded permutation_map");
- }
- if (permutationMap.getNumSymbols() != 0) {
- return emitOpError("requires a permutation_map without symbols");
- }
- if (permutationMap.getNumInputs() != memrefType.getRank()) {
- return emitOpError("requires a permutation_map with input dims of the "
- "same rank as the memref type");
- }
- if (permutationMap.getNumResults() != vectorType.getRank()) {
- return emitOpError("requires a permutation_map with result dims of the "
- "same rank as the vector type (" +
- Twine(permutationMap.getNumResults()) + " vs " +
- Twine(vectorType.getRank()));
- }
- return verifyPermutationMap(permutationMap,
- [this](Twine t) { return emitOpError(t); });
-}
diff --git a/lib/SuperVectorOps/DialectRegistration.cpp b/lib/SuperVectorOps/DialectRegistration.cpp
new file mode 100644
index 0000000..e8b8ee2
--- /dev/null
+++ b/lib/SuperVectorOps/DialectRegistration.cpp
@@ -0,0 +1,22 @@
+//===- DialectRegistration.cpp - Register super vectorization dialect -----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/SuperVectorOps/SuperVectorOps.h"
+using namespace mlir;
+
+// Static initialization for SuperVectorOps dialect registration.
+static DialectRegistration<SuperVectorOpsDialect> SuperVectorOps;
diff --git a/lib/SuperVectorOps/SuperVectorOps.cpp b/lib/SuperVectorOps/SuperVectorOps.cpp
new file mode 100644
index 0000000..5e19d2b
--- /dev/null
+++ b/lib/SuperVectorOps/SuperVectorOps.cpp
@@ -0,0 +1,452 @@
+//===- SuperVectorOps.cpp - MLIR Super Vectorizer Operations---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements convenience types for working with super-vectorization
+// operations, in particular super-vector loads and stores.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/SuperVectorOps/SuperVectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Support/LLVM.h"
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// SuperVectorOpsDialect
+//===----------------------------------------------------------------------===//
+
+SuperVectorOpsDialect::SuperVectorOpsDialect(MLIRContext *context)
+ : Dialect(/*opPrefix=*/"", context) {
+ addOperations<VectorTransferReadOp, VectorTransferWriteOp>();
+}
+
+//===----------------------------------------------------------------------===//
+// VectorTransferReadOp
+//===----------------------------------------------------------------------===//
+template <typename EmitFun>
+static bool verifyPermutationMap(AffineMap permutationMap,
+ EmitFun emitOpError) {
+ SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
+ for (auto expr : permutationMap.getResults()) {
+ auto dim = expr.dyn_cast<AffineDimExpr>();
+ auto zero = expr.dyn_cast<AffineConstantExpr>();
+ if (zero) {
+ if (zero.getValue() != 0) {
+ return emitOpError(
+ "requires a projected permutation_map (at most one dim or the zero "
+ "constant can appear in each result)");
+ }
+ continue;
+ }
+ if (!dim) {
+ return emitOpError("requires a projected permutation_map (at most one "
+ "dim or the zero constant can appear in each result)");
+ }
+ if (seen[dim.getPosition()]) {
+ return emitOpError(
+ "requires a permutation_map that is a permutation (found one dim "
+ "used more than once)");
+ }
+ seen[dim.getPosition()] = true;
+ }
+ return false;
+}
+
+void VectorTransferReadOp::build(Builder *builder, OperationState *result,
+ VectorType vectorType, SSAValue *srcMemRef,
+ ArrayRef<SSAValue *> srcIndices,
+ AffineMap permutationMap,
+ Optional<SSAValue *> paddingValue) {
+ result->addOperands(srcMemRef);
+ result->addOperands(srcIndices);
+ if (paddingValue) {
+ result->addOperands({*paddingValue});
+ }
+ result->addAttribute(getPermutationMapAttrName(),
+ builder->getAffineMapAttr(permutationMap));
+ result->addTypes(vectorType);
+}
+
+llvm::iterator_range<Operation::operand_iterator>
+VectorTransferReadOp::getIndices() {
+ auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
+ auto end = begin + getMemRefType().getRank();
+ return {begin, end};
+}
+
+llvm::iterator_range<Operation::const_operand_iterator>
+VectorTransferReadOp::getIndices() const {
+ auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
+ auto end = begin + getMemRefType().getRank();
+ return {begin, end};
+}
+
+Optional<SSAValue *> VectorTransferReadOp::getPaddingValue() {
+ auto memRefRank = getMemRefType().getRank();
+ if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
+ return None;
+ }
+ return Optional<SSAValue *>(
+ getOperand(Offsets::FirstIndexOffset + memRefRank));
+}
+
+Optional<const SSAValue *> VectorTransferReadOp::getPaddingValue() const {
+ auto memRefRank = getMemRefType().getRank();
+ if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
+ return None;
+ }
+ return Optional<const SSAValue *>(
+ getOperand(Offsets::FirstIndexOffset + memRefRank));
+}
+
+AffineMap VectorTransferReadOp::getPermutationMap() const {
+ return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
+}
+
+void VectorTransferReadOp::print(OpAsmPrinter *p) const {
+ *p << getOperationName() << " ";
+ p->printOperand(getMemRef());
+ *p << ", ";
+ p->printOperands(getIndices());
+ auto optionalPaddingValue = getPaddingValue();
+ if (optionalPaddingValue) {
+ *p << ", ";
+ p->printOperand(*optionalPaddingValue);
+ }
+ p->printOptionalAttrDict(getAttrs());
+ // Construct the FunctionType and print it.
+ llvm::SmallVector<Type, 8> inputs{getMemRefType()};
+ // Must have at least one actual index, see verify.
+ const SSAValue *firstIndex = *(getIndices().begin());
+ Type indexType = firstIndex->getType();
+ inputs.append(getMemRefType().getRank(), indexType);
+ if (optionalPaddingValue) {
+ inputs.push_back((*optionalPaddingValue)->getType());
+ }
+ *p << " : "
+ << FunctionType::get(inputs, {getResultType()}, indexType.getContext());
+}
+
+bool VectorTransferReadOp::parse(OpAsmParser *parser, OperationState *result) {
+ SmallVector<OpAsmParser::OperandType, 8> parsedOperands;
+ Type type;
+
+ // Parsing with support for optional paddingValue.
+ auto fail = parser->parseOperandList(parsedOperands) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type);
+ if (fail) {
+ return true;
+ }
+
+ // Resolution.
+ auto funType = type.dyn_cast<FunctionType>();
+ if (!funType)
+ return parser->emitError(parser->getNameLoc(), "Function type expected");
+ if (funType.getNumInputs() < 1)
+ return parser->emitError(parser->getNameLoc(),
+ "Function type expects at least one input");
+ MemRefType memrefType =
+ funType.getInput(Offsets::MemRefOffset).dyn_cast<MemRefType>();
+ if (!memrefType)
+ return parser->emitError(parser->getNameLoc(),
+ "MemRef type expected for first input");
+ if (funType.getNumResults() < 1)
+ return parser->emitError(parser->getNameLoc(),
+ "Function type expects exactly one vector result");
+ VectorType vectorType = funType.getResult(0).dyn_cast<VectorType>();
+ if (!vectorType)
+ return parser->emitError(parser->getNameLoc(),
+ "Vector type expected for first result");
+ if (parsedOperands.size() != funType.getNumInputs())
+ return parser->emitError(parser->getNameLoc(),
+ "requires " + Twine(funType.getNumInputs()) +
+ " operands");
+
+ // Extract optional paddingValue.
+ OpAsmParser::OperandType memrefInfo = parsedOperands[0];
+ // At this point, indexInfo may contain the optional paddingValue, pop it out.
+ SmallVector<OpAsmParser::OperandType, 8> indexInfo{
+ parsedOperands.begin() + Offsets::FirstIndexOffset, parsedOperands.end()};
+ Type paddingType;
+ OpAsmParser::OperandType paddingValue;
+ bool hasPaddingValue = indexInfo.size() > memrefType.getRank();
+ unsigned expectedNumOperands = Offsets::FirstIndexOffset +
+ memrefType.getRank() +
+ (hasPaddingValue ? 1 : 0);
+ if (hasPaddingValue) {
+ paddingType = funType.getInputs().back();
+ paddingValue = indexInfo.pop_back_val();
+ }
+ if (funType.getNumInputs() != expectedNumOperands)
+ return parser->emitError(
+ parser->getNameLoc(),
+ "requires actual number of operands to match function type");
+
+ auto indexType = parser->getBuilder().getIndexType();
+ return parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
+ parser->resolveOperands(indexInfo, indexType, result->operands) ||
+ (hasPaddingValue && parser->resolveOperand(paddingValue, paddingType,
+ result->operands)) ||
+ parser->addTypeToList(vectorType, result->types);
+}
+
+bool VectorTransferReadOp::verify() const {
+ // Consistency of memref type in function type.
+ if (llvm::empty(getOperands())) {
+ return emitOpError(
+ "requires at least a memref operand followed by 'rank' indices");
+ }
+ if (!getMemRef()->getType().isa<MemRefType>()) {
+ return emitOpError("requires a memref as first operand");
+ }
+ // Consistency of vector type in function type.
+ if (!getResult()->getType().isa<VectorType>()) {
+ return emitOpError("should have a vector result type in function type: "
+ "(memref_type [, elemental_type]) -> vector_type");
+ }
+ // Consistency of elemental types in memref and vector.
+ MemRefType memrefType = getMemRefType();
+ VectorType vectorType = getResultType();
+ if (memrefType.getElementType() != vectorType.getElementType())
+ return emitOpError(
+ "requires memref and vector types of the same elemental type");
+ // Consistency of number of input types.
+ auto optionalPaddingValue = getPaddingValue();
+ unsigned expectedNumOperands = Offsets::FirstIndexOffset +
+ memrefType.getRank() +
+ (optionalPaddingValue ? 1 : 0);
+ // Checks on the actual operands and their types.
+ if (getNumOperands() != expectedNumOperands) {
+ return emitOpError("expects " + Twine(expectedNumOperands) +
+ " operands to match the types");
+ }
+ // Consistency of padding value with vector type.
+ if (optionalPaddingValue) {
+ auto paddingValue = *optionalPaddingValue;
+ auto elementalType = paddingValue->getType();
+ if (!VectorType::isValidElementType(elementalType)) {
+ return emitOpError("requires valid padding vector elemental type");
+ }
+ if (elementalType != vectorType.getElementType()) {
+ return emitOpError(
+ "requires formal padding and vector of the same elemental type");
+ }
+ }
+ // Consistency of indices types.
+ unsigned numIndices = 0;
+ for (auto *idx : getIndices()) {
+ if (!idx->getType().isIndex()) {
+ return emitOpError(
+ "index to vector_transfer_read must have 'index' type");
+ }
+ ++numIndices;
+ }
+ if (numIndices != memrefType.getRank()) {
+ return emitOpError("requires at least a memref operand followed by " +
+ Twine(memrefType.getRank()) + " indices");
+ }
+
+ // Consistency of AffineMap attribute.
+ if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) {
+ return emitOpError("requires an AffineMapAttr named 'permutation_map'");
+ }
+ auto permutationMap = getPermutationMap();
+ if (!permutationMap.getRangeSizes().empty()) {
+ return emitOpError("requires an unbounded permutation_map");
+ }
+ if (permutationMap.getNumSymbols() != 0) {
+ return emitOpError("requires a permutation_map without symbols");
+ }
+ if (permutationMap.getNumInputs() != memrefType.getRank()) {
+ return emitOpError("requires a permutation_map with input dims of the "
+ "same rank as the memref type");
+ }
+ if (permutationMap.getNumResults() != vectorType.getRank()) {
+ return emitOpError("requires a permutation_map with result dims of the "
+ "same rank as the vector type (" +
+ Twine(permutationMap.getNumResults()) + " vs " +
+ Twine(vectorType.getRank()));
+ }
+ return verifyPermutationMap(permutationMap,
+ [this](Twine t) { return emitOpError(t); });
+}
+
+//===----------------------------------------------------------------------===//
+// VectorTransferWriteOp
+//===----------------------------------------------------------------------===//
+void VectorTransferWriteOp::build(Builder *builder, OperationState *result,
+ SSAValue *srcVector, SSAValue *dstMemRef,
+ ArrayRef<SSAValue *> dstIndices,
+ AffineMap permutationMap) {
+ result->addOperands({srcVector, dstMemRef});
+ result->addOperands(dstIndices);
+ result->addAttribute(getPermutationMapAttrName(),
+ builder->getAffineMapAttr(permutationMap));
+}
+
+llvm::iterator_range<Operation::operand_iterator>
+VectorTransferWriteOp::getIndices() {
+ auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
+ auto end = begin + getMemRefType().getRank();
+ return {begin, end};
+}
+
+llvm::iterator_range<Operation::const_operand_iterator>
+VectorTransferWriteOp::getIndices() const {
+ auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
+ auto end = begin + getMemRefType().getRank();
+ return {begin, end};
+}
+
+AffineMap VectorTransferWriteOp::getPermutationMap() const {
+ return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
+}
+
+void VectorTransferWriteOp::print(OpAsmPrinter *p) const {
+ *p << getOperationName();
+ *p << " " << *getVector();
+ *p << ", " << *getMemRef();
+ *p << ", ";
+ p->printOperands(getIndices());
+ p->printOptionalAttrDict(getAttrs());
+ Type indexType = (*getIndices().begin())->getType();
+ *p << " : ";
+ p->printType(getVectorType());
+ *p << ", ";
+ p->printType(getMemRefType());
+ for (unsigned r = 0, n = getMemRefType().getRank(); r < n; ++r) {
+ *p << ", ";
+ p->printType(indexType);
+ }
+}
+
+bool VectorTransferWriteOp::parse(OpAsmParser *parser, OperationState *result) {
+ SmallVector<OpAsmParser::OperandType, 8> parsedOperands;
+ SmallVector<Type, 8> types;
+
+ // Parsing with support for optional paddingValue.
+ auto fail = parser->parseOperandList(parsedOperands) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonTypeList(types);
+ if (fail) {
+ return true;
+ }
+
+ // Resolution.
+ if (parsedOperands.size() != types.size())
+ return parser->emitError(
+ parser->getNameLoc(),
+ "requires number of operands and input types to match");
+ if (parsedOperands.size() < Offsets::FirstIndexOffset)
+ return parser->emitError(parser->getNameLoc(),
+ "requires at least vector and memref operands");
+ VectorType vectorType = types[Offsets::VectorOffset].dyn_cast<VectorType>();
+ if (!vectorType)
+ return parser->emitError(parser->getNameLoc(),
+ "Vector type expected for first input type");
+ MemRefType memrefType = types[Offsets::MemRefOffset].dyn_cast<MemRefType>();
+ if (!memrefType)
+ return parser->emitError(parser->getNameLoc(),
+ "MemRef type expected for second input type");
+
+ unsigned expectedNumOperands =
+ Offsets::FirstIndexOffset + memrefType.getRank();
+ if (parsedOperands.size() != expectedNumOperands)
+ return parser->emitError(parser->getNameLoc(),
+ "requires " + Twine(expectedNumOperands) +
+ " operands");
+
+ OpAsmParser::OperandType vectorInfo = parsedOperands[Offsets::VectorOffset];
+ OpAsmParser::OperandType memrefInfo = parsedOperands[Offsets::MemRefOffset];
+ SmallVector<OpAsmParser::OperandType, 8> indexInfo{
+ parsedOperands.begin() + Offsets::FirstIndexOffset, parsedOperands.end()};
+ auto indexType = parser->getBuilder().getIndexType();
+ return parser->resolveOperand(vectorInfo, vectorType, result->operands) ||
+ parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
+ parser->resolveOperands(indexInfo, indexType, result->operands);
+}
+
+bool VectorTransferWriteOp::verify() const {
+ // Consistency of memref type in function type.
+ if (llvm::empty(getOperands())) {
+ return emitOpError(
+ "requires at least a memref operand followed by 'rank' indices");
+ }
+ if (!getMemRef()->getType().isa<MemRefType>()) {
+ return emitOpError("requires a memref first operand");
+ }
+ // Consistency of vector type in function type.
+ if (!getVector()->getType().isa<VectorType>()) {
+ return emitOpError("should have a vector input type in function type: "
+ "(vector_type, memref_type [, elemental_type]) -> ()");
+ }
+ // Consistency of elemental types in memref and vector.
+ MemRefType memrefType = getMemRefType();
+ VectorType vectorType = getVectorType();
+ if (memrefType.getElementType() != vectorType.getElementType())
+ return emitOpError(
+ "requires memref and vector types of the same elemental type");
+ // Consistency of number of input types.
+ unsigned expectedNumOperands =
+ Offsets::FirstIndexOffset + memrefType.getRank();
+ // Checks on the actual operands and their types.
+ if (getNumOperands() != expectedNumOperands) {
+ return emitOpError("expects " + Twine(expectedNumOperands) +
+ " operands to match the types");
+ }
+ // Consistency of indices types.
+ unsigned numIndices = 0;
+ for (auto *idx : getIndices()) {
+ if (!idx->getType().isIndex()) {
+ return emitOpError(
+ "index to vector_transfer_write must have 'index' type");
+ }
+ numIndices++;
+ }
+ if (numIndices != memrefType.getRank()) {
+ return emitOpError("requires at least a memref operand followed by " +
+ Twine(memrefType.getRank()) + " indices");
+ }
+
+ // Consistency of AffineMap attribute.
+ if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) {
+ return emitOpError("requires an AffineMapAttr named 'permutation_map'");
+ }
+ auto permutationMap = getPermutationMap();
+ if (!permutationMap.getRangeSizes().empty()) {
+ return emitOpError("requires an unbounded permutation_map");
+ }
+ if (permutationMap.getNumSymbols() != 0) {
+ return emitOpError("requires a permutation_map without symbols");
+ }
+ if (permutationMap.getNumInputs() != memrefType.getRank()) {
+ return emitOpError("requires a permutation_map with input dims of the "
+ "same rank as the memref type");
+ }
+ if (permutationMap.getNumResults() != vectorType.getRank()) {
+ return emitOpError("requires a permutation_map with result dims of the "
+ "same rank as the vector type (" +
+ Twine(permutationMap.getNumResults()) + " vs " +
+ Twine(vectorType.getRank()));
+ }
+ return verifyPermutationMap(permutationMap,
+ [this](Twine t) { return emitOpError(t); });
+}
diff --git a/lib/Transforms/LowerVectorTransfers.cpp b/lib/Transforms/LowerVectorTransfers.cpp
index f2f6716..a49a5c5 100644
--- a/lib/Transforms/LowerVectorTransfers.cpp
+++ b/lib/Transforms/LowerVectorTransfers.cpp
@@ -35,6 +35,7 @@
#include "mlir/IR/Types.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
+#include "mlir/SuperVectorOps/SuperVectorOps.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/Passes.h"
diff --git a/lib/Transforms/MaterializeVectors.cpp b/lib/Transforms/MaterializeVectors.cpp
index 4c81eb9..1199e97 100644
--- a/lib/Transforms/MaterializeVectors.cpp
+++ b/lib/Transforms/MaterializeVectors.cpp
@@ -38,6 +38,7 @@
#include "mlir/IR/Types.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
+#include "mlir/SuperVectorOps/SuperVectorOps.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/Passes.h"
diff --git a/lib/Transforms/Vectorize.cpp b/lib/Transforms/Vectorize.cpp
index 053fd3c..2f74e56 100644
--- a/lib/Transforms/Vectorize.cpp
+++ b/lib/Transforms/Vectorize.cpp
@@ -32,6 +32,7 @@
#include "mlir/IR/Types.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
+#include "mlir/SuperVectorOps/SuperVectorOps.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/Passes.h"