Implement OperationStmt. Refactor function printing to use FunctionState class for operation printing. FunctionState class is a base class for CFGFunctionState and MLFunctionState classes. No parsing yet - will add once cl/203785893 is in.
PiperOrigin-RevId: 203862427
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index f2d1e4f..715459f 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -76,12 +76,56 @@
os << "\n";
}
+namespace {
+
+// FunctionState contains common functionality for printing
+// CFG and ML functions.
+class FunctionState {
+public:
+ FunctionState(MLIRContext *context, raw_ostream &os);
+
+ void printOperation(const Operation *op);
+
+protected:
+ raw_ostream &os;
+ const OperationSet &operationSet;
+};
+} // end anonymous namespace
+
+FunctionState::FunctionState(MLIRContext *context, raw_ostream &os)
+ : os(os), operationSet(OperationSet::get(context)) {}
+
+void FunctionState::printOperation(const Operation *op) {
+ // Check to see if this is a known operation. If so, use the registered
+ // custom printer hook.
+ if (auto opInfo = operationSet.lookup(op->getName().str())) {
+ os << " ";
+ opInfo->printAssembly(op, os);
+ return;
+ }
+
+ // TODO: escape name if necessary.
+ os << " \"" << op->getName().str() << "\"()";
+
+ auto attrs = op->getAttrs();
+ if (!attrs.empty()) {
+ os << '{';
+ interleave(
+ attrs,
+ [&](NamedAttribute attr) { os << attr.first << ": " << *attr.second; },
+ [&]() { os << ", "; });
+ os << '}';
+ }
+
+ os << '\n';
+}
+
//===----------------------------------------------------------------------===//
// CFG Function printing
//===----------------------------------------------------------------------===//
namespace {
-class CFGFunctionState {
+class CFGFunctionState : public FunctionState {
public:
CFGFunctionState(const CFGFunction *function, raw_ostream &os);
@@ -103,16 +147,12 @@
private:
const CFGFunction *function;
- raw_ostream &os;
- const OperationSet &operationSet;
DenseMap<const BasicBlock*, unsigned> basicBlockIDs;
};
} // end anonymous namespace
CFGFunctionState::CFGFunctionState(const CFGFunction *function, raw_ostream &os)
- : function(function), os(os),
- operationSet(OperationSet::get(function->getContext())) {
-
+ : FunctionState(function->getContext(), os), function(function) {
// Each basic block gets a unique ID per function.
unsigned blockID = 0;
for (auto &block : *function)
@@ -151,27 +191,7 @@
}
void CFGFunctionState::print(const OperationInst *inst) {
- // Check to see if this is a known operation. If so, use the registered
- // custom printer hook.
- if (auto opInfo = operationSet.lookup(inst->getName().str())) {
- os << " ";
- opInfo->printAssembly(inst, os);
- return;
- }
-
- // TODO: escape name if necessary.
- os << " \"" << inst->getName().str() << "\"()";
-
- auto attrs = inst->getAttrs();
- if (!attrs.empty()) {
- os << '{';
- interleave(attrs, [&](NamedAttribute attr) {
- os << attr.first << ": " << *attr.second;
- }, [&]() { os << ", "; });
- os << '}';
- }
-
- os << '\n';
+ printOperation(inst);
}
void CFGFunctionState::print(const BranchInst *inst) {
@@ -186,7 +206,7 @@
//===----------------------------------------------------------------------===//
namespace {
-class MLFunctionState {
+class MLFunctionState : public FunctionState {
public:
MLFunctionState(const MLFunction *function, raw_ostream &os);
@@ -195,6 +215,7 @@
void print();
void print(const Statement *stmt);
+ void print(const OperationStmt *stmt);
void print(const ForStmt *stmt);
void print(const IfStmt *stmt);
void print(const ElseClause *stmt, bool isLast);
@@ -204,13 +225,13 @@
void printNestedStatements(const NodeStmt *stmt);
const MLFunction *function;
- raw_ostream &os;
int numSpaces;
};
} // end anonymous namespace
MLFunctionState::MLFunctionState(const MLFunction *function, raw_ostream &os)
- : function(function), os(os), numSpaces(2) {}
+ : FunctionState(function->getContext(), os), function(function),
+ numSpaces(2) {}
void MLFunctionState::print() {
os << "mlfunc ";
@@ -246,6 +267,8 @@
os.indent(numSpaces) << "}";
}
+void MLFunctionState::print(const OperationStmt *stmt) { printOperation(stmt); }
+
void MLFunctionState::print(const ForStmt *stmt) {
os << "for ";
printNestedStatements(stmt);