Refactor implementation of Statement class heirarchy to use statement block.
Use LLVM double-link with parent list to store statements within a block.
PiperOrigin-RevId: 204515541
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 27ed3e4..0103e6a 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -27,6 +27,7 @@
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSet.h"
+#include "mlir/IR/Statements.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/DenseMap.h"
@@ -212,18 +213,20 @@
const MLFunction *getFunction() const { return function; }
+ // Prints ML function
void print();
+ // Methods to print ML function statements
void print(const Statement *stmt);
void print(const OperationStmt *stmt);
void print(const ForStmt *stmt);
void print(const IfStmt *stmt);
- void print(const ElseClause *stmt, bool isLast);
+ void print(const StmtBlock *block);
+
+ // Number of spaces used for indenting nested statements
+ const static unsigned indentWidth = 2;
private:
- // Print statements nested within this node statement.
- void printNestedStatements(const NodeStmt *stmt);
-
const MLFunction *function;
int numSpaces;
};
@@ -231,21 +234,26 @@
MLFunctionState::MLFunctionState(const MLFunction *function, raw_ostream &os)
: FunctionState(function->getContext(), os), function(function),
- numSpaces(2) {}
+ numSpaces(0) {}
void MLFunctionState::print() {
os << "mlfunc ";
// FIXME: should print argument names rather than just signature
printFunctionSignature(function, os);
os << " {\n";
- for (auto *stmt : function->stmtList)
- print(stmt);
+ print(function);
os << " return\n";
os << "}\n\n";
}
+void MLFunctionState::print(const StmtBlock *block) {
+ numSpaces += indentWidth;
+ for (auto &stmt : block->getStatements())
+ print(&stmt);
+ numSpaces -= indentWidth;
+}
+
void MLFunctionState::print(const Statement *stmt) {
- os.indent(numSpaces);
switch (stmt->getKind()) {
case Statement::Kind::Operation: // TODO
llvm_unreachable("Operation statement is not yet implemented");
@@ -253,45 +261,31 @@
return print(cast<ForStmt>(stmt));
case Statement::Kind::If:
return print(cast<IfStmt>(stmt));
- case Statement::Kind::Else:
- return print(cast<ElseClause>(stmt));
}
}
-void MLFunctionState::printNestedStatements(const NodeStmt *stmt) {
- os << "{\n";
- numSpaces += 2;
- for (auto * nestedStmt : stmt->children)
- print(nestedStmt);
- numSpaces -= 2;
- os.indent(numSpaces) << "}";
+void MLFunctionState::print(const OperationStmt *stmt) {
+ printOperation(stmt);
}
-void MLFunctionState::print(const OperationStmt *stmt) { printOperation(stmt); }
-
void MLFunctionState::print(const ForStmt *stmt) {
- os << "for ";
- printNestedStatements(stmt);
- os << "\n";
+ os.indent(numSpaces) << "for {\n";
+ print(static_cast<const StmtBlock *>(stmt));
+ os.indent(numSpaces) << "}\n";
}
void MLFunctionState::print(const IfStmt *stmt) {
- os << "if ";
- printNestedStatements(stmt);
-
- int numClauses = stmt->elseClauses.size();
- for (auto e : stmt->elseClauses)
- print(e, e->getClauseNumber() == numClauses - 1);
+ os.indent(numSpaces) << "if () {\n";
+ print(stmt->getThenClause());
+ os.indent(numSpaces) << "}";
+ if (stmt->hasElseClause()) {
+ os << " else {\n";
+ print(stmt->getElseClause());
+ os.indent(numSpaces) << "}";
+ }
os << "\n";
}
-void MLFunctionState::print(const ElseClause *stmt, bool isLast) {
- if (!isLast)
- os << " if";
- os << " else ";;
- printNestedStatements(stmt);
-}
-
//===----------------------------------------------------------------------===//
// print and dump methods
//===----------------------------------------------------------------------===//
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index 92c3981..453797f 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -48,5 +48,5 @@
//===----------------------------------------------------------------------===//
MLFunction::MLFunction(StringRef name, FunctionType *type)
- : Function(name, type, Kind::MLFunc) {
+ : Function(name, type, Kind::MLFunc), StmtBlock() {
}
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
new file mode 100644
index 0000000..336cf7a
--- /dev/null
+++ b/lib/IR/Statement.cpp
@@ -0,0 +1,119 @@
+//===- Statement.cpp - MLIR Statement Classes ----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Statements.h"
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Statement
+//===------------------------------------------------------------------===//
+
+// Statements are deleted through the destroy() member because we don't have
+// a virtual destructor.
+Statement::~Statement() {
+ assert(block == nullptr && "statement destroyed but still in a block");
+}
+
+/// Destroy this statement or one of its subclasses.
+void Statement::destroy(Statement *stmt) {
+ switch (stmt->getKind()) {
+ case Kind::Operation:
+ delete cast<OperationStmt>(stmt);
+ break;
+ case Kind::For:
+ delete cast<ForStmt>(stmt);
+ break;
+ case Kind::If:
+ delete cast<IfStmt>(stmt);
+ break;
+ }
+}
+
+MLFunction *Statement::getFunction() const {
+ return this->getBlock()->getFunction();
+}
+
+//===----------------------------------------------------------------------===//
+// ilist_traits for Statement
+//===----------------------------------------------------------------------===//
+
+StmtBlock *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() {
+ size_t Offset(
+ size_t(&((StmtBlock *)nullptr->*StmtBlock::getSublistAccess(nullptr))));
+ iplist<Statement> *Anchor(static_cast<iplist<Statement> *>(this));
+ return reinterpret_cast<StmtBlock *>(reinterpret_cast<char *>(Anchor) -
+ Offset);
+}
+
+/// This is a trait method invoked when a statement is added to a block. We
+/// keep the block pointer up to date.
+void llvm::ilist_traits<::mlir::Statement>::addNodeToList(Statement *stmt) {
+ assert(!stmt->getBlock() && "already in a statement block!");
+ stmt->block = getContainingBlock();
+}
+
+/// This is a trait method invoked when a statement is removed from a block.
+/// We keep the block pointer up to date.
+void llvm::ilist_traits<::mlir::Statement>::removeNodeFromList(
+ Statement *stmt) {
+ assert(stmt->block && "not already in a statement block!");
+ stmt->block = nullptr;
+}
+
+/// This is a trait method invoked when a statement is moved from one block
+/// to another. We keep the block pointer up to date.
+void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList(
+ ilist_traits<Statement> &otherList, stmt_iterator first,
+ stmt_iterator last) {
+ // If we are transferring statements within the same block, the block
+ // pointer doesn't need to be updated.
+ StmtBlock *curParent = getContainingBlock();
+ if (curParent == otherList.getContainingBlock())
+ return;
+
+ // Update the 'block' member of each statement.
+ for (; first != last; ++first)
+ first->block = curParent;
+}
+
+/// Remove this statement from its StmtBlock and delete it.
+void Statement::eraseFromBlock() {
+ assert(getBlock() && "Statement has no block");
+ getBlock()->getStatements().erase(this);
+}
+
+//===----------------------------------------------------------------------===//
+// IfClause
+//===----------------------------------------------------------------------===//
+
+IfClause::IfClause(IfStmt *stmt) : StmtBlock(stmt) {
+ assert(stmt != nullptr && "If clause must have non-null parent");
+}
+
+IfStmt *IfClause::getIf() const { return static_cast<IfStmt *>(parent); }
+
+//===----------------------------------------------------------------------===//
+// IfStmt
+//===----------------------------------------------------------------------===//
+
+IfStmt::~IfStmt() {
+ // TODO: correctly delete StmtBlocks under then and else clauses
+ delete thenClause;
+ if (elseClause != nullptr)
+ delete elseClause;
+}
diff --git a/lib/IR/Statements.cpp b/lib/IR/Statements.cpp
deleted file mode 100644
index ab3f8fc..0000000
--- a/lib/IR/Statements.cpp
+++ /dev/null
@@ -1,40 +0,0 @@
-//===- Statements.cpp - MLIR Statement Instruction Classes ------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "mlir/IR/MLFunction.h"
-#include "mlir/IR/Statements.h"
-using namespace mlir;
-
-//===----------------------------------------------------------------------===//
-// Statement
-//===----------------------------------------------------------------------===//
-
-MLFunction *Statement::getFunction() const {
- ParentType p = parent;
- while (!p.is<MLFunction *>())
- p = p.get<NodeStmt *>()->getParent();
- return p.get<MLFunction *>();
-}
-
-//===----------------------------------------------------------------------===//
-// ElseClause
-//===----------------------------------------------------------------------===//
-
-ElseClause::ElseClause(IfStmt *ifStmt, int clauseNum)
- : NodeStmt(Kind::Else, ifStmt), clauseNum(clauseNum) {
- ifStmt->elseClauses.push_back(this);
-}
diff --git a/lib/IR/StmtBlock.cpp b/lib/IR/StmtBlock.cpp
new file mode 100644
index 0000000..83e0412
--- /dev/null
+++ b/lib/IR/StmtBlock.cpp
@@ -0,0 +1,32 @@
+//===- StmtBlock.cpp - MLIR Statement Instruction Classes -----------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/StmtBlock.h"
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Statement block
+//===----------------------------------------------------------------------===//
+
+MLFunction *StmtBlock::getFunction() const {
+ StmtBlock *block = const_cast<StmtBlock *>(this);
+
+ while (block->getParent() != nullptr)
+ block = block->getParent()->getBlock();
+ return static_cast<MLFunction *>(block);
+}