Refactor the AsmParser to follow the pattern established in the parser:
there is now an explicit state class - which only has one instance per top
level FooThing::print call. The FunctionPrinter's now subclass ModulePrinter
so they can just call print on their types and other global stuff. This also
makes the contract strict that the global FooThing::print calls are the public
entrypoints and that the printer implementation is otherwise self contained.
No Functionality Change.
PiperOrigin-RevId: 205409317
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 31fd05c..b56c775 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -38,37 +38,23 @@
void Identifier::dump() const { print(llvm::errs()); }
-template <typename Container, typename UnaryFunctor>
-inline void interleaveComma(raw_ostream &os, const Container &c,
- UnaryFunctor each_fn) {
- interleave(c.begin(), c.end(), each_fn, [&]() { os << ", "; });
-}
-
//===----------------------------------------------------------------------===//
-// Module printing
+// ModuleState
//===----------------------------------------------------------------------===//
namespace {
class ModuleState {
public:
- ModuleState(raw_ostream &os);
+ /// This is the operation set for the current context if it is knowable (a
+ /// context could be determined), otherwise this is null.
+ OperationSet *const operationSet;
+ explicit ModuleState(MLIRContext *context)
+ : operationSet(context ? &OperationSet::get(context) : nullptr) {}
+
+ // Initializes module state, populating affine map state.
void initialize(const Module *module);
- void print(const Module *module);
- void print(const Attribute *attr) const;
- void print(const Type *type) const;
- void print(const Function *fn);
- void print(const ExtFunction *fn);
- void print(const CFGFunction *fn);
- void print(const MLFunction *fn);
-
- void recordAffineMapReference(const AffineMap *affineMap) {
- if (affineMapIds.count(affineMap) == 0) {
- affineMapIds[affineMap] = nextAffineMapId++;
- }
- }
-
int getAffineMapId(const AffineMap *affineMap) const {
auto it = affineMapIds.find(affineMap);
if (it == affineMapIds.end()) {
@@ -77,7 +63,17 @@
return it->second;
}
+ const DenseMap<const AffineMap *, int> &getAffineMapIds() const {
+ return affineMapIds;
+ }
+
private:
+ void recordAffineMapReference(const AffineMap *affineMap) {
+ if (affineMapIds.count(affineMap) == 0) {
+ affineMapIds[affineMap] = nextAffineMapId++;
+ }
+ }
+
// Visit functions.
void visitFunction(const Function *fn);
void visitExtFunction(const ExtFunction *fn);
@@ -87,23 +83,11 @@
void visitAttribute(const Attribute *attr);
void visitOperation(const Operation *op);
- void printAffineMapId(int affineMapId) const;
- void printAffineMapReference(const AffineMap* affineMap) const;
-
- raw_ostream &os;
DenseMap<const AffineMap *, int> affineMapIds;
int nextAffineMapId = 0;
};
} // end anonymous namespace
-ModuleState::ModuleState(raw_ostream &os) : os(os) {}
-
-// Initializes module state, populating affine map state.
-void ModuleState::initialize(const Module *module) {
- for (auto fn : module->functionList) {
- visitFunction(fn);
- }
-}
// TODO Support visiting other types/instructions when implemented.
void ModuleState::visitType(const Type *type) {
@@ -171,8 +155,54 @@
}
}
+// Initializes module state, populating affine map state.
+void ModuleState::initialize(const Module *module) {
+ for (auto fn : module->functionList) {
+ visitFunction(fn);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// ModulePrinter
+//===----------------------------------------------------------------------===//
+
+namespace {
+class ModulePrinter {
+public:
+ ModulePrinter(raw_ostream &os, ModuleState &state) : os(os), state(state) {}
+ explicit ModulePrinter(const ModulePrinter &printer)
+ : os(printer.os), state(printer.state) {}
+
+ template <typename Container, typename UnaryFunctor>
+ inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
+ interleave(c.begin(), c.end(), each_fn, [&]() { os << ", "; });
+ }
+
+ void print(const Module *module);
+ void print(const Attribute *attr) const;
+ void print(const Type *type) const;
+ void print(const Function *fn);
+ void print(const ExtFunction *fn);
+ void print(const CFGFunction *fn);
+ void print(const MLFunction *fn);
+
+ void print(const AffineMap *map);
+ void print(const AffineExpr *expr) const;
+
+protected:
+ raw_ostream &os;
+ ModuleState &state;
+
+ void printFunctionSignature(const Function *fn);
+ void printAffineMapId(int affineMapId) const;
+ void printAffineMapReference(const AffineMap *affineMap) const;
+
+ void print(const AffineBinaryOpExpr *expr) const;
+};
+} // end anonymous namespace
+
// Prints function with initialized module state.
-void ModuleState::print(const Function *fn) {
+void ModulePrinter::print(const Function *fn) {
switch (fn->getKind()) {
case Function::Kind::ExtFunc:
return print(cast<ExtFunction>(fn));
@@ -184,12 +214,12 @@
}
// Prints affine map identifier.
-void ModuleState::printAffineMapId(int affineMapId) const {
+void ModulePrinter::printAffineMapId(int affineMapId) const {
os << "#map" << affineMapId;
}
-void ModuleState::printAffineMapReference(const AffineMap* affineMap) const {
- const int mapId = getAffineMapId(affineMap);
+void ModulePrinter::printAffineMapReference(const AffineMap *affineMap) const {
+ int mapId = state.getAffineMapId(affineMap);
if (mapId >= 0) {
// Map will be printed at top of module so print reference to its id.
printAffineMapId(mapId);
@@ -199,8 +229,8 @@
}
}
-void ModuleState::print(const Module *module) {
- for (const auto &mapAndId : affineMapIds) {
+void ModulePrinter::print(const Module *module) {
+ for (const auto &mapAndId : state.getAffineMapIds()) {
printAffineMapId(mapAndId.second);
os << " = ";
mapAndId.first->print(os);
@@ -209,7 +239,7 @@
for (auto *fn : module->functionList) print(fn);
}
-void ModuleState::print(const Attribute *attr) const {
+void ModulePrinter::print(const Attribute *attr) const {
switch (attr->getKind()) {
case Attribute::Kind::Bool:
os << (cast<BoolAttr>(attr)->getValue() ? "true" : "false");
@@ -228,7 +258,7 @@
case Attribute::Kind::Array: {
auto elts = cast<ArrayAttr>(attr)->getValue();
os << '[';
- interleaveComma(os, elts, [&](Attribute *attr) { print(attr); });
+ interleaveComma(elts, [&](Attribute *attr) { print(attr); });
os << ']';
break;
}
@@ -238,7 +268,7 @@
}
}
-void ModuleState::print(const Type *type) const {
+void ModulePrinter::print(const Type *type) const {
switch (type->getKind()) {
case Type::Kind::AffineInt:
os << "affineint";
@@ -264,14 +294,14 @@
case Type::Kind::Function: {
auto *func = cast<FunctionType>(type);
os << '(';
- interleaveComma(os, func->getInputs(), [&](Type *type) { os << *type; });
+ interleaveComma(func->getInputs(), [&](Type *type) { os << *type; });
os << ") -> ";
auto results = func->getResults();
if (results.size() == 1)
os << *results[0];
else {
os << '(';
- interleaveComma(os, results, [&](Type *type) { os << *type; });
+ interleaveComma(results, [&](Type *type) { os << *type; });
os << ')';
}
return;
@@ -324,17 +354,132 @@
}
//===----------------------------------------------------------------------===//
+// Affine expressions and maps
+//===----------------------------------------------------------------------===//
+
+void ModulePrinter::print(const AffineExpr *expr) const {
+ switch (expr->getKind()) {
+ case AffineExpr::Kind::SymbolId:
+ os << 's' << cast<AffineSymbolExpr>(expr)->getPosition();
+ return;
+ case AffineExpr::Kind::DimId:
+ os << 'd' << cast<AffineDimExpr>(expr)->getPosition();
+ return;
+ case AffineExpr::Kind::Constant:
+ os << cast<AffineConstantExpr>(expr)->getValue();
+ return;
+ case AffineExpr::Kind::Add:
+ case AffineExpr::Kind::Mul:
+ case AffineExpr::Kind::FloorDiv:
+ case AffineExpr::Kind::CeilDiv:
+ case AffineExpr::Kind::Mod:
+ return print(cast<AffineBinaryOpExpr>(expr));
+ }
+}
+
+void ModulePrinter::print(const AffineBinaryOpExpr *expr) const {
+ if (expr->getKind() != AffineExpr::Kind::Add) {
+ os << '(';
+ print(expr->getLHS());
+ switch (expr->getKind()) {
+ case AffineExpr::Kind::Mul:
+ os << " * ";
+ break;
+ case AffineExpr::Kind::FloorDiv:
+ os << " floordiv ";
+ break;
+ case AffineExpr::Kind::CeilDiv:
+ os << " ceildiv ";
+ break;
+ case AffineExpr::Kind::Mod:
+ os << " mod ";
+ break;
+ default:
+ llvm_unreachable("unexpected affine binary op expression");
+ }
+
+ print(expr->getRHS());
+ os << ')';
+ return;
+ }
+
+ // Print out special "pretty" forms for add.
+ os << '(';
+ print(expr->getLHS());
+
+ // Pretty print addition to a product that has a negative operand as a
+ // subtraction.
+ if (auto *rhs = dyn_cast<AffineBinaryOpExpr>(expr->getRHS())) {
+ if (rhs->getKind() == AffineExpr::Kind::Mul) {
+ if (auto *rrhs = dyn_cast<AffineConstantExpr>(rhs->getRHS())) {
+ if (rrhs->getValue() < 0) {
+ os << " - (";
+ print(rhs->getLHS());
+ os << " * " << -rrhs->getValue() << "))";
+ return;
+ }
+ }
+ }
+ }
+
+ // Pretty print addition to a negative number as a subtraction.
+ if (auto *rhs = dyn_cast<AffineConstantExpr>(expr->getRHS())) {
+ if (rhs->getValue() < 0) {
+ os << " - " << -rhs->getValue() << ")";
+ return;
+ }
+ }
+
+ os << " + ";
+ print(expr->getRHS());
+ os << ')';
+}
+
+void ModulePrinter::print(const AffineMap *map) {
+ // Dimension identifiers.
+ os << '(';
+ for (int i = 0; i < (int)map->getNumDims() - 1; i++)
+ os << "d" << i << ", ";
+ if (map->getNumDims() >= 1)
+ os << "d" << map->getNumDims() - 1;
+ os << ")";
+
+ // Symbolic identifiers.
+ if (map->getNumSymbols() >= 1) {
+ os << " [";
+ for (int i = 0; i < (int)map->getNumSymbols() - 1; i++)
+ os << "s" << i << ", ";
+ if (map->getNumSymbols() >= 1)
+ os << "s" << map->getNumSymbols() - 1;
+ os << "]";
+ }
+
+ // AffineMap should have at least one result.
+ assert(!map->getResults().empty());
+ // Result affine expressions.
+ os << " -> (";
+ interleaveComma(map->getResults(), [&](AffineExpr *expr) { print(expr); });
+ os << ")";
+
+ if (!map->isBounded()) {
+ return;
+ }
+
+ // Print range sizes for bounded affine maps.
+ os << " size (";
+ interleaveComma(map->getRangeSizes(), [&](AffineExpr *expr) { print(expr); });
+ os << ")";
+}
+
+//===----------------------------------------------------------------------===//
// Function printing
//===----------------------------------------------------------------------===//
-static void printFunctionSignature(const Function *fn,
- const ModuleState *moduleState,
- raw_ostream &os) {
+void ModulePrinter::printFunctionSignature(const Function *fn) {
auto type = fn->getType();
os << "@" << fn->getName() << '(';
- interleaveComma(os, type->getInputs(),
- [&](Type *eltType) { moduleState->print(eltType); });
+ interleaveComma(type->getInputs(), [&](Type *eltType) { print(eltType); });
os << ')';
switch (type->getResults().size()) {
@@ -342,20 +487,19 @@
break;
case 1:
os << " -> ";
- moduleState->print(type->getResults()[0]);
+ print(type->getResults()[0]);
break;
default:
os << " -> (";
- interleaveComma(os, type->getResults(),
- [&](Type *eltType) { moduleState->print(eltType); });
+ interleaveComma(type->getResults(), [&](Type *eltType) { print(eltType); });
os << ')';
break;
}
}
-void ModuleState::print(const ExtFunction *fn) {
+void ModulePrinter::print(const ExtFunction *fn) {
os << "extfunc ";
- printFunctionSignature(fn, this, os);
+ printFunctionSignature(fn);
os << '\n';
}
@@ -363,18 +507,13 @@
// FunctionState contains common functionality for printing
// CFG and ML functions.
-class FunctionState {
+class FunctionState : public ModulePrinter {
public:
- FunctionState(MLIRContext *context, const ModuleState *moduleState,
- raw_ostream &os);
+ FunctionState(const ModulePrinter &other) : ModulePrinter(other) {}
void printOperation(const Operation *op);
protected:
- raw_ostream &os;
- const ModuleState *moduleState;
- const OperationSet &operationSet;
-
void numberValueID(const SSAValue *value) {
assert(!valueIDs.count(value) && "Value numbered multiple times");
valueIDs[value] = nextValueID++;
@@ -397,12 +536,6 @@
};
} // end anonymous namespace
-FunctionState::FunctionState(MLIRContext *context,
- const ModuleState *moduleState, raw_ostream &os)
- : os(os),
- moduleState(moduleState),
- operationSet(OperationSet::get(context)) {}
-
void FunctionState::printOperation(const Operation *op) {
os << " ";
@@ -417,7 +550,7 @@
// Check to see if this is a known operation. If so, use the registered
// custom printer hook.
- if (auto opInfo = operationSet.lookup(op->getName().str())) {
+ if (auto opInfo = state.operationSet->lookup(op->getName().str())) {
opInfo->printAssembly(op, os);
return;
}
@@ -431,18 +564,18 @@
// Operation this check can go away.
if (auto *inst = dyn_cast<OperationInst>(op)) {
// TODO: Use getOperands() when we have it.
- interleaveComma(
- os, inst->getInstOperands(),
- [&](const InstOperand &operand) { printValueID(operand.get()); });
+ interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) {
+ printValueID(operand.get());
+ });
}
os << ')';
auto attrs = op->getAttrs();
if (!attrs.empty()) {
os << '{';
- interleaveComma(os, attrs, [&](NamedAttribute attr) {
+ interleaveComma(attrs, [&](NamedAttribute attr) {
os << attr.first << ": ";
- moduleState->print(attr.second);
+ print(attr.second);
});
os << '}';
}
@@ -453,20 +586,18 @@
// Print the type signature of the operation.
os << " : (";
// TODO: Switch to getOperands() when we have it.
- interleaveComma(os, inst->getInstOperands(), [&](const InstOperand &op) {
- moduleState->print(op.get()->getType());
- });
+ interleaveComma(inst->getInstOperands(),
+ [&](const InstOperand &op) { print(op.get()->getType()); });
os << ") -> ";
// TODO: Switch to getResults() when we have it.
if (inst->getNumResults() == 1) {
- moduleState->print(inst->getInstResult(0).getType());
+ print(inst->getInstResult(0).getType());
} else {
os << '(';
- interleaveComma(os, inst->getInstResults(),
- [&](const InstResult &result) {
- moduleState->print(result.getType());
- });
+ interleaveComma(inst->getInstResults(), [&](const InstResult &result) {
+ print(result.getType());
+ });
os << ')';
}
}
@@ -477,10 +608,9 @@
//===----------------------------------------------------------------------===//
namespace {
-class CFGFunctionState : public FunctionState {
+class CFGFunctionPrinter : public FunctionState {
public:
- CFGFunctionState(const CFGFunction *function, const ModuleState *moduleState,
- raw_ostream &os);
+ CFGFunctionPrinter(const CFGFunction *function, const ModulePrinter &other);
const CFGFunction *getFunction() const { return function; }
@@ -502,25 +632,23 @@
const CFGFunction *function;
DenseMap<const BasicBlock *, unsigned> basicBlockIDs;
- void numberBlock(const BasicBlock *block);
+ void numberValuesInBlock(const BasicBlock *block);
};
} // end anonymous namespace
-CFGFunctionState::CFGFunctionState(const CFGFunction *function,
- const ModuleState *moduleState,
- raw_ostream &os)
- : FunctionState(function->getContext(), moduleState, os),
- function(function) {
+CFGFunctionPrinter::CFGFunctionPrinter(const CFGFunction *function,
+ const ModulePrinter &other)
+ : FunctionState(other), function(function) {
// Each basic block gets a unique ID per function.
unsigned blockID = 0;
for (auto &block : *function) {
basicBlockIDs[&block] = blockID++;
- numberBlock(&block);
+ numberValuesInBlock(&block);
}
}
/// Number all of the SSA values in the specified basic block.
-void CFGFunctionState::numberBlock(const BasicBlock *block) {
+void CFGFunctionPrinter::numberValuesInBlock(const BasicBlock *block) {
// TODO: basic block arguments.
for (auto &op : *block) {
// We number instruction that have results, and we only number the first
@@ -532,16 +660,16 @@
// Terminators do not define values.
}
-void CFGFunctionState::print() {
+void CFGFunctionPrinter::print() {
os << "cfgfunc ";
- printFunctionSignature(this->getFunction(), moduleState, os);
+ printFunctionSignature(getFunction());
os << " {\n";
for (auto &block : *function) print(&block);
os << "}\n\n";
}
-void CFGFunctionState::print(const BasicBlock *block) {
+void CFGFunctionPrinter::print(const BasicBlock *block) {
os << "bb" << getBBID(block) << ":\n";
// TODO Print arguments.
@@ -554,7 +682,7 @@
os << "\n";
}
-void CFGFunctionState::print(const Instruction *inst) {
+void CFGFunctionPrinter::print(const Instruction *inst) {
switch (inst->getKind()) {
case Instruction::Kind::Operation:
return print(cast<OperationInst>(inst));
@@ -565,17 +693,16 @@
}
}
-void CFGFunctionState::print(const OperationInst *inst) {
+void CFGFunctionPrinter::print(const OperationInst *inst) {
printOperation(inst);
}
-void CFGFunctionState::print(const BranchInst *inst) {
+void CFGFunctionPrinter::print(const BranchInst *inst) {
os << " br bb" << getBBID(inst->getDest());
}
-void CFGFunctionState::print(const ReturnInst *inst) { os << " return"; }
+void CFGFunctionPrinter::print(const ReturnInst *inst) { os << " return"; }
-void ModuleState::print(const CFGFunction *fn) {
- CFGFunctionState state(fn, this, os);
- state.print();
+void ModulePrinter::print(const CFGFunction *fn) {
+ CFGFunctionPrinter(fn, *this).print();
}
//===----------------------------------------------------------------------===//
@@ -583,10 +710,9 @@
//===----------------------------------------------------------------------===//
namespace {
-class MLFunctionState : public FunctionState {
+class MLFunctionPrinter : public FunctionState {
public:
- MLFunctionState(const MLFunction *function, const ModuleState *moduleState,
- raw_ostream &os);
+ MLFunctionPrinter(const MLFunction *function, const ModulePrinter &other);
const MLFunction *getFunction() const { return function; }
@@ -609,24 +735,21 @@
};
} // end anonymous namespace
-MLFunctionState::MLFunctionState(const MLFunction *function,
- const ModuleState *moduleState,
- raw_ostream &os)
- : FunctionState(function->getContext(), moduleState, os),
- function(function),
- numSpaces(0) {}
+MLFunctionPrinter::MLFunctionPrinter(const MLFunction *function,
+ const ModulePrinter &other)
+ : FunctionState(other), function(function), numSpaces(0) {}
-void MLFunctionState::print() {
+void MLFunctionPrinter::print() {
os << "mlfunc ";
// FIXME: should print argument names rather than just signature
- printFunctionSignature(function, moduleState, os);
+ printFunctionSignature(function);
os << " {\n";
print(function);
os << " return\n";
os << "}\n\n";
}
-void MLFunctionState::print(const StmtBlock *block) {
+void MLFunctionPrinter::print(const StmtBlock *block) {
numSpaces += indentWidth;
for (auto &stmt : block->getStatements()) {
print(&stmt);
@@ -635,7 +758,7 @@
numSpaces -= indentWidth;
}
-void MLFunctionState::print(const Statement *stmt) {
+void MLFunctionPrinter::print(const Statement *stmt) {
switch (stmt->getKind()) {
case Statement::Kind::Operation:
return print(cast<OperationStmt>(stmt));
@@ -646,9 +769,11 @@
}
}
-void MLFunctionState::print(const OperationStmt *stmt) { printOperation(stmt); }
+void MLFunctionPrinter::print(const OperationStmt *stmt) {
+ printOperation(stmt);
+}
-void MLFunctionState::print(const ForStmt *stmt) {
+void MLFunctionPrinter::print(const ForStmt *stmt) {
os.indent(numSpaces) << "for x = " << *stmt->getLowerBound();
os << " to " << *stmt->getUpperBound();
if (stmt->getStep()->getValue() != 1)
@@ -659,7 +784,7 @@
os.indent(numSpaces) << "}";
}
-void MLFunctionState::print(const IfStmt *stmt) {
+void MLFunctionPrinter::print(const IfStmt *stmt) {
os.indent(numSpaces) << "if () {\n";
print(stmt->getThenClause());
os.indent(numSpaces) << "}";
@@ -670,9 +795,8 @@
}
}
-void ModuleState::print(const MLFunction *fn) {
- MLFunctionState state(fn, this, os);
- state.print();
+void ModulePrinter::print(const MLFunction *fn) {
+ MLFunctionPrinter(fn, *this).print();
}
//===----------------------------------------------------------------------===//
@@ -680,8 +804,8 @@
//===----------------------------------------------------------------------===//
void Attribute::print(raw_ostream &os) const {
- ModuleState moduleState(os);
- moduleState.print(this);
+ ModuleState state(/*no context is known*/ nullptr);
+ ModulePrinter(os, state).print(this);
}
void Attribute::dump() const {
@@ -689,23 +813,12 @@
}
void Type::print(raw_ostream &os) const {
- ModuleState moduleState(os);
- moduleState.print(this);
+ ModuleState state(getContext());
+ ModulePrinter(os, state).print(this);
}
void Type::dump() const { print(llvm::errs()); }
-void Instruction::print(raw_ostream &os) const {
- ModuleState moduleState(os);
- CFGFunctionState state(getFunction(), &moduleState, os);
- state.print(this);
-}
-
-void Instruction::dump() const {
- print(llvm::errs());
- llvm::errs() << "\n";
-}
-
void AffineMap::dump() const {
print(llvm::errs());
llvm::errs() << "\n";
@@ -716,163 +829,54 @@
llvm::errs() << "\n";
}
-void AffineSymbolExpr::print(raw_ostream &os) const {
- os << 's' << getPosition();
-}
-
-void AffineDimExpr::print(raw_ostream &os) const { os << 'd' << getPosition(); }
-
-void AffineConstantExpr::print(raw_ostream &os) const { os << getValue(); }
-
-static void printAdd(const AffineBinaryOpExpr *addExpr, raw_ostream &os) {
- os << '(' << *addExpr->getLHS();
-
- // Pretty print addition to a product that has a negative operand as a
- // subtraction.
- if (auto *rhs = dyn_cast<AffineBinaryOpExpr>(addExpr->getRHS())) {
- if (rhs->getKind() == AffineExpr::Kind::Mul) {
- if (auto *rrhs = dyn_cast<AffineConstantExpr>(rhs->getRHS())) {
- if (rrhs->getValue() < 0) {
- os << " - (" << *rhs->getLHS() << " * " << -rrhs->getValue() << "))";
- return;
- }
- }
- }
- }
-
- // Pretty print addition to a negative number as a subtraction.
- if (auto *rhs = dyn_cast<AffineConstantExpr>(addExpr->getRHS())) {
- if (rhs->getValue() < 0) {
- os << " - " << -rhs->getValue() << ")";
- return;
- }
- }
-
- os << " + " << *addExpr->getRHS() << ")";
-}
-
-void AffineBinaryOpExpr::print(raw_ostream &os) const {
- switch (getKind()) {
- case Kind::Add:
- return printAdd(this, os);
- case Kind::Mul:
- os << "(" << *getLHS() << " * " << *getRHS() << ")";
- return;
- case Kind::FloorDiv:
- os << "(" << *getLHS() << " floordiv " << *getRHS() << ")";
- return;
- case Kind::CeilDiv:
- os << "(" << *getLHS() << " ceildiv " << *getRHS() << ")";
- return;
- case Kind::Mod:
- os << "(" << *getLHS() << " mod " << *getRHS() << ")";
- return;
- default:
- llvm_unreachable("unexpected affine binary op expression");
- }
-}
-
void AffineExpr::print(raw_ostream &os) const {
- switch (getKind()) {
- case Kind::SymbolId:
- return cast<AffineSymbolExpr>(this)->print(os);
- case Kind::DimId:
- return cast<AffineDimExpr>(this)->print(os);
- case Kind::Constant:
- return cast<AffineConstantExpr>(this)->print(os);
- case Kind::Add:
- case Kind::Mul:
- case Kind::FloorDiv:
- case Kind::CeilDiv:
- case Kind::Mod:
- return cast<AffineBinaryOpExpr>(this)->print(os);
- }
+ ModuleState state(/*no context is known*/ nullptr);
+ ModulePrinter(os, state).print(this);
}
void AffineMap::print(raw_ostream &os) const {
- // Dimension identifiers.
- os << "(";
- for (int i = 0; i < (int)getNumDims() - 1; i++) os << "d" << i << ", ";
- if (getNumDims() >= 1) os << "d" << getNumDims() - 1;
- os << ")";
+ ModuleState state(/*no context is known*/ nullptr);
+ ModulePrinter(os, state).print(this);
+}
- // Symbolic identifiers.
- if (getNumSymbols() >= 1) {
- os << " [";
- for (int i = 0; i < (int)getNumSymbols() - 1; i++) os << "s" << i << ", ";
- if (getNumSymbols() >= 1) os << "s" << getNumSymbols() - 1;
- os << "]";
- }
+void Instruction::print(raw_ostream &os) const {
+ ModuleState state(getFunction()->getContext());
+ ModulePrinter modulePrinter(os, state);
+ CFGFunctionPrinter(getFunction(), modulePrinter).print(this);
+}
- // AffineMap should have at least one result.
- assert(!getResults().empty());
- // Result affine expressions.
- os << " -> (";
- interleaveComma(os, getResults(), [&](AffineExpr *expr) { os << *expr; });
- os << ")";
-
- if (!isBounded()) {
- return;
- }
-
- // Print range sizes for bounded affine maps.
- os << " size (";
- interleaveComma(os, getRangeSizes(), [&](AffineExpr *expr) { os << *expr; });
- os << ")";
+void Instruction::dump() const {
+ print(llvm::errs());
+ llvm::errs() << "\n";
}
void BasicBlock::print(raw_ostream &os) const {
- ModuleState moduleState(os);
- CFGFunctionState state(getFunction(), &moduleState, os);
- state.print();
+ ModuleState state(getFunction()->getContext());
+ ModulePrinter modulePrinter(os, state);
+ CFGFunctionPrinter(getFunction(), modulePrinter).print(this);
}
void BasicBlock::dump() const { print(llvm::errs()); }
void Statement::print(raw_ostream &os) const {
- ModuleState moduleState(os);
- MLFunctionState state(getFunction(), &moduleState, os);
- state.print(this);
+ ModuleState state(getFunction()->getContext());
+ ModulePrinter modulePrinter(os, state);
+ MLFunctionPrinter(getFunction(), modulePrinter).print(this);
}
void Statement::dump() const { print(llvm::errs()); }
void Function::print(raw_ostream &os) const {
- switch (getKind()) {
- case Kind::ExtFunc:
- return cast<ExtFunction>(this)->print(os);
- case Kind::CFGFunc:
- return cast<CFGFunction>(this)->print(os);
- case Kind::MLFunc:
- return cast<MLFunction>(this)->print(os);
- }
+ ModuleState state(getContext());
+ ModulePrinter(os, state).print(this);
}
void Function::dump() const { print(llvm::errs()); }
-void ExtFunction::print(raw_ostream &os) const {
- ModuleState moduleState(os);
- os << "extfunc ";
- printFunctionSignature(this, &moduleState, os);
- os << "\n";
-}
-
-void CFGFunction::print(raw_ostream &os) const {
- ModuleState moduleState(os);
- CFGFunctionState state(this, &moduleState, os);
- state.print();
-}
-
-void MLFunction::print(raw_ostream &os) const {
- ModuleState moduleState(os);
- MLFunctionState state(this, &moduleState, os);
- state.print();
-}
-
void Module::print(raw_ostream &os) const {
- ModuleState moduleState(os);
- moduleState.initialize(this);
- moduleState.print(this);
+ ModuleState state(getContext());
+ state.initialize(this);
+ ModulePrinter(os, state).print(this);
}
void Module::dump() const { print(llvm::errs()); }