Refactor implementation of Statement class heirarchy to use statement block.
Use LLVM double-link with parent list to store statements within a block.

PiperOrigin-RevId: 204515541
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 27ed3e4..0103e6a 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -27,6 +27,7 @@
 #include "mlir/IR/MLFunction.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/OperationSet.h"
+#include "mlir/IR/Statements.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Support/STLExtras.h"
 #include "llvm/ADT/DenseMap.h"
@@ -212,18 +213,20 @@
 
   const MLFunction *getFunction() const { return function; }
 
+  // Prints ML function
   void print();
 
+  // Methods to print ML function statements
   void print(const Statement *stmt);
   void print(const OperationStmt *stmt);
   void print(const ForStmt *stmt);
   void print(const IfStmt *stmt);
-  void print(const ElseClause *stmt, bool isLast);
+  void print(const StmtBlock *block);
+
+  // Number of spaces used for indenting nested statements
+  const static unsigned indentWidth = 2;
 
 private:
-  // Print statements nested within this node statement.
-  void printNestedStatements(const NodeStmt *stmt);
-
   const MLFunction *function;
   int numSpaces;
 };
@@ -231,21 +234,26 @@
 
 MLFunctionState::MLFunctionState(const MLFunction *function, raw_ostream &os)
     : FunctionState(function->getContext(), os), function(function),
-      numSpaces(2) {}
+      numSpaces(0) {}
 
 void MLFunctionState::print() {
   os << "mlfunc ";
   // FIXME: should print argument names rather than just signature
   printFunctionSignature(function, os);
   os << " {\n";
-  for (auto *stmt : function->stmtList)
-    print(stmt);
+  print(function);
   os << "  return\n";
   os << "}\n\n";
 }
 
+void MLFunctionState::print(const StmtBlock *block) {
+  numSpaces += indentWidth;
+  for (auto &stmt : block->getStatements())
+    print(&stmt);
+  numSpaces -= indentWidth;
+}
+
 void MLFunctionState::print(const Statement *stmt) {
-  os.indent(numSpaces);
   switch (stmt->getKind()) {
   case Statement::Kind::Operation: // TODO
     llvm_unreachable("Operation statement is not yet implemented");
@@ -253,45 +261,31 @@
     return print(cast<ForStmt>(stmt));
   case Statement::Kind::If:
     return print(cast<IfStmt>(stmt));
-  case Statement::Kind::Else:
-    return print(cast<ElseClause>(stmt));
   }
 }
 
