Adds ModuleState to support printing outlined AffineMaps.
PiperOrigin-RevId: 204999887
diff --git a/include/mlir/IR/Module.h b/include/mlir/IR/Module.h
index a12b387..ccc832a 100644
--- a/include/mlir/IR/Module.h
+++ b/include/mlir/IR/Module.h
@@ -39,10 +39,6 @@
// TODO(someone): This should switch to llvm::iplist<Function>.
std::vector<Function*> functionList;
- // FIXME: wrong representation and API.
- // These affine maps are immutable
- std::vector<const AffineMap *> affineMapList;
-
/// Perform (potentially expensive) checks of invariants, used to detect
/// compiler bugs. This aborts on failure.
void verify() const;
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 9a86b85..be6573f 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -44,27 +44,253 @@
}
//===----------------------------------------------------------------------===//
+// Module printing
+//===----------------------------------------------------------------------===//
+
+namespace {
+class ModuleState {
+ public:
+ ModuleState(raw_ostream &os);
+
+ void initialize(const Module *module);
+
+ void print(const Module *module);
+ void print(const Type *type) const;
+ void print(const Function* fn);
+
+ void recordAffineMapReference(const AffineMap* affineMap) {
+ if (affineMapIds.count(affineMap) == 0) {
+ affineMapIds[affineMap] = nextAffineMapId++;
+ }
+ }
+
+ int getAffineMapId(const AffineMap* affineMap) const {
+ auto it = affineMapIds.find(affineMap);
+ if (it == affineMapIds.end()) {
+ return -1;
+ }
+ return it->second;
+ }
+
+ private:
+ // Visit functions.
+ void visitFunction(const Function *fn);
+ void visitExtFunction(const ExtFunction *fn);
+ void visitCFGFunction(const CFGFunction *fn);
+ void visitMLFunction(const MLFunction *fn);
+ void visitType(const Type *type);
+
+ raw_ostream &os;
+ DenseMap<const AffineMap*, int> affineMapIds;
+ int nextAffineMapId = 0;
+};
+} // end anonymous namespace
+
+ModuleState::ModuleState(raw_ostream &os) : os(os) {
+}
+
+// Initializes module state, populating affine map state.
+void ModuleState::initialize(const Module *module) {
+ for (auto fn : module->functionList) {
+ visitFunction(fn);
+ }
+}
+
+// TODO Support visiting other types/instructions when implemented.
+void ModuleState::visitType(const Type *type) {
+ if (type->getKind() == Type::Kind::Function) {
+ // Visit input and result types for functions.
+ auto *funcType = cast<FunctionType>(type);
+ for (auto* input : funcType->getInputs()) {
+ visitType(input);
+ }
+ for (auto* result : funcType->getResults()) {
+ visitType(result);
+ }
+ } else if (type->getKind() == Type::Kind::MemRef) {
+ // Visit affine maps in memref type.
+ auto *memref = cast<MemRefType>(type);
+ for (AffineMap* map : memref->getAffineMaps()) {
+ recordAffineMapReference(map);
+ }
+ }
+}
+
+void ModuleState::visitExtFunction(const ExtFunction *fn) {
+ visitType(fn->getType());
+}
+
+void ModuleState::visitCFGFunction(const CFGFunction *fn) {
+ visitType(fn->getType());
+ // TODO Visit function body instructions.
+}
+
+void ModuleState::visitMLFunction(const MLFunction *fn) {
+ visitType(fn->getType());
+ // TODO Visit function body statements.
+}
+
+void ModuleState::visitFunction(const Function *fn) {
+ switch (fn->getKind()) {
+ case Function::Kind::ExtFunc:
+ return visitExtFunction(cast<ExtFunction>(fn));
+ case Function::Kind::CFGFunc:
+ return visitCFGFunction(cast<CFGFunction>(fn));
+ case Function::Kind::MLFunc:
+ return visitMLFunction(cast<MLFunction>(fn));
+ }
+}
+
+static void printExtFunction(const ExtFunction* fn,
+ const ModuleState* moduleState, raw_ostream &os);
+
+
+static void printCFGFunction(const CFGFunction* fn,
+ const ModuleState* moduleState, raw_ostream &os);
+
+static void printMLFunction(const MLFunction* fn,
+ const ModuleState* moduleState, raw_ostream &os);
+
+// Prints function with initialized module state.
+void ModuleState::print(const Function* fn) {
+ switch (fn->getKind()) {
+ case Function::Kind::ExtFunc:
+ return printExtFunction(cast<ExtFunction>(fn), this, os);
+ case Function::Kind::CFGFunc:
+ return printCFGFunction(cast<CFGFunction>(fn), this, os);
+ case Function::Kind::MLFunc:
+ return printMLFunction(cast<MLFunction>(fn), this, os);
+ }
+}
+
+// Prints affine map identifier.
+static void printAffineMapId(unsigned affineMapId, raw_ostream &os) {
+ os << "#map" << affineMapId;
+}
+
+void ModuleState::print(const Module *module) {
+ for (const auto& mapAndId : affineMapIds) {
+ printAffineMapId(mapAndId.second, os);
+ os << " = ";
+ mapAndId.first->print(os);
+ os << '\n';
+ }
+ for (auto *fn : module->functionList)
+ print(fn);
+}
+
+void ModuleState::print(const Type *type) const {
+ switch (type->getKind()) {
+ case Type::Kind::AffineInt: os << "affineint"; return;
+ case Type::Kind::BF16: os << "bf16"; return;
+ case Type::Kind::F16: os << "f16"; return;
+ case Type::Kind::F32: os << "f32"; return;
+ case Type::Kind::F64: os << "f64"; return;
+
+ case Type::Kind::Integer: {
+ auto *integer = cast<IntegerType>(type);
+ os << 'i' << integer->getWidth();
+ return;
+ }
+ case Type::Kind::Function: {
+ auto *func = cast<FunctionType>(type);
+ os << '(';
+ interleave(func->getInputs(),
+ [&](Type *type) { os << *type; },
+ [&]() { os << ", "; });
+ os << ") -> ";
+ auto results = func->getResults();
+ if (results.size() == 1)
+ os << *results[0];
+ else {
+ os << '(';
+ interleave(results,
+ [&](Type *type) { os << *type; },
+ [&]() { os << ", "; });
+ os << ')';
+ }
+ return;
+ }
+ case Type::Kind::Vector: {
+ auto *v = cast<VectorType>(type);
+ os << "vector<";
+ for (auto dim : v->getShape())
+ os << dim << 'x';
+ os << *v->getElementType() << '>';
+ return;
+ }
+ case Type::Kind::RankedTensor: {
+ auto *v = cast<RankedTensorType>(type);
+ os << "tensor<";
+ for (auto dim : v->getShape()) {
+ if (dim < 0)
+ os << '?';
+ else
+ os << dim;
+ os << 'x';
+ }
+ os << *v->getElementType() << '>';
+ return;
+ }
+ case Type::Kind::UnrankedTensor: {
+ auto *v = cast<UnrankedTensorType>(type);
+ os << "tensor<??" << *v->getElementType() << '>';
+ return;
+ }
+ case Type::Kind::MemRef: {
+ auto *v = cast<MemRefType>(type);
+ os << "memref<";
+ for (auto dim : v->getShape()) {
+ if (dim < 0)
+ os << '?';
+ else
+ os << dim;
+ os << 'x';
+ }
+ os << *v->getElementType();
+ for (auto map : v->getAffineMaps()) {
+ os << ", ";
+ const int mapId = getAffineMapId(map);
+ if (mapId >= 0) {
+ // Map will be printed at top of module so print reference to its id.
+ printAffineMapId(mapId, os);
+ } else {
+ // Map not in module state so print inline.
+ map->print(os);
+ }
+ }
+ os << ", " << v->getMemorySpace();
+ os << '>';
+ return;
+ }
+ }
+}
+
+//===----------------------------------------------------------------------===//
// Function printing
//===----------------------------------------------------------------------===//
-static void printFunctionSignature(const Function *fn, raw_ostream &os) {
+static void printFunctionSignature(const Function *fn,
+ const ModuleState *moduleState,
+ raw_ostream &os) {
auto type = fn->getType();
os << "@" << fn->getName() << '(';
interleave(type->getInputs(),
- [&](Type *eltType) { os << *eltType; },
+ [&](Type *eltType) { moduleState->print(eltType); },
[&]() { os << ", "; });
os << ')';
switch (type->getResults().size()) {
case 0: break;
case 1:
- os << " -> " << *type->getResults()[0];
+ os << " -> ";
+ moduleState->print(type->getResults()[0]);
break;
default:
os << " -> (";
interleave(type->getResults(),
- [&](Type *eltType) { os << *eltType; },
+ [&](Type *eltType) { moduleState->print(eltType); },
[&]() { os << ", "; });
os << ')';
break;
@@ -72,8 +298,9 @@
}
void ExtFunction::print(raw_ostream &os) const {
+ ModuleState moduleState(os);
os << "extfunc ";
- printFunctionSignature(this, os);
+ printFunctionSignature(this, &moduleState, os);
os << "\n";
}
@@ -83,18 +310,23 @@
// CFG and ML functions.
class FunctionState {
public:
- FunctionState(MLIRContext *context, raw_ostream &os);
+ FunctionState(MLIRContext *context, const ModuleState *moduleState,
+ raw_ostream &os);
void printOperation(const Operation *op);
protected:
raw_ostream &os;
+ const ModuleState *moduleState;
const OperationSet &operationSet;
};
} // end anonymous namespace
-FunctionState::FunctionState(MLIRContext *context, raw_ostream &os)
- : os(os), operationSet(OperationSet::get(context)) {}
+FunctionState::FunctionState(MLIRContext *context,
+ const ModuleState *moduleState,
+ raw_ostream &os)
+ : os(os), moduleState(moduleState),
+ operationSet(OperationSet::get(context)) {}
void FunctionState::printOperation(const Operation *op) {
// Check to see if this is a known operation. If so, use the registered
@@ -126,7 +358,8 @@
namespace {
class CFGFunctionState : public FunctionState {
public:
- CFGFunctionState(const CFGFunction *function, raw_ostream &os);
+ CFGFunctionState(const CFGFunction *function, const ModuleState *moduleState,
+ raw_ostream &os);
const CFGFunction *getFunction() const { return function; }
@@ -150,8 +383,11 @@
};
} // end anonymous namespace
-CFGFunctionState::CFGFunctionState(const CFGFunction *function, raw_ostream &os)
- : FunctionState(function->getContext(), os), function(function) {
+CFGFunctionState::CFGFunctionState(const CFGFunction *function,
+ const ModuleState *moduleState,
+ raw_ostream &os)
+ : FunctionState(function->getContext(), moduleState, os),
+ function(function) {
// Each basic block gets a unique ID per function.
unsigned blockID = 0;
for (auto &block : *function)
@@ -160,7 +396,7 @@
void CFGFunctionState::print() {
os << "cfgfunc ";
- printFunctionSignature(this->getFunction(), os);
+ printFunctionSignature(this->getFunction(), moduleState, os);
os << " {\n";
for (auto &block : *function)
@@ -210,7 +446,8 @@
namespace {
class MLFunctionState : public FunctionState {
public:
- MLFunctionState(const MLFunction *function, raw_ostream &os);
+ MLFunctionState(const MLFunction *function, const ModuleState *moduleState,
+ raw_ostream &os);
const MLFunction *getFunction() const { return function; }
@@ -233,14 +470,16 @@
};
} // end anonymous namespace
-MLFunctionState::MLFunctionState(const MLFunction *function, raw_ostream &os)
- : FunctionState(function->getContext(), os), function(function),
- numSpaces(0) {}
+MLFunctionState::MLFunctionState(const MLFunction *function,
+ const ModuleState *moduleState,
+ raw_ostream &os)
+ : FunctionState(function->getContext(), moduleState, os),
+ function(function), numSpaces(0) {}
void MLFunctionState::print() {
os << "mlfunc ";
// FIXME: should print argument names rather than just signature
- printFunctionSignature(function, os);
+ printFunctionSignature(function, moduleState, os);
os << " {\n";
print(function);
os << " return\n";
@@ -288,12 +527,41 @@
}
}
+void printExtFunction(const ExtFunction* fn, const ModuleState* moduleState,
+ raw_ostream &os) {
+ os << "extfunc ";
+ printFunctionSignature(fn, moduleState, os);
+ os << '\n';
+}
+
+void printCFGFunction(const CFGFunction* fn, const ModuleState* moduleState,
+ raw_ostream &os) {
+ CFGFunctionState state(fn, moduleState, os);
+ state.print();
+}
+
+void printMLFunction(const MLFunction* fn, const ModuleState* moduleState,
+ raw_ostream &os) {
+ MLFunctionState state(fn, moduleState, os);
+ state.print();
+}
+
//===----------------------------------------------------------------------===//
// print and dump methods
//===----------------------------------------------------------------------===//
+void Type::print(raw_ostream &os) const {
+ ModuleState moduleState(os);
+ moduleState.print(this);
+}
+
+void Type::dump() const {
+ print(llvm::errs());
+}
+
void Instruction::print(raw_ostream &os) const {
- CFGFunctionState state(getFunction(), os);
+ ModuleState moduleState(os);
+ CFGFunctionState state(getFunction(), &moduleState, os);
state.print(this);
}
@@ -406,7 +674,8 @@
}
void BasicBlock::print(raw_ostream &os) const {
- CFGFunctionState state(getFunction(), os);
+ ModuleState moduleState(os);
+ CFGFunctionState state(getFunction(), &moduleState, os);
state.print();
}
@@ -415,7 +684,8 @@
}
void Statement::print(raw_ostream &os) const {
- MLFunctionState state(getFunction(), os);
+ ModuleState moduleState(os);
+ MLFunctionState state(getFunction(), &moduleState, os);
state.print(this);
}
@@ -436,24 +706,21 @@
}
void CFGFunction::print(raw_ostream &os) const {
- CFGFunctionState state(this, os);
+ ModuleState moduleState(os);
+ CFGFunctionState state(this, &moduleState, os);
state.print();
}
void MLFunction::print(raw_ostream &os) const {
- MLFunctionState state(this, os);
+ ModuleState moduleState(os);
+ MLFunctionState state(this, &moduleState, os);
state.print();
}
void Module::print(raw_ostream &os) const {
- unsigned id = 0;
- for (auto *map : affineMapList) {
- os << "#" << id++ << " = ";
- map->print(os);
- os << '\n';
- }
- for (auto *fn : functionList)
- fn->print(os);
+ ModuleState moduleState(os);
+ moduleState.initialize(this);
+ moduleState.print(this);
}
void Module::dump() const {
diff --git a/lib/IR/Types.cpp b/lib/IR/Types.cpp
index 1b7d1a6..fe8238d 100644
--- a/lib/IR/Types.cpp
+++ b/lib/IR/Types.cpp
@@ -64,87 +64,3 @@
ArrayRef<AffineMap*> MemRefType::getAffineMaps() const {
return ArrayRef<AffineMap*>(affineMapList, numAffineMaps);
}
-
-void Type::print(raw_ostream &os) const {
- switch (getKind()) {
- case Kind::AffineInt: os << "affineint"; return;
- case Kind::BF16: os << "bf16"; return;
- case Kind::F16: os << "f16"; return;
- case Kind::F32: os << "f32"; return;
- case Kind::F64: os << "f64"; return;
-
- case Kind::Integer: {
- auto *integer = cast<IntegerType>(this);
- os << 'i' << integer->getWidth();
- return;
- }
- case Kind::Function: {
- auto *func = cast<FunctionType>(this);
- os << '(';
- interleave(func->getInputs(),
- [&](Type *type) { os << *type; },
- [&]() { os << ", "; });
- os << ") -> ";
- auto results = func->getResults();
- if (results.size() == 1)
- os << *results[0];
- else {
- os << '(';
- interleave(results,
- [&](Type *type) { os << *type; },
- [&]() { os << ", "; });
- os << ")";
- }
- return;
- }
- case Kind::Vector: {
- auto *v = cast<VectorType>(this);
- os << "vector<";
- for (auto dim : v->getShape())
- os << dim << 'x';
- os << *v->getElementType() << '>';
- return;
- }
- case Kind::RankedTensor: {
- auto *v = cast<RankedTensorType>(this);
- os << "tensor<";
- for (auto dim : v->getShape()) {
- if (dim < 0)
- os << '?';
- else
- os << dim;
- os << 'x';
- }
- os << *v->getElementType() << '>';
- return;
- }
- case Kind::UnrankedTensor: {
- auto *v = cast<UnrankedTensorType>(this);
- os << "tensor<??" << *v->getElementType() << '>';
- return;
- }
- case Kind::MemRef: {
- auto *v = cast<MemRefType>(this);
- os << "memref<";
- for (auto dim : v->getShape()) {
- if (dim < 0)
- os << '?';
- else
- os << dim;
- os << 'x';
- }
- os << *v->getElementType();
- for (auto map : v->getAffineMaps()) {
- os << ", ";
- map->print(os);
- }
- os << ", " << v->getMemorySpace();
- os << '>';
- return;
- }
- }
-}
-
-void Type::dump() const {
- print(llvm::errs());
-}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 4b9133c..aabefbf 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1642,7 +1642,6 @@
if (!entry)
return ParseFailure;
- getModule()->affineMapList.push_back(entry);
return ParseSuccess;
}
diff --git a/test/IR/parser-affine-map.mlir b/test/IR/parser-affine-map.mlir
index a0b550f..50f2bd7 100644
--- a/test/IR/parser-affine-map.mlir
+++ b/test/IR/parser-affine-map.mlir
@@ -1,121 +1,241 @@
// RUN: %S/../../mlir-opt %s -o - | FileCheck %s
-// CHECK: #{{[0-9]+}} = (d0, d1) -> (d0, d1)
-#hello_world0 = (i, j) -> (i, j)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (d0, d1)
+#map0 = (i, j) -> (i, j)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> (d0, d1)
-#hello_world1 = (i, j) [s0] -> (i, j)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (d0, d1)
+#map1 = (i, j) [s0] -> (i, j)
-// CHECK: #{{[0-9]+}} = () -> (0)
-#hello_world2 = () -> (0)
+// CHECK-DAG: #map{{[0-9]+}} = () -> (0)
+#map2 = () -> (0)
-// CHECK: #{{[0-9]+}} = (d0, d1) -> ((d0 + 1), d1)
-#hello_world3 = (i, j) -> (i+1, j)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> ((d0 + 1), d1)
+#map3 = (i, j) -> (i+1, j)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + s0), d1)
-#hello_world4 = (i, j) [s0] -> (i + s0, j)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + s0), d1)
+#map4 = (i, j) [s0] -> (i + s0, j)
-// CHECK: #{{[0-9]+}} = (d0, d1) -> ((d0 + 1), d1)
-#hello_world5 = (i, j) -> (1+i, j)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> ((d0 + 1), d1)
+#map5 = (i, j) -> (1+i, j)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + s0), (d1 + 5))
-#hello_world6 = (i, j) [s0] -> (i + s0, j + 5)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + s0), (d1 + 5))
+#map6 = (i, j) [s0] -> (i + s0, j + 5)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + s0), d1)
-#hello_world7 = (i, j) [s0] -> (i + j + s0, j)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + s0), d1)
+#map7 = (i, j) [s0] -> (i + j + s0, j)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> ((((d0 + 5) + d1) + s0), d1)
-#hello_world8 = (i, j) [s0] -> (5 + i + j + s0, j)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((((d0 + 5) + d1) + s0), d1)
+#map8 = (i, j) [s0] -> (5 + i + j + s0, j)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + 5), d1)
-#hello_world9 = (i, j) [s0] -> ((i + j) + 5, j)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + 5), d1)
+#map9 = (i, j) [s0] -> ((i + j) + 5, j)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + (d1 + 5)), d1)
-#hello_world10 = (i, j) [s0] -> (i + (j + 5), j)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + (d1 + 5)), d1)
+#map10 = (i, j) [s0] -> (i + (j + 5), j)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> ((d0 * 2), (d1 * 3))
-#hello_world11 = (i, j) [s0] -> (2*i, 3*j)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 * 2), (d1 * 3))
+#map11 = (i, j) [s0] -> (2*i, 3*j)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + 12) + ((d1 + (s0 * 3)) * 5)), d1)
-#hello_world12 = (i, j) [s0] -> (i + 2*6 + 5*(j+s0*3), j)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + 12) + ((d1 + (s0 * 3)) * 5)), d1)
+#map12 = (i, j) [s0] -> (i + 2*6 + 5*(j+s0*3), j)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> (((d0 * 5) + d1), d1)
-#hello_world13 = (i, j) [s0] -> (5*i + j, j)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 * 5) + d1), d1)
+#map13 = (i, j) [s0] -> (5*i + j, j)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + d1), d1)
-#hello_world14 = (i, j) [s0] -> ((i + j), (j))
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((d0 + d1), d1)
+#map14 = (i, j) [s0] -> ((i + j), (j))
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + 5), (d1 + 3))
-#hello_world15 = (i, j) [s0] -> ((i + j)+5, (j)+3)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (((d0 + d1) + 5), (d1 + 3))
+#map15 = (i, j) [s0] -> ((i + j)+5, (j)+3)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> (d0, 0)
-#hello_world16 = (i, j) [s1] -> (i, 0)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (d0, 0)
+#map16 = (i, j) [s1] -> (i, 0)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> (d0, (d1 * s0))
-#hello_world17 = (i, j) [s0] -> (i, s0*j)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (d0, (d1 * s0))
+#map17 = (i, j) [s0] -> (i, s0*j)
-// CHECK: #{{[0-9]+}} = (d0, d1) -> (d0, ((d0 * 3) + d1))
-#hello_world19 = (i, j) -> (i, 3*i + j)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (d0, ((d0 * 3) + d1))
+#map19 = (i, j) -> (i, 3*i + j)
-// CHECK: #{{[0-9]+}} = (d0, d1) -> (d0, (d0 + (d1 * 3)))
-#hello_world20 = (i, j) -> (i, i + 3*j)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (d0, (d0 + (d1 * 3)))
+#map20 = (i, j) -> (i, i + 3*j)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> (d0, (((d0 * ((s0 * s0) * 9)) + 2) + 1))
-#hello_world18 = (i, j) [N] -> (i, 2 + N*N*9*i + 1)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> (d0, (((d0 * ((s0 * s0) * 9)) + 2) + 1))
+#map18 = (i, j) [N] -> (i, 2 + N*N*9*i + 1)
-// CHECK: #{{[0-9]+}} = (d0, d1) -> (1, ((d0 + (d1 * 3)) + 5))
-#hello_world21 = (i, j) -> (1, i + 3*j + 5)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (1, ((d0 + (d1 * 3)) + 5))
+#map21 = (i, j) -> (1, i + 3*j + 5)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0] -> ((s0 * 5), ((d0 + (d1 * 3)) + (d0 * 5)))
-#hello_world22 = (i, j) [s0] -> (5*s0, i + 3*j + 5*i)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0] -> ((s0 * 5), ((d0 + (d1 * 3)) + (d0 * 5)))
+#map22 = (i, j) [s0] -> (5*s0, i + 3*j + 5*i)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * (s0 * s1)), d1)
-#hello_world23 = (i, j) [s0, s1] -> (i*(s0*s1), j)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * (s0 * s1)), d1)
+#map23 = (i, j) [s0, s1] -> (i*(s0*s1), j)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 mod 5))
-#hello_world24 = (i, j) [s0, s1] -> (i, j mod 5)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 mod 5))
+#map24 = (i, j) [s0, s1] -> (i, j mod 5)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 floordiv 5))
-#hello_world25 = (i, j) [s0, s1] -> (i, j floordiv 5)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 floordiv 5))
+#map25 = (i, j) [s0, s1] -> (i, j floordiv 5)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 ceildiv 5))
-#hello_world26 = (i, j) [s0, s1] -> (i, j ceildiv 5)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 ceildiv 5))
+#map26 = (i, j) [s0, s1] -> (i, j ceildiv 5)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 - d1) - 5))
-#hello_world29 = (i, j) [s0, s1] -> (i, i - j - 5)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 - d1) - 5))
+#map29 = (i, j) [s0, s1] -> (i, i - j - 5)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 - (d1 * s1)) + 2))
-#hello_world30 = (i, j) [M, N] -> (i, i - N*j + 2)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, ((d0 - (d1 * s1)) + 2))
+#map30 = (i, j) [M, N] -> (i, i - N*j + 2)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * -5), (d1 * -3), -2, ((d0 + d1) * -1), (s0 * -1))
-#hello_world32 = (i, j) [s0, s1] -> (-5*i, -3*j, -2, -1*(i+j), -1*s0)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * -5), (d1 * -3), -2, ((d0 + d1) * -1), (s0 * -1))
+#map32 = (i, j) [s0, s1] -> (-5*i, -3*j, -2, -1*(i+j), -1*s0)
-// CHECK: #{{[0-9]+}} = (d0, d1) -> (-4, (d0 * -1))
-#hello_world33 = (i, j) -> (-2+-5-(-3), -1*i)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (-4, (d0 * -1))
+#map33 = (i, j) -> (-2+-5-(-3), -1*i)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 floordiv s0), (d1 mod s0))
-#hello_world34 = (i, j) [s0, s1] -> (i, j floordiv s0, j mod s0)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, (d1 floordiv s0), (d1 mod s0))
+#map34 = (i, j) [s0, s1] -> (i, j floordiv s0, j mod s0)
-// CHECK: #{{[0-9]+}} = (d0, d1, d2) [s0, s1, s2] -> (((((d0 * s1) * s2) + (d1 * s1)) + d2))
-#hello_world35 = (i, j, k) [s0, s1, s2] -> (i*s1*s2 + j*s1 + k)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1, d2) [s0, s1, s2] -> (((((d0 * s1) * s2) + (d1 * s1)) + d2))
+#map35 = (i, j, k) [s0, s1, s2] -> (i*s1*s2 + j*s1 + k)
-// CHECK: #{{[0-9]+}} = (d0, d1) -> (8, 4, 1, 3, 2, 4)
-#hello_world36 = (i, j) -> (5+3, 2*2, 8-7, 100 floordiv 32, 5 mod 3, 10 ceildiv 3)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (8, 4, 1, 3, 2, 4)
+#map36 = (i, j) -> (5+3, 2*2, 8-7, 100 floordiv 32, 5 mod 3, 10 ceildiv 3)
-// CHECK: #{{[0-9]+}} = (d0, d1) -> (4, 11, 512, 15)
-#hello_world37 = (i, j) -> (5 mod 3 + 2, 5*3 - 4, 128 * (500 ceildiv 128), 40 floordiv 7 * 3)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (4, 11, 512, 15)
+#map37 = (i, j) -> (5 mod 3 + 2, 5*3 - 4, 128 * (500 ceildiv 128), 40 floordiv 7 * 3)
-// CHECK: #{{[0-9]+}} = (d0, d1) -> (((d0 * 2) + 1), (d1 + 2))
-#hello_world38 = (i, j) -> (1 + i*2, 2 + j)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (((d0 * 2) + 1), (d1 + 2))
+#map38 = (i, j) -> (1 + i*2, 2 + j)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * s0), (d0 + s0), (d0 + 2), (d1 * 2), (s1 * 2), (s0 + 2))
-#hello_world39 = (i, j) [M, N] -> (i*M, M + i, 2+i, j*2, N*2, 2 + M)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> ((d0 * s0), (d0 + s0), (d0 + 2), (d1 * 2), (s1 * 2), (s0 + 2))
+#map39 = (i, j) [M, N] -> (i*M, M + i, 2+i, j*2, N*2, 2 + M)
-// CHECK: #{{[0-9]+}} = (d0, d1) -> (d0, d1) size (10, 20)
-#hello_world40 = (i, j) -> (i, j) size (10, 20)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (d0, d1) size (10, 20)
+#map40 = (i, j) -> (i, j) size (10, 20)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, d1) size (s0, (s1 + 10))
-#hello_world41 = (i, j) [N, M] -> (i, j) size (N, M+10)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, d1) size (s0, (s1 + 10))
+#map41 = (i, j) [N, M] -> (i, j) size (N, M+10)
-// CHECK: #{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, d1) size (128, (((s0 * 2) + 5) + s1))
-#hello_world42 = (i, j) [N, M] -> (i, j) size (64 + 64, 5 + 2*N + M)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) [s0, s1] -> (d0, d1) size (128, (((s0 * 2) + 5) + s1))
+#map42 = (i, j) [N, M] -> (i, j) size (64 + 64, 5 + 2*N + M)
+
+// CHECK: extfunc @f0(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f0(memref<2x4xi8, #map0, 1>)
+
+// CHECK: extfunc @f1(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f1(memref<2x4xi8, #map1, 1>)
+
+// CHECK: extfunc @f2(memref<2xi8, #map{{[0-9]+}}, 1>)
+extfunc @f2(memref<2xi8, #map2, 1>)
+
+// CHECK: extfunc @f3(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f3(memref<2x4xi8, #map3, 1>)
+
+// CHECK: extfunc @f4(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f4(memref<2x4xi8, #map4, 1>)
+
+// CHECK: extfunc @f5(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f5(memref<2x4xi8, #map5, 1>)
+
+// CHECK: extfunc @f6(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f6(memref<2x4xi8, #map6, 1>)
+
+// CHECK: extfunc @f7(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f7(memref<2x4xi8, #map7, 1>)
+
+// CHECK: extfunc @f8(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f8(memref<2x4xi8, #map8, 1>)
+
+// CHECK: extfunc @f9(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f9(memref<2x4xi8, #map9, 1>)
+
+// CHECK: extfunc @f10(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f10(memref<2x4xi8, #map10, 1>)
+
+// CHECK: extfunc @f11(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f11(memref<2x4xi8, #map11, 1>)
+
+// CHECK: extfunc @f12(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f12(memref<2x4xi8, #map12, 1>)
+
+// CHECK: extfunc @f13(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f13(memref<2x4xi8, #map13, 1>)
+
+// CHECK: extfunc @f14(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f14(memref<2x4xi8, #map14, 1>)
+
+// CHECK: extfunc @f15(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f15(memref<2x4xi8, #map15, 1>)
+
+// CHECK: extfunc @f16(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f16(memref<2x4xi8, #map16, 1>)
+
+// CHECK: extfunc @f17(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f17(memref<2x4xi8, #map17, 1>)
+
+// CHECK: extfunc @f19(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f19(memref<2x4xi8, #map19, 1>)
+
+// CHECK: extfunc @f20(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f20(memref<2x4xi8, #map20, 1>)
+
+// CHECK: extfunc @f18(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f18(memref<2x4xi8, #map18, 1>)
+
+// CHECK: extfunc @f21(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f21(memref<2x4xi8, #map21, 1>)
+
+// CHECK: extfunc @f22(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f22(memref<2x4xi8, #map22, 1>)
+
+// CHECK: extfunc @f23(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f23(memref<2x4xi8, #map23, 1>)
+
+// CHECK: extfunc @f24(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f24(memref<2x4xi8, #map24, 1>)
+
+// CHECK: extfunc @f25(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f25(memref<2x4xi8, #map25, 1>)
+
+// CHECK: extfunc @f26(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f26(memref<2x4xi8, #map26, 1>)
+
+// CHECK: extfunc @f29(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f29(memref<2x4xi8, #map29, 1>)
+
+// CHECK: extfunc @f30(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f30(memref<2x4xi8, #map30, 1>)
+
+// CHECK: extfunc @f32(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f32(memref<2x4xi8, #map32, 1>)
+
+// CHECK: extfunc @f33(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f33(memref<2x4xi8, #map33, 1>)
+
+// CHECK: extfunc @f34(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f34(memref<2x4xi8, #map34, 1>)
+
+// CHECK: extfunc @f35(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f35(memref<2x4xi8, #map35, 1>)
+
+// CHECK: extfunc @f36(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f36(memref<2x4xi8, #map36, 1>)
+
+// CHECK: extfunc @f37(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f37(memref<2x4xi8, #map37, 1>)
+
+// CHECK: extfunc @f38(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f38(memref<2x4xi8, #map38, 1>)
+
+// CHECK: extfunc @f39(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f39(memref<2x4xi8, #map39, 1>)
+
+// CHECK: extfunc @f40(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f40(memref<2x4xi8, #map40, 1>)
+
+// CHECK: extfunc @f41(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f41(memref<2x4xi8, #map41, 1>)
+
+// CHECK: extfunc @f42(memref<2x4xi8, #map{{[0-9]+}}, 1>)
+extfunc @f42(memref<2x4xi8, #map42, 1>)
\ No newline at end of file
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index 00e5c8e..b7a3678 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -3,10 +3,19 @@
//
// RUN: %S/../../mlir-opt %s -o - | FileCheck %s
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1, d2, d3, d4) [s0] -> (d0, d1, d2, d3, d4)
#map0 = (d0, d1, d2, d3, d4) [s0] -> (d0, d1, d2, d3, d4)
+
+// CHECK-DAG: #map{{[0-9]+}} = (d0) -> (d0)
#map1 = (d0) -> (d0)
+
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1, d2) -> (d0, d1, d2)
#map2 = (d0, d1, d2) -> (d0, d1, d2)
+
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1, d2) -> (d1, d0, d2)
#map3 = (d0, d1, d2) -> (d1, d0, d2)
+
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1, d2) -> (d2, d1, d0)
#map4 = (d0, d1, d2) -> (d2, d1, d0)
// CHECK: extfunc @foo(i32, i64) -> f32
@@ -32,27 +41,26 @@
extfunc @tensors(tensor<?? f32>, tensor<?? vector<2x4xf32>>,
tensor<1x?x4x?x?xaffineint>, tensor<i8>)
-// TODO(andydavis) Add support to outline affine maps for these cases.
-// CHECK: extfunc @memrefs(memref<1x?x4x?x?xaffineint, (d0, d1, d2, d3, d4) [s0] -> (d0, d1, d2, d3, d4), 0>, memref<i8, (d0) -> (d0), 0>)
+// CHECK: extfunc @memrefs(memref<1x?x4x?x?xaffineint, #map{{[0-9]+}}, 0>, memref<i8, #map{{[0-9]+}}, 0>)
extfunc @memrefs(memref<1x?x4x?x?xaffineint, #map0, 0>, memref<i8, #map1, 0>)
// Test memref affine map compositions.
-// CHECK: extfunc @memrefs2(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2), 1>)
+// CHECK: extfunc @memrefs2(memref<2x4x8xi8, #map{{[0-9]+}}, 1>)
extfunc @memrefs2(memref<2x4x8xi8, #map2, 1>)
-// CHECK: extfunc @memrefs23(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2), (d0, d1, d2) -> (d1, d0, d2), 0>)
+// CHECK: extfunc @memrefs23(memref<2x4x8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}, 0>)
extfunc @memrefs23(memref<2x4x8xi8, #map2, #map3, 0>)
-// CHECK: extfunc @memrefs234(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2), (d0, d1, d2) -> (d1, d0, d2), (d0, d1, d2) -> (d2, d1, d0), 3>)
+// CHECK: extfunc @memrefs234(memref<2x4x8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}, #map{{[0-9]+}}, 3>)
extfunc @memrefs234(memref<2x4x8xi8, #map2, #map3, #map4, 3>)
// Test memref inline affine map compositions.
-// CHECK: extfunc @memrefs2(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2), 0>)
+// CHECK: extfunc @memrefs2(memref<2x4x8xi8, #map{{[0-9]+}}, 0>)
extfunc @memrefs2(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2), 0>)
-// CHECK: extfunc @memrefs23(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2), (d0, d1, d2) -> (d1, d0, d2), 1>)
+// CHECK: extfunc @memrefs23(memref<2x4x8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}, 1>)
extfunc @memrefs23(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2), (d0, d1, d2) -> (d1, d0, d2), 1>)
// CHECK: extfunc @functions((memref<1x?x4x?x?xaffineint, (d0, d1, d2, d3, d4) [s0] -> (d0, d1, d2, d3, d4), 0>, memref<i8, (d0) -> (d0), 0>) -> (), () -> ())