Refactor implementation of Statement class heirarchy to use statement block.
Use LLVM double-link with parent list to store statements within a block.
PiperOrigin-RevId: 204515541
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 27ed3e4..0103e6a 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -27,6 +27,7 @@
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSet.h"
+#include "mlir/IR/Statements.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/DenseMap.h"
@@ -212,18 +213,20 @@
const MLFunction *getFunction() const { return function; }
+ // Prints ML function
void print();
+ // Methods to print ML function statements
void print(const Statement *stmt);
void print(const OperationStmt *stmt);
void print(const ForStmt *stmt);
void print(const IfStmt *stmt);
- void print(const ElseClause *stmt, bool isLast);
+ void print(const StmtBlock *block);
+
+ // Number of spaces used for indenting nested statements
+ const static unsigned indentWidth = 2;
private:
- // Print statements nested within this node statement.
- void printNestedStatements(const NodeStmt *stmt);
-
const MLFunction *function;
int numSpaces;
};
@@ -231,21 +234,26 @@
MLFunctionState::MLFunctionState(const MLFunction *function, raw_ostream &os)
: FunctionState(function->getContext(), os), function(function),
- numSpaces(2) {}
+ numSpaces(0) {}
void MLFunctionState::print() {
os << "mlfunc ";
// FIXME: should print argument names rather than just signature
printFunctionSignature(function, os);
os << " {\n";
- for (auto *stmt : function->stmtList)
- print(stmt);
+ print(function);
os << " return\n";
os << "}\n\n";
}
+void MLFunctionState::print(const StmtBlock *block) {
+ numSpaces += indentWidth;
+ for (auto &stmt : block->getStatements())
+ print(&stmt);
+ numSpaces -= indentWidth;
+}
+
void MLFunctionState::print(const Statement *stmt) {
- os.indent(numSpaces);
switch (stmt->getKind()) {
case Statement::Kind::Operation: // TODO
llvm_unreachable("Operation statement is not yet implemented");
@@ -253,45 +261,31 @@
return print(cast<ForStmt>(stmt));
case Statement::Kind::If:
return print(cast<IfStmt>(stmt));
- case Statement::Kind::Else:
- return print(cast<ElseClause>(stmt));
}
}
-void MLFunctionState::printNestedStatements(const NodeStmt *stmt) {
- os << "{\n";
- numSpaces += 2;
- for (auto * nestedStmt : stmt->children)
- print(nestedStmt);
- numSpaces -= 2;
- os.indent(numSpaces) << "}";
+void MLFunctionState::print(const OperationStmt *stmt) {
+ printOperation(stmt);
}
-void MLFunctionState::print(const OperationStmt *stmt) { printOperation(stmt); }
-
void MLFunctionState::print(const ForStmt *stmt) {
- os << "for ";
- printNestedStatements(stmt);
- os << "\n";
+ os.indent(numSpaces) << "for {\n";
+ print(static_cast<const StmtBlock *>(stmt));
+ os.indent(numSpaces) << "}\n";
}
void MLFunctionState::print(const IfStmt *stmt) {
- os << "if ";
- printNestedStatements(stmt);
-
- int numClauses = stmt->elseClauses.size();
- for (auto e : stmt->elseClauses)
- print(e, e->getClauseNumber() == numClauses - 1);
+ os.indent(numSpaces) << "if () {\n";
+ print(stmt->getThenClause());
+ os.indent(numSpaces) << "}";
+ if (stmt->hasElseClause()) {
+ os << " else {\n";
+ print(stmt->getElseClause());
+ os.indent(numSpaces) << "}";
+ }
os << "\n";
}
-void MLFunctionState::print(const ElseClause *stmt, bool isLast) {
- if (!isLast)
- os << " if";
- os << " else ";;
- printNestedStatements(stmt);
-}
-
//===----------------------------------------------------------------------===//
// print and dump methods
//===----------------------------------------------------------------------===//