-void MLFunctionState::printNestedStatements(const NodeStmt *stmt) {
-  os << "{\n";
-  numSpaces += 2;
-  for (auto * nestedStmt : stmt->children)
-    print(nestedStmt);
-  numSpaces -= 2;
-  os.indent(numSpaces) << "}";
+void MLFunctionState::print(const OperationStmt *stmt) {
+  printOperation(stmt);
 }
 
-void MLFunctionState::print(const OperationStmt *stmt) { printOperation(stmt); }
-
 void MLFunctionState::print(const ForStmt *stmt) {
-  os << "for ";
-  printNestedStatements(stmt);
-  os << "\n";
+  os.indent(numSpaces) << "for {\n";
+  print(static_cast<const StmtBlock *>(stmt));
+  os.indent(numSpaces) << "}\n";
 }
 
 void MLFunctionState::print(const IfStmt *stmt) {
-  os << "if ";
-  printNestedStatements(stmt);
-
-  int numClauses = stmt->elseClauses.size();
-  for (auto e : stmt->elseClauses)
-    print(e, e->getClauseNumber() == numClauses - 1);
+  os.indent(numSpaces) << "if () {\n";
+  print(stmt->getThenClause());
+  os.indent(numSpaces) << "}";
+  if (stmt->hasElseClause()) {
+    os << " else {\n";
+    print(stmt->getElseClause());
+    os.indent(numSpaces) << "}";
+  }
   os << "\n";
 }
 
-void MLFunctionState::print(const ElseClause *stmt, bool isLast) {
-  if (!isLast)
-    os << " if";
-  os << " else ";;
-  printNestedStatements(stmt);
-}
-
 //===----------------------------------------------------------------------===//
 // print and dump methods
 //===----------------------------------------------------------------------===//
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index 92c3981..453797f 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -48,5 +48,5 @@
 //===----------------------------------------------------------------------===//
 
 MLFunction::MLFunction(StringRef name, FunctionType *type)
-  : Function(name, type, Kind::MLFunc) {
+  : Function(name, type, Kind::MLFunc), StmtBlock() {
 }
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
new file mode 100644
index 0000000..336cf7a
--- /dev/null
+++ b/lib/IR/Statement.cpp
@@ -0,0 +1,119 @@
+//===- Statement.cpp - MLIR Statement Classes ----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Statements.h"
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Statement
+//===------------------------------------------------------------------===//
+
+// Statements are deleted through the destroy() member because we don't have
+// a virtual destructor.
+Statement::~Statement() {
+  assert(block == nullptr && "statement destroyed but still in a block");
+}
+
+/// Destroy this statement or one of its subclasses.
+void Statement::destroy(Statement *stmt) {
+  switch (stmt->getKind()) {
+  case Kind::Operation:
+    delete cast<OperationStmt>(stmt);
+    break;
+  case Kind::For:
+    delete cast<ForStmt>(stmt);
+    break;
+  case Kind::If:
+    delete cast<IfStmt>(stmt);
+    break;
+  }
+}
+
+MLFunction *Statement::getFunction() const {
+  return this->getBlock()->getFunction();
+}
+
+//===----------------------------------------------------------------------===//
+// ilist_traits for Statement
+//===----------------------------------------------------------------------===//
+
+StmtBlock *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() {
+  size_t Offset(
+      size_t(&((StmtBlock *)nullptr->*StmtBlock::getSublistAccess(nullptr))));
+  iplist<Statement> *Anchor(static_cast<iplist<Statement> *>(this));
+  return reinterpret_cast<StmtBlock *>(reinterpret_cast<char *>(Anchor) -
+                                       Offset);
+}
+
+/// This is a trait method invoked when a statement is added to a block.  We
+/// keep the block pointer up to date.
+void llvm::ilist_traits<::mlir::Statement>::addNodeToList(Statement *stmt) {
+  assert(!stmt->getBlock() && "already in a statement block!");
+  stmt->block = getContainingBlock();
+}
+
+/// This is a trait method invoked when a statement is removed from a block.
+/// We keep the block pointer up to date.
+void llvm::ilist_traits<::mlir::Statement>::removeNodeFromList(
+    Statement *stmt) {
+  assert(stmt->block && "not already in a statement block!");
+  stmt->block = nullptr;
+}
+
+/// This is a trait method invoked when a statement is moved from one block
+/// to another.  We keep the block pointer up to date.
+void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList(
+    ilist_traits<Statement> &otherList, stmt_iterator first,
+    stmt_iterator last) {
+  // If we are transferring statements within the same block, the block
+  // pointer doesn't need to be updated.
+  StmtBlock *curParent = getContainingBlock();
+  if (curParent == otherList.getContainingBlock())
+    return;
+
+  // Update the 'block' member of each statement.
+  for (; first != last; ++first)
+    first->block = curParent;
+}
+
+/// Remove this statement from its StmtBlock and delete it.
+void Statement::eraseFromBlock() {
+  assert(getBlock() && "Statement has no block");
+  getBlock()->getStatements().erase(this);
+}
+
+//===----------------------------------------------------------------------===//
+// IfClause
+//===----------------------------------------------------------------------===//
+
+IfClause::IfClause(IfStmt *stmt) : StmtBlock(stmt) {
+  assert(stmt != nullptr && "If clause must have non-null parent");
+}
+
+IfStmt *IfClause::getIf() const { return static_cast<IfStmt *>(parent); }
+
+//===----------------------------------------------------------------------===//
+// IfStmt
+//===----------------------------------------------------------------------===//
+
+IfStmt::~IfStmt() {
+  // TODO: correctly delete StmtBlocks under then and else clauses
+  delete thenClause;
+  if (elseClause != nullptr)
+    delete elseClause;
+}
diff --git a/lib/IR/Statements.cpp b/lib/IR/Statements.cpp
deleted file mode 100644
index ab3f8fc..0000000
--- a/lib/IR/Statements.cpp
+++ /dev/null
@@ -1,40 +0,0 @@
-//===- Statements.cpp - MLIR Statement Instruction Classes ------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "mlir/IR/MLFunction.h"
-#include "mlir/IR/Statements.h"
-using namespace mlir;
-
-//===----------------------------------------------------------------------===//
-// Statement
-//===----------------------------------------------------------------------===//
-
-MLFunction *Statement::getFunction() const {
-  ParentType p = parent;
-  while (!p.is<MLFunction *>())
-    p = p.get<NodeStmt *>()->getParent();
-  return p.get<MLFunction *>();
-}
-
-//===----------------------------------------------------------------------===//
-// ElseClause
-//===----------------------------------------------------------------------===//
-
-ElseClause::ElseClause(IfStmt *ifStmt, int clauseNum)
-      : NodeStmt(Kind::Else, ifStmt), clauseNum(clauseNum) {
-  ifStmt->elseClauses.push_back(this);
-}
diff --git a/lib/IR/StmtBlock.cpp b/lib/IR/StmtBlock.cpp
new file mode 100644
index 0000000..83e0412
--- /dev/null
+++ b/lib/IR/StmtBlock.cpp
@@ -0,0 +1,32 @@
+//===- StmtBlock.cpp - MLIR Statement Instruction Classes -----------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/StmtBlock.h"
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Statement block
+//===----------------------------------------------------------------------===//
+
+MLFunction *StmtBlock::getFunction() const {
+  StmtBlock *block = const_cast<StmtBlock *>(this);
+
+  while (block->getParent() != nullptr)
+    block = block->getParent()->getBlock();
+  return static_cast<MLFunction *>(block);
+}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 5838f73..e1d184d 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -28,6 +28,7 @@
 #include "mlir/IR/MLFunction.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/OperationSet.h"
+#include "mlir/IR/Statements.h"
 #include "mlir/IR/Types.h"
 #include "llvm/Support/SourceMgr.h"
 using namespace mlir;
@@ -1381,10 +1382,13 @@
       : Parser(state), function(function) {}
 
   ParseResult parseFunctionBody();
-  Statement *parseStatement(ParentType parent);
-  ForStmt *parseForStmt(ParentType parent);
-  IfStmt *parseIfStmt(ParentType parent);
-  ParseResult parseNestedStatements(NodeStmt *parent);
+
+private:
+  Statement *parseStatement();
+  ForStmt *parseForStmt();
+  IfStmt *parseIfStmt();
+  ParseResult parseElseClause(IfClause *elseClause);
+  ParseResult parseStmtBlock(StmtBlock *block);
 };
 } // end anonymous namespace
 
