Adds ModuleState to support printing outlined AffineMaps.
PiperOrigin-RevId: 204999887
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 9a86b85..be6573f 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -44,27 +44,253 @@
}
//===----------------------------------------------------------------------===//
+// Module printing
+//===----------------------------------------------------------------------===//
+
+namespace {
+class ModuleState {
+ public:
+ ModuleState(raw_ostream &os);
+
+ void initialize(const Module *module);
+
+ void print(const Module *module);
+ void print(const Type *type) const;
+ void print(const Function* fn);
+
+ void recordAffineMapReference(const AffineMap* affineMap) {
+ if (affineMapIds.count(affineMap) == 0) {
+ affineMapIds[affineMap] = nextAffineMapId++;
+ }
+ }
+
+ int getAffineMapId(const AffineMap* affineMap) const {
+ auto it = affineMapIds.find(affineMap);
+ if (it == affineMapIds.end()) {
+ return -1;
+ }
+ return it->second;
+ }
+
+ private:
+ // Visit functions.
+ void visitFunction(const Function *fn);
+ void visitExtFunction(const ExtFunction *fn);
+ void visitCFGFunction(const CFGFunction *fn);
+ void visitMLFunction(const MLFunction *fn);
+ void visitType(const Type *type);
+
+ raw_ostream &os;
+ DenseMap<const AffineMap*, int> affineMapIds;
+ int nextAffineMapId = 0;
+};
+} // end anonymous namespace
+
+ModuleState::ModuleState(raw_ostream &os) : os(os) {
+}
+
+// Initializes module state, populating affine map state.
+void ModuleState::initialize(const Module *module) {
+ for (auto fn : module->functionList) {
+ visitFunction(fn);
+ }
+}
+
+// TODO Support visiting other types/instructions when implemented.
+void ModuleState::visitType(const Type *type) {
+ if (type->getKind() == Type::Kind::Function) {
+ // Visit input and result types for functions.
+ auto *funcType = cast<FunctionType>(type);
+ for (auto* input : funcType->getInputs()) {
+ visitType(input);
+ }
+ for (auto* result : funcType->getResults()) {
+ visitType(result);
+ }
+ } else if (type->getKind() == Type::Kind::MemRef) {
+ // Visit affine maps in memref type.
+ auto *memref = cast<MemRefType>(type);
+ for (AffineMap* map : memref->getAffineMaps()) {
+ recordAffineMapReference(map);
+ }
+ }
+}
+
+void ModuleState::visitExtFunction(const ExtFunction *fn) {
+ visitType(fn->getType());
+}
+
+void ModuleState::visitCFGFunction(const CFGFunction *fn) {
+ visitType(fn->getType());
+ // TODO Visit function body instructions.
+}
+
+void ModuleState::visitMLFunction(const MLFunction *fn) {
+ visitType(fn->getType());
+ // TODO Visit function body statements.
+}
+
+void ModuleState::visitFunction(const Function *fn) {
+ switch (fn->getKind()) {
+ case Function::Kind::ExtFunc:
+ return visitExtFunction(cast<ExtFunction>(fn));
+ case Function::Kind::CFGFunc:
+ return visitCFGFunction(cast<CFGFunction>(fn));
+ case Function::Kind::MLFunc:
+ return visitMLFunction(cast<MLFunction>(fn));
+ }
+}
+
+static void printExtFunction(const ExtFunction* fn,
+ const ModuleState* moduleState, raw_ostream &os);
+
+
+static void printCFGFunction(const CFGFunction* fn,
+ const ModuleState* moduleState, raw_ostream &os);
+
+static void printMLFunction(const MLFunction* fn,
+ const ModuleState* moduleState, raw_ostream &os);
+
+// Prints function with initialized module state.
+void ModuleState::print(const Function* fn) {
+ switch (fn->getKind()) {
+ case Function::Kind::ExtFunc:
+ return printExtFunction(cast<ExtFunction>(fn), this, os);
+ case Function::Kind::CFGFunc:
+ return printCFGFunction(cast<CFGFunction>(fn), this, os);
+ case Function::Kind::MLFunc:
+ return printMLFunction(cast<MLFunction>(fn), this, os);
+ }
+}
+
+// Prints affine map identifier.
+static void printAffineMapId(unsigned affineMapId, raw_ostream &os) {
+ os << "#map" << affineMapId;
+}
+
+void ModuleState::print(const Module *module) {
+ for (const auto& mapAndId : affineMapIds) {
+ printAffineMapId(mapAndId.second, os);
+ os << " = ";
+ mapAndId.first->print(os);
+ os << '\n';
+ }
+ for (auto *fn : module->functionList)
+ print(fn);
+}
+
+void ModuleState::print(const Type *type) const {
+ switch (type->getKind()) {
+ case Type::Kind::AffineInt: os << "affineint"; return;
+ case Type::Kind::BF16: os << "bf16"; return;
+ case Type::Kind::F16: os << "f16"; return;
+ case Type::Kind::F32: os << "f32"; return;
+ case Type::Kind::F64: os << "f64"; return;
+
+ case Type::Kind::Integer: {
+ auto *integer = cast<IntegerType>(type);
+ os << 'i' << integer->getWidth();
+ return;
+ }
+ case Type::Kind::Function: {
+ auto *func = cast<FunctionType>(type);
+ os << '(';
+ interleave(func->getInputs(),
+ [&](Type *type) { os << *type; },
+ [&]() { os << ", "; });
+ os << ") -> ";
+ auto results = func->getResults();
+ if (results.size() == 1)
+ os << *results[0];
+ else {
+ os << '(';
+ interleave(results,
+ [&](Type *type) { os << *type; },
+ [&]() { os << ", "; });
+ os << ')';
+ }
+ return;
+ }
+ case Type::Kind::Vector: {
+ auto *v = cast<VectorType>(type);
+ os << "vector<";
+ for (auto dim : v->getShape())
+ os << dim << 'x';
+ os << *v->getElementType() << '>';
+ return;
+ }
+ case Type::Kind::RankedTensor: {
+ auto *v = cast<RankedTensorType>(type);
+ os << "tensor<";
+ for (auto dim : v->getShape()) {
+ if (dim < 0)
+ os << '?';
+ else
+ os << dim;
+ os << 'x';
+ }
+ os << *v->getElementType() << '>';
+ return;
+ }
+ case Type::Kind::UnrankedTensor: {
+ auto *v = cast<UnrankedTensorType>(type);
+ os << "tensor<??" << *v->getElementType() << '>';
+ return;
+ }
+ case Type::Kind::MemRef: {
+ auto *v = cast<MemRefType>(type);
+ os << "memref<";
+ for (auto dim : v->getShape()) {
+ if (dim < 0)
+ os << '?';
+ else
+ os << dim;
+ os << 'x';
+ }
+ os << *v->getElementType();
+ for (auto map : v->getAffineMaps()) {
+ os << ", ";
+ const int mapId = getAffineMapId(map);
+ if (mapId >= 0) {
+ // Map will be printed at top of module so print reference to its id.
+ printAffineMapId(mapId, os);
+ } else {
+ // Map not in module state so print inline.
+ map->print(os);
+ }
+ }
+ os << ", " << v->getMemorySpace();
+ os << '>';
+ return;
+ }
+ }
+}
+
+//===----------------------------------------------------------------------===//
// Function printing
//===----------------------------------------------------------------------===//
-static void printFunctionSignature(const Function *fn, raw_ostream &os) {
+static void printFunctionSignature(const Function *fn,
+ const ModuleState *moduleState,
+ raw_ostream &os) {
auto type = fn->getType();
os << "@" << fn->getName() << '(';
interleave(type->getInputs(),
- [&](Type *eltType) { os << *eltType; },
+ [&](Type *eltType) { moduleState->print(eltType); },
[&]() { os << ", "; });
os << ')';
switch (type->getResults().size()) {
case 0: break;
case 1:
- os << " -> " << *type->getResults()[0];
+ os << " -> ";
+ moduleState->print(type->getResults()[0]);
break;
default:
os << " -> (";
interleave(type->getResults(),
- [&](Type *eltType) { os << *eltType; },
+ [&](Type *eltType) { moduleState->print(eltType); },
[&]() { os << ", "; });
os << ')';
break;
@@ -72,8 +298,9 @@
}
void ExtFunction::print(raw_ostream &os) const {
+ ModuleState moduleState(os);
os << "extfunc ";
- printFunctionSignature(this, os);
+ printFunctionSignature(this, &moduleState, os);
os << "\n";
}
@@ -83,18 +310,23 @@
// CFG and ML functions.
class FunctionState {
public:
- FunctionState(MLIRContext *context, raw_ostream &os);
+ FunctionState(MLIRContext *context, const ModuleState *moduleState,
+ raw_ostream &os);
void printOperation(const Operation *op);
protected:
raw_ostream &os;
+ const ModuleState *moduleState;
const OperationSet &operationSet;
};
} // end anonymous namespace
-FunctionState::FunctionState(MLIRContext *context, raw_ostream &os)
- : os(os), operationSet(OperationSet::get(context)) {}
+FunctionState::FunctionState(MLIRContext *context,
+ const ModuleState *moduleState,
+ raw_ostream &os)
+ : os(os), moduleState(moduleState),
+ operationSet(OperationSet::get(context)) {}
void FunctionState::printOperation(const Operation *op) {
// Check to see if this is a known operation. If so, use the registered
@@ -126,7 +358,8 @@
namespace {
class CFGFunctionState : public FunctionState {
public:
- CFGFunctionState(const CFGFunction *function, raw_ostream &os);
+ CFGFunctionState(const CFGFunction *function, const ModuleState *moduleState,
+ raw_ostream &os);
const CFGFunction *getFunction() const { return function; }
@@ -150,8 +383,11 @@
};
} // end anonymous namespace
-CFGFunctionState::CFGFunctionState(const CFGFunction *function, raw_ostream &os)
- : FunctionState(function->getContext(), os), function(function) {
+CFGFunctionState::CFGFunctionState(const CFGFunction *function,
+ const ModuleState *moduleState,
+ raw_ostream &os)
+ : FunctionState(function->getContext(), moduleState, os),
+ function(function) {
// Each basic block gets a unique ID per function.
unsigned blockID = 0;
for (auto &block : *function)
@@ -160,7 +396,7 @@
void CFGFunctionState::print() {
os << "cfgfunc ";
- printFunctionSignature(this->getFunction(), os);
+ printFunctionSignature(this->getFunction(), moduleState, os);
os << " {\n";
for (auto &block : *function)
@@ -210,7 +446,8 @@
namespace {
class MLFunctionState : public FunctionState {
public:
- MLFunctionState(const MLFunction *function, raw_ostream &os);
+ MLFunctionState(const MLFunction *function, const ModuleState *moduleState,
+ raw_ostream &os);
const MLFunction *getFunction() const { return function; }
@@ -233,14 +470,16 @@
};
} // end anonymous namespace
-MLFunctionState::MLFunctionState(const MLFunction *function, raw_ostream &os)
- : FunctionState(function->getContext(), os), function(function),
- numSpaces(0) {}
+MLFunctionState::MLFunctionState(const MLFunction *function,
+ const ModuleState *moduleState,
+ raw_ostream &os)
+ : FunctionState(function->getContext(), moduleState, os),
+ function(function), numSpaces(0) {}
void MLFunctionState::print() {
os << "mlfunc ";
// FIXME: should print argument names rather than just signature
- printFunctionSignature(function, os);
+ printFunctionSignature(function, moduleState, os);
os << " {\n";
print(function);
os << " return\n";
@@ -288,12 +527,41 @@
}
}
+void printExtFunction(const ExtFunction* fn, const ModuleState* moduleState,
+ raw_ostream &os) {
+ os << "extfunc ";
+ printFunctionSignature(fn, moduleState, os);
+ os << '\n';
+}
+
+void printCFGFunction(const CFGFunction* fn, const ModuleState* moduleState,
+ raw_ostream &os) {
+ CFGFunctionState state(fn, moduleState, os);
+ state.print();
+}
+
+void printMLFunction(const MLFunction* fn, const ModuleState* moduleState,
+ raw_ostream &os) {
+ MLFunctionState state(fn, moduleState, os);
+ state.print();
+}
+
//===----------------------------------------------------------------------===//
// print and dump methods
//===----------------------------------------------------------------------===//
+void Type::print(raw_ostream &os) const {
+ ModuleState moduleState(os);
+ moduleState.print(this);
+}
+
+void Type::dump() const {
+ print(llvm::errs());
+}
+
void Instruction::print(raw_ostream &os) const {
- CFGFunctionState state(getFunction(), os);
+ ModuleState moduleState(os);
+ CFGFunctionState state(getFunction(), &moduleState, os);
state.print(this);
}
@@ -406,7 +674,8 @@
}
void BasicBlock::print(raw_ostream &os) const {
- CFGFunctionState state(getFunction(), os);
+ ModuleState moduleState(os);
+ CFGFunctionState state(getFunction(), &moduleState, os);
state.print();
}
@@ -415,7 +684,8 @@
}
void Statement::print(raw_ostream &os) const {
- MLFunctionState state(getFunction(), os);
+ ModuleState moduleState(os);
+ MLFunctionState state(getFunction(), &moduleState, os);
state.print(this);
}
@@ -436,24 +706,21 @@
}
void CFGFunction::print(raw_ostream &os) const {
- CFGFunctionState state(this, os);
+ ModuleState moduleState(os);
+ CFGFunctionState state(this, &moduleState, os);
state.print();
}
void MLFunction::print(raw_ostream &os) const {
- MLFunctionState state(this, os);
+ ModuleState moduleState(os);
+ MLFunctionState state(this, &moduleState, os);
state.print();
}
void Module::print(raw_ostream &os) const {
- unsigned id = 0;
- for (auto *map : affineMapList) {
- os << "#" << id++ << " = ";
- map->print(os);
- os << '\n';
- }
- for (auto *fn : functionList)
- fn->print(os);
+ ModuleState moduleState(os);
+ moduleState.initialize(this);
+ moduleState.print(this);
}
void Module::dump() const {