Adds MemRef type and adds support for parsing memref affine map composition.
PiperOrigin-RevId: 204756982
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 2dd2a3b..7347a81 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -303,7 +303,10 @@
llvm::errs() << "\n";
}
-void AffineMap::dump() const { print(llvm::errs()); }
+void AffineMap::dump() const {
+ print(llvm::errs());
+ llvm::errs() << "\n";
+}
void AffineExpr::dump() const {
print(llvm::errs());
@@ -393,7 +396,6 @@
os << ")";
if (!isBounded()) {
- os << "\n";
return;
}
@@ -401,7 +403,7 @@
os << " size (";
interleave(getRangeSizes(), [&](AffineExpr *expr) { os << *expr; },
[&]() { os << ", "; });
- os << ")\n";
+ os << ")";
}
void BasicBlock::print(raw_ostream &os) const {
@@ -449,6 +451,7 @@
for (auto *map : affineMapList) {
os << "#" << id++ << " = ";
map->print(os);
+ os << '\n';
}
for (auto *fn : functionList)
fn->print(os);
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index 4448b16..eb06e94 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -113,6 +113,30 @@
}
};
+struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType*> {
+ // MemRefs are uniqued based on their element type, shape, affine map
+ // composition, and memory space.
+ using KeyTy = std::tuple<Type*, ArrayRef<int>, ArrayRef<AffineMap*>,
+ unsigned>;
+ using DenseMapInfo<MemRefType*>::getHashValue;
+ using DenseMapInfo<MemRefType*>::isEqual;
+
+ static unsigned getHashValue(KeyTy key) {
+ return hash_combine(
+ DenseMapInfo<Type*>::getHashValue(std::get<0>(key)),
+ hash_combine_range(std::get<1>(key).begin(), std::get<1>(key).end()),
+ hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
+ std::get<3>(key));
+ }
+
+ static bool isEqual(const KeyTy &lhs, const MemRefType *rhs) {
+ if (rhs == getEmptyKey() || rhs == getTombstoneKey())
+ return false;
+ return lhs == std::make_tuple(rhs->getElementType(), rhs->getShape(),
+ rhs->getAffineMaps(), rhs->getMemorySpace());
+ }
+};
+
struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttr*> {
// Array attributes are uniqued based on their elements.
using KeyTy = ArrayRef<Attribute*>;
@@ -195,6 +219,10 @@
/// Unranked tensor type uniquing.
DenseMap<Type*, UnrankedTensorType*> unrankedTensors;
+ /// MemRef type uniquing.
+ using MemRefTypeSet = DenseSet<MemRefType*, MemRefTypeKeyInfo>;
+ MemRefTypeSet memrefs;
+
// Attribute uniquing.
BoolAttr *boolAttrs[2] = { nullptr };
DenseMap<int64_t, IntegerAttr*> integerAttrs;
@@ -403,6 +431,39 @@
return result;
}
+MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
+ ArrayRef<AffineMap*> affineMapComposition,
+ unsigned memorySpace) {
+ auto *context = elementType->getContext();
+ auto &impl = context->getImpl();
+
+ // Look to see if we already have this memref type.
+ auto key = std::make_tuple(elementType, shape, affineMapComposition,
+ memorySpace);
+ auto existing = impl.memrefs.insert_as(nullptr, key);
+
+ // If we already have it, return that value.
+ if (!existing.second)
+ return *existing.first;
+
+ // On the first use, we allocate them into the bump pointer.
+ auto *result = impl.allocator.Allocate<MemRefType>();
+
+ // Copy the shape into the bump pointer.
+ shape = impl.copyInto(shape);
+
+ // Copy the affine map composition into the bump pointer.
+ // TODO(andydavis) Assert that the structure of the composition is valid.
+ affineMapComposition = impl.copyInto(ArrayRef<AffineMap*>(
+ affineMapComposition));
+
+ // Initialize the memory using placement new.
+ new (result) MemRefType(shape, elementType, affineMapComposition, memorySpace,
+ context);
+ // Cache and return it.
+ return *existing.first = result;
+}
+
//===----------------------------------------------------------------------===//
// Attribute uniquing
//===----------------------------------------------------------------------===//
diff --git a/lib/IR/Types.cpp b/lib/IR/Types.cpp
index 0860772..1b7d1a6 100644
--- a/lib/IR/Types.cpp
+++ b/lib/IR/Types.cpp
@@ -16,6 +16,7 @@
// =============================================================================
#include "mlir/IR/Types.h"
+#include "mlir/IR/AffineMap.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Support/STLExtras.h"
using namespace mlir;
@@ -51,6 +52,19 @@
: TensorType(Kind::UnrankedTensor, elementType, context) {
}
+MemRefType::MemRefType(ArrayRef<int> shape, Type *elementType,
+ ArrayRef<AffineMap*> affineMapList,
+ unsigned memorySpace, MLIRContext *context)
+ : Type(Kind::MemRef, context, shape.size()),
+ elementType(elementType), shapeElements(shape.data()),
+ numAffineMaps(affineMapList.size()), affineMapList(affineMapList.data()),
+ memorySpace(memorySpace) {
+}
+
+ArrayRef<AffineMap*> MemRefType::getAffineMaps() const {
+ return ArrayRef<AffineMap*>(affineMapList, numAffineMaps);
+}
+
void Type::print(raw_ostream &os) const {
switch (getKind()) {
case Kind::AffineInt: os << "affineint"; return;
@@ -109,6 +123,25 @@
os << "tensor<??" << *v->getElementType() << '>';
return;
}
+ case Kind::MemRef: {
+ auto *v = cast<MemRefType>(this);
+ os << "memref<";
+ for (auto dim : v->getShape()) {
+ if (dim < 0)
+ os << '?';
+ else
+ os << dim;
+ os << 'x';
+ }
+ os << *v->getElementType();
+ for (auto map : v->getAffineMaps()) {
+ os << ", ";
+ map->print(os);
+ }
+ os << ", " << v->getMemorySpace();
+ os << '>';
+ return;
+ }
}
}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index e1d184d..626b92f 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -132,9 +132,10 @@
return true;
}
- ParseResult parseCommaSeparatedList(Token::Kind rightToken,
- const std::function<ParseResult()> &parseElement,
- bool allowEmptyList = true);
+ ParseResult parseCommaSeparatedList(
+ Token::Kind rightToken,
+ const std::function<ParseResult()> &parseElement,
+ bool allowEmptyList = true);
// We have two forms of parsing methods - those that return a non-null
// pointer on success, and those that return a ParseResult to indicate whether
@@ -158,6 +159,7 @@
// Polyhedral structures.
AffineMap *parseAffineMapInline();
+ AffineMap *parseAffineMapReference();
// SSA
ParseResult parseSSAUse();
@@ -414,14 +416,50 @@
if (!elementType)
return nullptr;
- // TODO: Parse semi-affine-map-composition.
- // TODO: Parse memory-space.
+ if (!consumeIf(Token::comma))
+ return (emitError("expected ',' in memref type"), nullptr);
- if (!consumeIf(Token::greater))
- return (emitError("expected '>' in memref type"), nullptr);
+ // Parse semi-affine-map-composition.
+ SmallVector<AffineMap*, 2> affineMapComposition;
+ unsigned memorySpace;
+ bool parsedMemorySpace = false;
- // FIXME: Add an IR representation for memref types.
- return builder.getIntegerType(1);
+ auto parseElt = [&]() -> ParseResult {
+ if (getToken().is(Token::integer)) {
+ // Parse memory space.
+ if (parsedMemorySpace)
+ return emitError("multiple memory spaces specified in memref type");
+ auto v = getToken().getUnsignedIntegerValue();
+ if (!v.hasValue())
+ return emitError("invalid memory space in memref type");
+ memorySpace = v.getValue();
+ consumeToken(Token::integer);
+ parsedMemorySpace = true;
+ } else {
+ // Parse affine map.
+ if (parsedMemorySpace)
+ return emitError("affine map after memory space in memref type");
+ auto* affineMap = parseAffineMapReference();
+ if (affineMap == nullptr)
+ return ParseFailure;
+ affineMapComposition.push_back(affineMap);
+ }
+ return ParseSuccess;
+ };
+
+ // Parse comma separated list of affine maps, followed by memory space.
+ if (parseCommaSeparatedList(Token::greater, parseElt,
+ /*allowEmptyList=*/false)) {
+ return nullptr;
+ }
+ // Check that MemRef type specifies at least one affine map in composition.
+ if (affineMapComposition.empty())
+ return (emitError("expected semi-affine-map in memref type"), nullptr);
+ if (!parsedMemorySpace)
+ return (emitError("expected memory space in memref type"), nullptr);
+
+ return MemRefType::get(dimensions, elementType, affineMapComposition,
+ memorySpace);
}
/// Parse a function type.
@@ -1106,6 +1144,20 @@
return AffineMapParser(state).parseAffineMapInline();
}
+AffineMap *Parser::parseAffineMapReference() {
+ if (getToken().is(Token::hash_identifier)) {
+ // Parse affine map identifier and verify that it exists.
+ StringRef affineMapId = getTokenSpelling().drop_front();
+ if (getState().affineMapDefinitions.count(affineMapId) == 0)
+ return (emitError("undefined affine map id '" + affineMapId + "'"),
+ nullptr);
+ consumeToken(Token::hash_identifier);
+ return getState().affineMapDefinitions[affineMapId];
+ }
+ // Try to parse inline affine map.
+ return parseAffineMapInline();
+}
+
//===----------------------------------------------------------------------===//
// SSA
//===----------------------------------------------------------------------===//