@@ -1398,10 +1402,10 @@
 
   // Parse the list of instructions.
   while (!consumeIf(Token::kw_return)) {
-    auto *stmt = parseStatement(function);
+    auto *stmt = parseStatement();
     if (!stmt)
       return ParseFailure;
-    function->stmtList.push_back(stmt);
+    function->push_back(stmt);
   }
 
   // TODO: parse return statement operands
@@ -1420,17 +1424,17 @@
 /// TODO: fix terminology in MLSpec document. ML functions
 /// contain operation statements, not instructions.
 ///
-Statement *MLFunctionParser::parseStatement(ParentType parent) {
+Statement *MLFunctionParser::parseStatement() {
   switch (getToken().getKind()) {
   default:
     //TODO: parse OperationStmt
     return (emitError("expected statement"), nullptr);
 
   case Token::kw_for:
-    return parseForStmt(parent);
+    return parseForStmt();
 
   case Token::kw_if:
-    return parseIfStmt(parent);
+    return parseIfStmt();
   }
 }
 
@@ -1439,12 +1443,12 @@
 ///    ml-for-stmt ::= `for` ssa-id `=` lower-bound `to` upper-bound
 ///                   (`step` integer-literal)? `{` ml-stmt* `}`
 ///
-ForStmt *MLFunctionParser::parseForStmt(ParentType parent) {
+ForStmt *MLFunctionParser::parseForStmt() {
   consumeToken(Token::kw_for);
 
   //TODO: parse loop header
-  ForStmt *stmt = new ForStmt(parent);
-  if (parseNestedStatements(stmt)) {
+  ForStmt *stmt = new ForStmt();
+  if (parseStmtBlock(static_cast<StmtBlock *>(stmt))) {
     delete stmt;
     return nullptr;
   }
@@ -1458,50 +1462,61 @@
 ///   ml-if-stmt ::= ml-if-head
 ///               | ml-if-head `else` `{` ml-stmt* `}`
 ///
-IfStmt *
-MLFunctionParser::parseIfStmt(PointerUnion<MLFunction *, NodeStmt *> parent) {
+IfStmt *MLFunctionParser::parseIfStmt() {
   consumeToken(Token::kw_if);
+  if (!consumeIf(Token::l_paren))
+    return (emitError("expected ("), nullptr);
 
   //TODO: parse condition
-  IfStmt *stmt = new IfStmt(parent);
-  if (parseNestedStatements(stmt)) {
-    delete stmt;
+
+  if (!consumeIf(Token::r_paren))
+    return (emitError("expected )"), nullptr);
+
+  IfStmt *ifStmt = new IfStmt();
+  IfClause *thenClause = ifStmt->getThenClause();
+  if (parseStmtBlock(thenClause)) {
+    delete ifStmt;
     return nullptr;
   }
 
-  int clauseNum = 0;
-  while (consumeIf(Token::kw_else)) {
-    if (consumeIf(Token::kw_if)) {
-       //TODO: parse condition
-    }
-    ElseClause * clause = new ElseClause(stmt, clauseNum);
-    ++clauseNum;
-    if (parseNestedStatements(clause)) {
-      delete clause;
+  if (consumeIf(Token::kw_else)) {
+    IfClause *elseClause = ifStmt->createElseClause();
+    if (parseElseClause(elseClause)) {
+      delete ifStmt;
       return nullptr;
     }
   }
 
-  return stmt;
+  return ifStmt;
+}
+
+ParseResult MLFunctionParser::parseElseClause(IfClause *elseClause) {
+  if (getToken().is(Token::kw_if)) {
+      IfStmt *nextIf = parseIfStmt();
+      if (!nextIf)
+        return ParseFailure;
+      elseClause->push_back(nextIf);
+    return ParseSuccess;
+  }
+
+  if (parseStmtBlock(elseClause))
+      return ParseFailure;
+
+  return ParseSuccess;
 }
 
 ///
 /// Parse `{` ml-stmt* `}`
 ///
-ParseResult MLFunctionParser::parseNestedStatements(NodeStmt *parent) {
+ParseResult MLFunctionParser::parseStmtBlock(StmtBlock *block) {
   if (!consumeIf(Token::l_brace))
     return emitError("expected '{' before statement list");
 
-  if (consumeIf(Token::r_brace)) {
-    // TODO: parse OperationStmt
-    return ParseSuccess;
-  }
-
   while (!consumeIf(Token::r_brace)) {
-    auto *stmt = parseStatement(parent);
+    auto *stmt = parseStatement();
     if (!stmt)
       return ParseFailure;
-    parent->children.push_back(stmt);
+    block->push_back(stmt);
   }
 
   return ParseSuccess;