Refactor implementation of Statement class heirarchy to use statement block.
Use LLVM double-link with parent list to store statements within a block.
PiperOrigin-RevId: 204515541
diff --git a/include/mlir/IR/MLFunction.h b/include/mlir/IR/MLFunction.h
index 0907b29..c8388ee 100644
--- a/include/mlir/IR/MLFunction.h
+++ b/include/mlir/IR/MLFunction.h
@@ -1,4 +1,4 @@
-//===- MLFunction.h - MLIR MLFunction Class -------------------*- C++ -*-===//
+//===- MLFunction.h - MLIR MLFunction Class ---------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
@@ -23,20 +23,16 @@
#define MLIR_IR_MLFUNCTION_H_
#include "mlir/IR/Function.h"
-#include "mlir/IR/Statements.h"
-#include <vector>
+#include "mlir/IR/StmtBlock.h"
namespace mlir {
// MLFunction is defined as a sequence of statements that may
-// include nested affine for loops, conditionals and instructions.
-class MLFunction : public Function {
+// include nested affine for loops, conditionals and operations.
+class MLFunction : public Function, public StmtBlock {
public:
MLFunction(StringRef name, FunctionType *type);
- // FIXME: wrong representation and API, leaks memory etc
- std::vector<Statement*> stmtList;
-
// TODO: add function arguments and return values once
// SSA values are implemented
diff --git a/include/mlir/IR/Statement.h b/include/mlir/IR/Statement.h
new file mode 100644
index 0000000..5b855ad
--- /dev/null
+++ b/include/mlir/IR/Statement.h
@@ -0,0 +1,107 @@
+//===- Statement.h - MLIR ML Statement Class --------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the Statement class.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_STATEMENT_H
+#define MLIR_IR_STATEMENT_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ilist.h"
+#include "llvm/ADT/ilist_node.h"
+
+namespace mlir {
+ class MLFunction;
+ class StmtBlock;
+
+/// Statement is a basic unit of execution within an ML function.
+/// Statements can be nested within for and if statements effectively
+/// forming a tree. Statements are organized into statement blocks
+/// represented by StmtBlock class.
+class Statement
+ : public llvm::ilist_node_with_parent<Statement, StmtBlock> {
+public:
+ enum class Kind {
+ Operation,
+ For,
+ If
+ };
+
+ Kind getKind() const { return kind; }
+
+ /// Returns the statement block that contains this statement.
+ StmtBlock *getBlock() const { return block; }
+
+ /// Returns the function that this statement is part of.
+ MLFunction *getFunction() const;
+
+ /// Destroys the argument statement or one of its subclasses
+ static void destroy(Statement *stmt);
+
+ void print(raw_ostream &os) const;
+ void dump() const;
+
+protected:
+ Statement(Kind kind) : kind(kind) {}
+ // Statements are deleted through the destroy() member because this class
+ // does not have a virtual destructor.
+ ~Statement();
+
+ /// Remove this statement from its block and delete it.
+ void eraseFromBlock();
+private:
+ Kind kind;
+ StmtBlock *block = nullptr;
+
+ // allow ilist_traits access to 'block' field.
+ friend struct llvm::ilist_traits<Statement>;
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, const Statement &stmt) {
+ stmt.print(os);
+ return os;
+}
+} //end namespace mlir
+
+//===----------------------------------------------------------------------===//
+// ilist_traits for Statement
+//===----------------------------------------------------------------------===//
+
+namespace llvm {
+
+template <>
+struct ilist_traits<::mlir::Statement> {
+ using Statement = ::mlir::Statement;
+ using stmt_iterator = simple_ilist<Statement>::iterator;
+
+ static void deleteNode(Statement *stmt) {
+ Statement::destroy(stmt);
+ }
+
+ void addNodeToList(Statement *stmt);
+ void removeNodeFromList(Statement *stmt);
+ void transferNodesFromList(ilist_traits<Statement> &otherList,
+ stmt_iterator first, stmt_iterator last);
+private:
+ mlir::StmtBlock *getContainingBlock();
+};
+
+} // end namespace llvm
+
+#endif // MLIR_IR_STATEMENT_H
diff --git a/include/mlir/IR/Statements.h b/include/mlir/IR/Statements.h
index 77d1a6a..92ab8b6 100644
--- a/include/mlir/IR/Statements.h
+++ b/include/mlir/IR/Statements.h
@@ -1,4 +1,4 @@
-//===- Statements.h - MLIR ML Statement Classes ------------*- C++ -*-===//
+//===- Statements.h - MLIR ML Statement Classes -----------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
@@ -15,7 +15,7 @@
// limitations under the License.
// =============================================================================
//
-// This file defines the classes for MLFunction statements.
+// This file defines classes for special kinds of ML Function statements.
//
//===----------------------------------------------------------------------===//
@@ -23,56 +23,19 @@
#define MLIR_IR_STATEMENTS_H
#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/PointerUnion.h"
-
#include "mlir/IR/Operation.h"
-
-#include <vector>
+#include "mlir/IR/Statement.h"
+#include "mlir/IR/StmtBlock.h"
namespace mlir {
- class MLFunction;
- class NodeStmt;
- class ElseClause;
-
- typedef PointerUnion<MLFunction *, NodeStmt *> ParentType;
-
-/// Statement is a basic unit of execution within an ML function.
-/// Statements can be nested within each other, effectively forming a tree.
-class Statement {
-public:
- enum class Kind {
- Operation,
- For,
- If,
- Else
- };
-
- Kind getKind() const { return kind; }
-
- /// Returns the parent of this statement. The parent of a nested statement
- /// is the closest surrounding for or if statement. The parent of
- /// a top-level statement is the function that contains the statement.
- ParentType getParent() const { return parent; }
-
- /// Returns the function that this statement is part of.
- MLFunction *getFunction() const;
-
- void print(raw_ostream &os) const;
- void dump() const;
-
-protected:
- Statement(Kind kind, ParentType parent) : kind(kind), parent(parent) {}
-private:
- Kind kind;
- ParentType parent;
-};
/// Operation statements represent operations inside ML functions.
class OperationStmt : public Operation, public Statement {
public:
- explicit OperationStmt(ParentType parent, Identifier name,
- ArrayRef<NamedAttribute> attrs, MLIRContext *context)
- : Operation(name, attrs, context), Statement(Kind::Operation, parent) {}
+ explicit OperationStmt(Identifier name, ArrayRef<NamedAttribute> attrs,
+ MLIRContext *context)
+ : Operation(name, attrs, context), Statement(Kind::Operation) {}
+ ~OperationStmt() {}
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Statement *stmt) {
@@ -80,25 +43,12 @@
}
};
-/// Node statement represents a statement that may contain other statements.
-class NodeStmt : public Statement {
-public:
- // FIXME: wrong representation and API, leaks memory etc
- std::vector<Statement*> children;
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(const Statement *stmt) {
- return stmt->getKind() != Kind::Operation;
- }
-
-protected:
- NodeStmt(Kind kind, ParentType parent) : Statement(kind, parent) {}
-};
-
/// For statement represents an affine loop nest.
-class ForStmt : public NodeStmt {
+class ForStmt : public Statement, public StmtBlock {
public:
- explicit ForStmt(ParentType parent) : NodeStmt(Kind::For, parent) {}
+ explicit ForStmt() : Statement(Kind::For), StmtBlock(this) {}
+ //TODO: delete nested statements or assert that they are gone.
+ ~ForStmt() {}
// TODO: represent loop variable, bounds and step
@@ -108,39 +58,41 @@
}
};
-/// If statement restricts execution to a subset of the loop iteration space.
-class IfStmt : public NodeStmt {
+/// If clause represents statements contained within then or else clause
+/// of an if statement.
+class IfClause : public StmtBlock {
public:
- explicit IfStmt(ParentType parent) : NodeStmt(Kind::If, parent) {}
+ explicit IfClause(IfStmt *stmt);
- // TODO: Represent condition
+ //TODO: delete nested statements or assert that they are gone.
+ ~IfClause() {}
- // FIXME: most likely wrong representation since it's wrong everywhere else
- std::vector<ElseClause *> elseClauses;
+ IfStmt *getIf() const;
+};
+
+/// If statement restricts execution to a subset of the loop iteration space.
+class IfStmt : public Statement {
+public:
+ explicit IfStmt()
+ : Statement(Kind::If), thenClause(new IfClause(this)),
+ elseClause(nullptr) {}
+
+ ~IfStmt();
+
+ IfClause *getThenClause() const { return thenClause; }
+ IfClause *getElseClause() const { return elseClause; }
+ bool hasElseClause() const {return elseClause != nullptr;}
+ IfClause *createElseClause() { return (elseClause = new IfClause(this)); }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Statement *stmt) {
return stmt->getKind() == Kind::If;
}
-};
-
-/// Else clause reprsents else or else-if clause of an if statement
-class ElseClause : public NodeStmt {
-public:
- explicit ElseClause(IfStmt *ifStmt, int clauseNum);
-
- // TODO: Represent optional condition
-
- // Returns ordinal number of this clause in the list of clauses.
- int getClauseNumber() const { return clauseNum;}
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(const Statement *stmt) {
- return stmt->getKind() == Kind::Else;
- }
private:
- int clauseNum;
+ IfClause *thenClause;
+ IfClause *elseClause;
+ // TODO: Represent IntegerSet condition
};
-
} //end namespace mlir
+
#endif // MLIR_IR_STATEMENTS_H
diff --git a/include/mlir/IR/StmtBlock.h b/include/mlir/IR/StmtBlock.h
new file mode 100644
index 0000000..ed3f50c
--- /dev/null
+++ b/include/mlir/IR/StmtBlock.h
@@ -0,0 +1,101 @@
+//===- StmtBlock.h ----------------------------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines StmtBlock and *Stmt classes that extend Statement.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_STMTBLOCK_H
+#define MLIR_IR_STMTBLOCK_H
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/IR/Statement.h"
+
+namespace mlir {
+ class MLFunction;
+ class IfStmt;
+
+/// Statement block represents an ordered list of statements.
+class StmtBlock {
+public:
+ /// Returns the closest surrounding statement that contains this block or
+ /// nullptr if this is a top-level statement block.
+ Statement *getParent() const { return parent; }
+
+ /// Returns the function that this statement block is part of.
+ MLFunction *getFunction() const;
+
+ //===--------------------------------------------------------------------===//
+ // Statement list management
+ //===--------------------------------------------------------------------===//
+
+ /// This is the list of statements in the block.
+ typedef llvm::iplist<Statement> StmtListType;
+ StmtListType &getStatements() { return statements; }
+ const StmtListType &getStatements() const { return statements; }
+
+ // Iteration over the statements in the block.
+ using iterator = StmtListType::iterator;
+ using const_iterator = StmtListType::const_iterator;
+ using reverse_iterator = StmtListType::reverse_iterator;
+ using const_reverse_iterator = StmtListType::const_reverse_iterator;
+
+ iterator begin() { return statements.begin(); }
+ iterator end() { return statements.end(); }
+ const_iterator begin() const { return statements.begin(); }
+ const_iterator end() const { return statements.end(); }
+ reverse_iterator rbegin() { return statements.rbegin(); }
+ reverse_iterator rend() { return statements.rend(); }
+ const_reverse_iterator rbegin() const { return statements.rbegin(); }
+ const_reverse_iterator rend() const { return statements.rend(); }
+
+ bool empty() const { return statements.empty(); }
+ void push_back(Statement *stmt) { statements.push_back(stmt); }
+ void push_front(Statement *stmt) { statements.push_front(stmt); }
+
+ Statement &back() { return statements.back(); }
+ const Statement &back() const {
+ return const_cast<StmtBlock *>(this)->back();
+ }
+ Statement &front() { return statements.front(); }
+ const Statement &front() const {
+ return const_cast<StmtBlock*>(this)->front();
+ }
+
+ void print(raw_ostream &os) const;
+ void dump() const;
+
+ /// getSublistAccess() - Returns pointer to member of statement list
+ static StmtListType StmtBlock::*getSublistAccess(Statement*) {
+ return &StmtBlock::statements;
+ }
+
+protected:
+ Statement *parent;
+
+ StmtBlock(Statement *parent=nullptr) : parent(parent) {}
+private:
+ /// This is the list of statements in the block.
+ StmtListType statements;
+
+ StmtBlock(const StmtBlock&) = delete;
+ void operator=(const StmtBlock&) = delete;
+
+};
+
+} //end namespace mlir
+#endif // MLIR_IR_STMTBLOCK_H
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 27ed3e4..0103e6a 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -27,6 +27,7 @@
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSet.h"
+#include "mlir/IR/Statements.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/DenseMap.h"
@@ -212,18 +213,20 @@
const MLFunction *getFunction() const { return function; }
+ // Prints ML function
void print();
+ // Methods to print ML function statements
void print(const Statement *stmt);
void print(const OperationStmt *stmt);
void print(const ForStmt *stmt);
void print(const IfStmt *stmt);
- void print(const ElseClause *stmt, bool isLast);
+ void print(const StmtBlock *block);
+
+ // Number of spaces used for indenting nested statements
+ const static unsigned indentWidth = 2;
private:
- // Print statements nested within this node statement.
- void printNestedStatements(const NodeStmt *stmt);
-
const MLFunction *function;
int numSpaces;
};
@@ -231,21 +234,26 @@
MLFunctionState::MLFunctionState(const MLFunction *function, raw_ostream &os)
: FunctionState(function->getContext(), os), function(function),
- numSpaces(2) {}
+ numSpaces(0) {}
void MLFunctionState::print() {
os << "mlfunc ";
// FIXME: should print argument names rather than just signature
printFunctionSignature(function, os);
os << " {\n";
- for (auto *stmt : function->stmtList)
- print(stmt);
+ print(function);
os << " return\n";
os << "}\n\n";
}
+void MLFunctionState::print(const StmtBlock *block) {
+ numSpaces += indentWidth;
+ for (auto &stmt : block->getStatements())
+ print(&stmt);
+ numSpaces -= indentWidth;
+}
+
void MLFunctionState::print(const Statement *stmt) {
- os.indent(numSpaces);
switch (stmt->getKind()) {
case Statement::Kind::Operation: // TODO
llvm_unreachable("Operation statement is not yet implemented");
@@ -253,45 +261,31 @@
return print(cast<ForStmt>(stmt));
case Statement::Kind::If:
return print(cast<IfStmt>(stmt));
- case Statement::Kind::Else:
- return print(cast<ElseClause>(stmt));
}
}
-void MLFunctionState::printNestedStatements(const NodeStmt *stmt) {
- os << "{\n";
- numSpaces += 2;
- for (auto * nestedStmt : stmt->children)
- print(nestedStmt);
- numSpaces -= 2;
- os.indent(numSpaces) << "}";
+void MLFunctionState::print(const OperationStmt *stmt) {
+ printOperation(stmt);
}
-void MLFunctionState::print(const OperationStmt *stmt) { printOperation(stmt); }
-
void MLFunctionState::print(const ForStmt *stmt) {
- os << "for ";
- printNestedStatements(stmt);
- os << "\n";
+ os.indent(numSpaces) << "for {\n";
+ print(static_cast<const StmtBlock *>(stmt));
+ os.indent(numSpaces) << "}\n";
}
void MLFunctionState::print(const IfStmt *stmt) {
- os << "if ";
- printNestedStatements(stmt);
-
- int numClauses = stmt->elseClauses.size();
- for (auto e : stmt->elseClauses)
- print(e, e->getClauseNumber() == numClauses - 1);
+ os.indent(numSpaces) << "if () {\n";
+ print(stmt->getThenClause());
+ os.indent(numSpaces) << "}";
+ if (stmt->hasElseClause()) {
+ os << " else {\n";
+ print(stmt->getElseClause());
+ os.indent(numSpaces) << "}";
+ }
os << "\n";
}
-void MLFunctionState::print(const ElseClause *stmt, bool isLast) {
- if (!isLast)
- os << " if";
- os << " else ";;
- printNestedStatements(stmt);
-}
-
//===----------------------------------------------------------------------===//
// print and dump methods
//===----------------------------------------------------------------------===//
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index 92c3981..453797f 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -48,5 +48,5 @@
//===----------------------------------------------------------------------===//
MLFunction::MLFunction(StringRef name, FunctionType *type)
- : Function(name, type, Kind::MLFunc) {
+ : Function(name, type, Kind::MLFunc), StmtBlock() {
}
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
new file mode 100644
index 0000000..336cf7a
--- /dev/null
+++ b/lib/IR/Statement.cpp
@@ -0,0 +1,119 @@
+//===- Statement.cpp - MLIR Statement Classes ----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Statements.h"
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Statement
+//===------------------------------------------------------------------===//
+
+// Statements are deleted through the destroy() member because we don't have
+// a virtual destructor.
+Statement::~Statement() {
+ assert(block == nullptr && "statement destroyed but still in a block");
+}
+
+/// Destroy this statement or one of its subclasses.
+void Statement::destroy(Statement *stmt) {
+ switch (stmt->getKind()) {
+ case Kind::Operation:
+ delete cast<OperationStmt>(stmt);
+ break;
+ case Kind::For:
+ delete cast<ForStmt>(stmt);
+ break;
+ case Kind::If:
+ delete cast<IfStmt>(stmt);
+ break;
+ }
+}
+
+MLFunction *Statement::getFunction() const {
+ return this->getBlock()->getFunction();
+}
+
+//===----------------------------------------------------------------------===//
+// ilist_traits for Statement
+//===----------------------------------------------------------------------===//
+
+StmtBlock *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() {
+ size_t Offset(
+ size_t(&((StmtBlock *)nullptr->*StmtBlock::getSublistAccess(nullptr))));
+ iplist<Statement> *Anchor(static_cast<iplist<Statement> *>(this));
+ return reinterpret_cast<StmtBlock *>(reinterpret_cast<char *>(Anchor) -
+ Offset);
+}
+
+/// This is a trait method invoked when a statement is added to a block. We
+/// keep the block pointer up to date.
+void llvm::ilist_traits<::mlir::Statement>::addNodeToList(Statement *stmt) {
+ assert(!stmt->getBlock() && "already in a statement block!");
+ stmt->block = getContainingBlock();
+}
+
+/// This is a trait method invoked when a statement is removed from a block.
+/// We keep the block pointer up to date.
+void llvm::ilist_traits<::mlir::Statement>::removeNodeFromList(
+ Statement *stmt) {
+ assert(stmt->block && "not already in a statement block!");
+ stmt->block = nullptr;
+}
+
+/// This is a trait method invoked when a statement is moved from one block
+/// to another. We keep the block pointer up to date.
+void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList(
+ ilist_traits<Statement> &otherList, stmt_iterator first,
+ stmt_iterator last) {
+ // If we are transferring statements within the same block, the block
+ // pointer doesn't need to be updated.
+ StmtBlock *curParent = getContainingBlock();
+ if (curParent == otherList.getContainingBlock())
+ return;
+
+ // Update the 'block' member of each statement.
+ for (; first != last; ++first)
+ first->block = curParent;
+}
+
+/// Remove this statement from its StmtBlock and delete it.
+void Statement::eraseFromBlock() {
+ assert(getBlock() && "Statement has no block");
+ getBlock()->getStatements().erase(this);
+}
+
+//===----------------------------------------------------------------------===//
+// IfClause
+//===----------------------------------------------------------------------===//
+
+IfClause::IfClause(IfStmt *stmt) : StmtBlock(stmt) {
+ assert(stmt != nullptr && "If clause must have non-null parent");
+}
+
+IfStmt *IfClause::getIf() const { return static_cast<IfStmt *>(parent); }
+
+//===----------------------------------------------------------------------===//
+// IfStmt
+//===----------------------------------------------------------------------===//
+
+IfStmt::~IfStmt() {
+ // TODO: correctly delete StmtBlocks under then and else clauses
+ delete thenClause;
+ if (elseClause != nullptr)
+ delete elseClause;
+}
diff --git a/lib/IR/Statements.cpp b/lib/IR/Statements.cpp
deleted file mode 100644
index ab3f8fc..0000000
--- a/lib/IR/Statements.cpp
+++ /dev/null
@@ -1,40 +0,0 @@
-//===- Statements.cpp - MLIR Statement Instruction Classes ------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "mlir/IR/MLFunction.h"
-#include "mlir/IR/Statements.h"
-using namespace mlir;
-
-//===----------------------------------------------------------------------===//
-// Statement
-//===----------------------------------------------------------------------===//
-
-MLFunction *Statement::getFunction() const {
- ParentType p = parent;
- while (!p.is<MLFunction *>())
- p = p.get<NodeStmt *>()->getParent();
- return p.get<MLFunction *>();
-}
-
-//===----------------------------------------------------------------------===//
-// ElseClause
-//===----------------------------------------------------------------------===//
-
-ElseClause::ElseClause(IfStmt *ifStmt, int clauseNum)
- : NodeStmt(Kind::Else, ifStmt), clauseNum(clauseNum) {
- ifStmt->elseClauses.push_back(this);
-}
diff --git a/lib/IR/StmtBlock.cpp b/lib/IR/StmtBlock.cpp
new file mode 100644
index 0000000..83e0412
--- /dev/null
+++ b/lib/IR/StmtBlock.cpp
@@ -0,0 +1,32 @@
+//===- StmtBlock.cpp - MLIR Statement Instruction Classes -----------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/StmtBlock.h"
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Statement block
+//===----------------------------------------------------------------------===//
+
+MLFunction *StmtBlock::getFunction() const {
+ StmtBlock *block = const_cast<StmtBlock *>(this);
+
+ while (block->getParent() != nullptr)
+ block = block->getParent()->getBlock();
+ return static_cast<MLFunction *>(block);
+}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 5838f73..e1d184d 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -28,6 +28,7 @@
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSet.h"
+#include "mlir/IR/Statements.h"
#include "mlir/IR/Types.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
@@ -1381,10 +1382,13 @@
: Parser(state), function(function) {}
ParseResult parseFunctionBody();
- Statement *parseStatement(ParentType parent);
- ForStmt *parseForStmt(ParentType parent);
- IfStmt *parseIfStmt(ParentType parent);
- ParseResult parseNestedStatements(NodeStmt *parent);
+
+private:
+ Statement *parseStatement();
+ ForStmt *parseForStmt();
+ IfStmt *parseIfStmt();
+ ParseResult parseElseClause(IfClause *elseClause);
+ ParseResult parseStmtBlock(StmtBlock *block);
};
} // end anonymous namespace
@@ -1398,10 +1402,10 @@
// Parse the list of instructions.
while (!consumeIf(Token::kw_return)) {
- auto *stmt = parseStatement(function);
+ auto *stmt = parseStatement();
if (!stmt)
return ParseFailure;
- function->stmtList.push_back(stmt);
+ function->push_back(stmt);
}
// TODO: parse return statement operands
@@ -1420,17 +1424,17 @@
/// TODO: fix terminology in MLSpec document. ML functions
/// contain operation statements, not instructions.
///
-Statement *MLFunctionParser::parseStatement(ParentType parent) {
+Statement *MLFunctionParser::parseStatement() {
switch (getToken().getKind()) {
default:
//TODO: parse OperationStmt
return (emitError("expected statement"), nullptr);
case Token::kw_for:
- return parseForStmt(parent);
+ return parseForStmt();
case Token::kw_if:
- return parseIfStmt(parent);
+ return parseIfStmt();
}
}
@@ -1439,12 +1443,12 @@
/// ml-for-stmt ::= `for` ssa-id `=` lower-bound `to` upper-bound
/// (`step` integer-literal)? `{` ml-stmt* `}`
///
-ForStmt *MLFunctionParser::parseForStmt(ParentType parent) {
+ForStmt *MLFunctionParser::parseForStmt() {
consumeToken(Token::kw_for);
//TODO: parse loop header
- ForStmt *stmt = new ForStmt(parent);
- if (parseNestedStatements(stmt)) {
+ ForStmt *stmt = new ForStmt();
+ if (parseStmtBlock(static_cast<StmtBlock *>(stmt))) {
delete stmt;
return nullptr;
}
@@ -1458,50 +1462,61 @@
/// ml-if-stmt ::= ml-if-head
/// | ml-if-head `else` `{` ml-stmt* `}`
///
-IfStmt *
-MLFunctionParser::parseIfStmt(PointerUnion<MLFunction *, NodeStmt *> parent) {
+IfStmt *MLFunctionParser::parseIfStmt() {
consumeToken(Token::kw_if);
+ if (!consumeIf(Token::l_paren))
+ return (emitError("expected ("), nullptr);
//TODO: parse condition
- IfStmt *stmt = new IfStmt(parent);
- if (parseNestedStatements(stmt)) {
- delete stmt;
+
+ if (!consumeIf(Token::r_paren))
+ return (emitError("expected )"), nullptr);
+
+ IfStmt *ifStmt = new IfStmt();
+ IfClause *thenClause = ifStmt->getThenClause();
+ if (parseStmtBlock(thenClause)) {
+ delete ifStmt;
return nullptr;
}
- int clauseNum = 0;
- while (consumeIf(Token::kw_else)) {
- if (consumeIf(Token::kw_if)) {
- //TODO: parse condition
- }
- ElseClause * clause = new ElseClause(stmt, clauseNum);
- ++clauseNum;
- if (parseNestedStatements(clause)) {
- delete clause;
+ if (consumeIf(Token::kw_else)) {
+ IfClause *elseClause = ifStmt->createElseClause();
+ if (parseElseClause(elseClause)) {
+ delete ifStmt;
return nullptr;
}
}
- return stmt;
+ return ifStmt;
+}
+
+ParseResult MLFunctionParser::parseElseClause(IfClause *elseClause) {
+ if (getToken().is(Token::kw_if)) {
+ IfStmt *nextIf = parseIfStmt();
+ if (!nextIf)
+ return ParseFailure;
+ elseClause->push_back(nextIf);
+ return ParseSuccess;
+ }
+
+ if (parseStmtBlock(elseClause))
+ return ParseFailure;
+
+ return ParseSuccess;
}
///
/// Parse `{` ml-stmt* `}`
///
-ParseResult MLFunctionParser::parseNestedStatements(NodeStmt *parent) {
+ParseResult MLFunctionParser::parseStmtBlock(StmtBlock *block) {
if (!consumeIf(Token::l_brace))
return emitError("expected '{' before statement list");
- if (consumeIf(Token::r_brace)) {
- // TODO: parse OperationStmt
- return ParseSuccess;
- }
-
while (!consumeIf(Token::r_brace)) {
- auto *stmt = parseStatement(parent);
+ auto *stmt = parseStatement();
if (!stmt)
return ParseFailure;
- parent->children.push_back(stmt);
+ block->push_back(stmt);
}
return ParseSuccess;
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index a01d7f2..71d4b56 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -70,14 +70,14 @@
; CHECK-LABEL: mlfunc @ifstmt() {
mlfunc @ifstmt() {
- for { ; CHECK for {
- if { ; CHECK if {
- } else if { ; CHECK } else if {
- } else { ; CHECK } else {
- } ; CHECK }
- } ; CHECK }
- return ; CHECK return
-} ; CHECK }
+ for { ; CHECK for {
+ if () { ; CHECK if () {
+ } else if () { ; CHECK } else if () {
+ } else { ; CHECK } else {
+ } ; CHECK }
+ } ; CHECK }
+ return ; CHECK return
+} ; CHECK }
; CHECK-LABEL: cfgfunc @attributes() {
cfgfunc @attributes() {