Address AsmPrinter changes from last CL.
PiperOrigin-RevId: 205096519
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index be6573f..baddac9 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -34,37 +34,35 @@
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
+void Identifier::print(raw_ostream &os) const { os << str(); }
-void Identifier::print(raw_ostream &os) const {
- os << str();
-}
-
-void Identifier::dump() const {
- print(llvm::errs());
-}
+void Identifier::dump() const { print(llvm::errs()); }
//===----------------------------------------------------------------------===//
// Module printing
//===----------------------------------------------------------------------===//
namespace {
-class ModuleState {
- public:
+class ModuleState {
+public:
ModuleState(raw_ostream &os);
void initialize(const Module *module);
void print(const Module *module);
void print(const Type *type) const;
- void print(const Function* fn);
+ void print(const Function *fn);
+ void print(const ExtFunction *fn);
+ void print(const CFGFunction *fn);
+ void print(const MLFunction *fn);
- void recordAffineMapReference(const AffineMap* affineMap) {
+ void recordAffineMapReference(const AffineMap *affineMap) {
if (affineMapIds.count(affineMap) == 0) {
affineMapIds[affineMap] = nextAffineMapId++;
}
}
- int getAffineMapId(const AffineMap* affineMap) const {
+ int getAffineMapId(const AffineMap *affineMap) const {
auto it = affineMapIds.find(affineMap);
if (it == affineMapIds.end()) {
return -1;
@@ -72,7 +70,7 @@
return it->second;
}
- private:
+private:
// Visit functions.
void visitFunction(const Function *fn);
void visitExtFunction(const ExtFunction *fn);
@@ -81,13 +79,12 @@
void visitType(const Type *type);
raw_ostream &os;
- DenseMap<const AffineMap*, int> affineMapIds;
+ DenseMap<const AffineMap *, int> affineMapIds;
int nextAffineMapId = 0;
};
-} // end anonymous namespace
+} // end anonymous namespace
-ModuleState::ModuleState(raw_ostream &os) : os(os) {
-}
+ModuleState::ModuleState(raw_ostream &os) : os(os) {}
// Initializes module state, populating affine map state.
void ModuleState::initialize(const Module *module) {
@@ -101,16 +98,16 @@
if (type->getKind() == Type::Kind::Function) {
// Visit input and result types for functions.
auto *funcType = cast<FunctionType>(type);
- for (auto* input : funcType->getInputs()) {
+ for (auto *input : funcType->getInputs()) {
visitType(input);
}
- for (auto* result : funcType->getResults()) {
+ for (auto *result : funcType->getResults()) {
visitType(result);
}
} else if (type->getKind() == Type::Kind::MemRef) {
// Visit affine maps in memref type.
auto *memref = cast<MemRefType>(type);
- for (AffineMap* map : memref->getAffineMaps()) {
+ for (AffineMap *map : memref->getAffineMaps()) {
recordAffineMapReference(map);
}
}
@@ -132,34 +129,24 @@
void ModuleState::visitFunction(const Function *fn) {
switch (fn->getKind()) {
- case Function::Kind::ExtFunc:
- return visitExtFunction(cast<ExtFunction>(fn));
- case Function::Kind::CFGFunc:
- return visitCFGFunction(cast<CFGFunction>(fn));
- case Function::Kind::MLFunc:
- return visitMLFunction(cast<MLFunction>(fn));
+ case Function::Kind::ExtFunc:
+ return visitExtFunction(cast<ExtFunction>(fn));
+ case Function::Kind::CFGFunc:
+ return visitCFGFunction(cast<CFGFunction>(fn));
+ case Function::Kind::MLFunc:
+ return visitMLFunction(cast<MLFunction>(fn));
}
}
-static void printExtFunction(const ExtFunction* fn,
- const ModuleState* moduleState, raw_ostream &os);
-
-
-static void printCFGFunction(const CFGFunction* fn,
- const ModuleState* moduleState, raw_ostream &os);
-
-static void printMLFunction(const MLFunction* fn,
- const ModuleState* moduleState, raw_ostream &os);
-
// Prints function with initialized module state.
-void ModuleState::print(const Function* fn) {
+void ModuleState::print(const Function *fn) {
switch (fn->getKind()) {
- case Function::Kind::ExtFunc:
- return printExtFunction(cast<ExtFunction>(fn), this, os);
- case Function::Kind::CFGFunc:
- return printCFGFunction(cast<CFGFunction>(fn), this, os);
- case Function::Kind::MLFunc:
- return printMLFunction(cast<MLFunction>(fn), this, os);
+ case Function::Kind::ExtFunc:
+ return print(cast<ExtFunction>(fn));
+ case Function::Kind::CFGFunc:
+ return print(cast<CFGFunction>(fn));
+ case Function::Kind::MLFunc:
+ return print(cast<MLFunction>(fn));
}
}
@@ -169,23 +156,32 @@
}
void ModuleState::print(const Module *module) {
- for (const auto& mapAndId : affineMapIds) {
+ for (const auto &mapAndId : affineMapIds) {
printAffineMapId(mapAndId.second, os);
os << " = ";
mapAndId.first->print(os);
os << '\n';
}
- for (auto *fn : module->functionList)
- print(fn);
+ for (auto *fn : module->functionList) print(fn);
}
void ModuleState::print(const Type *type) const {
switch (type->getKind()) {
- case Type::Kind::AffineInt: os << "affineint"; return;
- case Type::Kind::BF16: os << "bf16"; return;
- case Type::Kind::F16: os << "f16"; return;
- case Type::Kind::F32: os << "f32"; return;
- case Type::Kind::F64: os << "f64"; return;
+ case Type::Kind::AffineInt:
+ os << "affineint";
+ return;
+ case Type::Kind::BF16:
+ os << "bf16";
+ return;
+ case Type::Kind::F16:
+ os << "f16";
+ return;
+ case Type::Kind::F32:
+ os << "f32";
+ return;
+ case Type::Kind::F64:
+ os << "f64";
+ return;
case Type::Kind::Integer: {
auto *integer = cast<IntegerType>(type);
@@ -195,8 +191,7 @@
case Type::Kind::Function: {
auto *func = cast<FunctionType>(type);
os << '(';
- interleave(func->getInputs(),
- [&](Type *type) { os << *type; },
+ interleave(func->getInputs(), [&](Type *type) { os << *type; },
[&]() { os << ", "; });
os << ") -> ";
auto results = func->getResults();
@@ -204,8 +199,7 @@
os << *results[0];
else {
os << '(';
- interleave(results,
- [&](Type *type) { os << *type; },
+ interleave(results, [&](Type *type) { os << *type; },
[&]() { os << ", "; });
os << ')';
}
@@ -214,8 +208,7 @@
case Type::Kind::Vector: {
auto *v = cast<VectorType>(type);
os << "vector<";
- for (auto dim : v->getShape())
- os << dim << 'x';
+ for (auto dim : v->getShape()) os << dim << 'x';
os << *v->getElementType() << '>';
return;
}
@@ -282,7 +275,8 @@
os << ')';
switch (type->getResults().size()) {
- case 0: break;
+ case 0:
+ break;
case 1:
os << " -> ";
moduleState->print(type->getResults()[0]);
@@ -297,11 +291,10 @@
}
}
-void ExtFunction::print(raw_ostream &os) const {
- ModuleState moduleState(os);
+void ModuleState::print(const ExtFunction *fn) {
os << "extfunc ";
- printFunctionSignature(this, &moduleState, os);
- os << "\n";
+ printFunctionSignature(fn, this, os);
+ os << '\n';
}
namespace {
@@ -320,12 +313,12 @@
const ModuleState *moduleState;
const OperationSet &operationSet;
};
-} // end anonymous namespace
+} // end anonymous namespace
FunctionState::FunctionState(MLIRContext *context,
- const ModuleState *moduleState,
- raw_ostream &os)
- : os(os), moduleState(moduleState),
+ const ModuleState *moduleState, raw_ostream &os)
+ : os(os),
+ moduleState(moduleState),
operationSet(OperationSet::get(context)) {}
void FunctionState::printOperation(const Operation *op) {
@@ -379,9 +372,9 @@
private:
const CFGFunction *function;
- DenseMap<const BasicBlock*, unsigned> basicBlockIDs;
+ DenseMap<const BasicBlock *, unsigned> basicBlockIDs;
};
-} // end anonymous namespace
+} // end anonymous namespace
CFGFunctionState::CFGFunctionState(const CFGFunction *function,
const ModuleState *moduleState,
@@ -390,8 +383,7 @@
function(function) {
// Each basic block gets a unique ID per function.
unsigned blockID = 0;
- for (auto &block : *function)
- basicBlockIDs[&block] = blockID++;
+ for (auto &block : *function) basicBlockIDs[&block] = blockID++;
}
void CFGFunctionState::print() {
@@ -399,8 +391,7 @@
printFunctionSignature(this->getFunction(), moduleState, os);
os << " {\n";
- for (auto &block : *function)
- print(&block);
+ for (auto &block : *function) print(&block);
os << "}\n\n";
}
@@ -435,8 +426,11 @@
void CFGFunctionState::print(const BranchInst *inst) {
os << " br bb" << getBBID(inst->getDest());
}
-void CFGFunctionState::print(const ReturnInst *inst) {
- os << " return";
+void CFGFunctionState::print(const ReturnInst *inst) { os << " return"; }
+
+void ModuleState::print(const CFGFunction *fn) {
+ CFGFunctionState state(fn, this, os);
+ state.print();
}
//===----------------------------------------------------------------------===//
@@ -468,13 +462,14 @@
const MLFunction *function;
int numSpaces;
};
-} // end anonymous namespace
+} // end anonymous namespace
MLFunctionState::MLFunctionState(const MLFunction *function,
const ModuleState *moduleState,
raw_ostream &os)
: FunctionState(function->getContext(), moduleState, os),
- function(function), numSpaces(0) {}
+ function(function),
+ numSpaces(0) {}
void MLFunctionState::print() {
os << "mlfunc ";
@@ -506,9 +501,7 @@
}
}
-void MLFunctionState::print(const OperationStmt *stmt) {
- printOperation(stmt);
-}
+void MLFunctionState::print(const OperationStmt *stmt) { printOperation(stmt); }
void MLFunctionState::print(const ForStmt *stmt) {
os.indent(numSpaces) << "for {\n";
@@ -527,22 +520,8 @@
}
}
-void printExtFunction(const ExtFunction* fn, const ModuleState* moduleState,
- raw_ostream &os) {
- os << "extfunc ";
- printFunctionSignature(fn, moduleState, os);
- os << '\n';
-}
-
-void printCFGFunction(const CFGFunction* fn, const ModuleState* moduleState,
- raw_ostream &os) {
- CFGFunctionState state(fn, moduleState, os);
- state.print();
-}
-
-void printMLFunction(const MLFunction* fn, const ModuleState* moduleState,
- raw_ostream &os) {
- MLFunctionState state(fn, moduleState, os);
+void ModuleState::print(const MLFunction *fn) {
+ MLFunctionState state(fn, this, os);
state.print();
}
@@ -555,9 +534,7 @@
moduleState.print(this);
}
-void Type::dump() const {
- print(llvm::errs());
-}
+void Type::dump() const { print(llvm::errs()); }
void Instruction::print(raw_ostream &os) const {
ModuleState moduleState(os);
@@ -638,19 +615,15 @@
void AffineMap::print(raw_ostream &os) const {
// Dimension identifiers.
os << "(";
- for (int i = 0; i < (int)getNumDims() - 1; i++)
- os << "d" << i << ", ";
- if (getNumDims() >= 1)
- os << "d" << getNumDims() - 1;
+ for (int i = 0; i < (int)getNumDims() - 1; i++) os << "d" << i << ", ";
+ if (getNumDims() >= 1) os << "d" << getNumDims() - 1;
os << ")";
// Symbolic identifiers.
if (getNumSymbols() >= 1) {
os << " [";
- for (int i = 0; i < (int)getNumSymbols() - 1; i++)
- os << "s" << i << ", ";
- if (getNumSymbols() >= 1)
- os << "s" << getNumSymbols() - 1;
+ for (int i = 0; i < (int)getNumSymbols() - 1; i++) os << "s" << i << ", ";
+ if (getNumSymbols() >= 1) os << "s" << getNumSymbols() - 1;
os << "]";
}
@@ -679,9 +652,7 @@
state.print();
}
-void BasicBlock::dump() const {
- print(llvm::errs());
-}
+void BasicBlock::dump() const { print(llvm::errs()); }
void Statement::print(raw_ostream &os) const {
ModuleState moduleState(os);
@@ -689,20 +660,26 @@
state.print(this);
}
-void Statement::dump() const {
- print(llvm::errs());
-}
+void Statement::dump() const { print(llvm::errs()); }
void Function::print(raw_ostream &os) const {
switch (getKind()) {
- case Kind::ExtFunc: return cast<ExtFunction>(this)->print(os);
- case Kind::CFGFunc: return cast<CFGFunction>(this)->print(os);
- case Kind::MLFunc: return cast<MLFunction>(this)->print(os);
+ case Kind::ExtFunc:
+ return cast<ExtFunction>(this)->print(os);
+ case Kind::CFGFunc:
+ return cast<CFGFunction>(this)->print(os);
+ case Kind::MLFunc:
+ return cast<MLFunction>(this)->print(os);
}
}
-void Function::dump() const {
- print(llvm::errs());
+void Function::dump() const { print(llvm::errs()); }
+
+void ExtFunction::print(raw_ostream &os) const {
+ ModuleState moduleState(os);
+ os << "extfunc ";
+ printFunctionSignature(this, &moduleState, os);
+ os << "\n";
}
void CFGFunction::print(raw_ostream &os) const {
@@ -723,6 +700,4 @@
moduleState.print(this);
}
-void Module::dump() const {
- print(llvm::errs());
-}
+void Module::dump() const { print(llvm::errs()); }