Refactor implementation of Statement class heirarchy to use statement block.
Use LLVM double-link with parent list to store statements within a block.
PiperOrigin-RevId: 204515541
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 5838f73..e1d184d 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -28,6 +28,7 @@
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSet.h"
+#include "mlir/IR/Statements.h"
#include "mlir/IR/Types.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
@@ -1381,10 +1382,13 @@
: Parser(state), function(function) {}
ParseResult parseFunctionBody();
- Statement *parseStatement(ParentType parent);
- ForStmt *parseForStmt(ParentType parent);
- IfStmt *parseIfStmt(ParentType parent);
- ParseResult parseNestedStatements(NodeStmt *parent);
+
+private:
+ Statement *parseStatement();
+ ForStmt *parseForStmt();
+ IfStmt *parseIfStmt();
+ ParseResult parseElseClause(IfClause *elseClause);
+ ParseResult parseStmtBlock(StmtBlock *block);
};
} // end anonymous namespace
@@ -1398,10 +1402,10 @@
// Parse the list of instructions.
while (!consumeIf(Token::kw_return)) {
- auto *stmt = parseStatement(function);
+ auto *stmt = parseStatement();
if (!stmt)
return ParseFailure;
- function->stmtList.push_back(stmt);
+ function->push_back(stmt);
}
// TODO: parse return statement operands
@@ -1420,17 +1424,17 @@
/// TODO: fix terminology in MLSpec document. ML functions
/// contain operation statements, not instructions.
///
-Statement *MLFunctionParser::parseStatement(ParentType parent) {
+Statement *MLFunctionParser::parseStatement() {
switch (getToken().getKind()) {
default:
//TODO: parse OperationStmt
return (emitError("expected statement"), nullptr);
case Token::kw_for:
- return parseForStmt(parent);
+ return parseForStmt();
case Token::kw_if:
- return parseIfStmt(parent);
+ return parseIfStmt();
}
}
@@ -1439,12 +1443,12 @@
/// ml-for-stmt ::= `for` ssa-id `=` lower-bound `to` upper-bound
/// (`step` integer-literal)? `{` ml-stmt* `}`
///
-ForStmt *MLFunctionParser::parseForStmt(ParentType parent) {
+ForStmt *MLFunctionParser::parseForStmt() {
consumeToken(Token::kw_for);
//TODO: parse loop header
- ForStmt *stmt = new ForStmt(parent);
- if (parseNestedStatements(stmt)) {
+ ForStmt *stmt = new ForStmt();
+ if (parseStmtBlock(static_cast<StmtBlock *>(stmt))) {
delete stmt;
return nullptr;
}
@@ -1458,50 +1462,61 @@
/// ml-if-stmt ::= ml-if-head
/// | ml-if-head `else` `{` ml-stmt* `}`
///
-IfStmt *
-MLFunctionParser::parseIfStmt(PointerUnion<MLFunction *, NodeStmt *> parent) {
+IfStmt *MLFunctionParser::parseIfStmt() {
consumeToken(Token::kw_if);
+ if (!consumeIf(Token::l_paren))
+ return (emitError("expected ("), nullptr);
//TODO: parse condition
- IfStmt *stmt = new IfStmt(parent);
- if (parseNestedStatements(stmt)) {
- delete stmt;
+
+ if (!consumeIf(Token::r_paren))
+ return (emitError("expected )"), nullptr);
+
+ IfStmt *ifStmt = new IfStmt();
+ IfClause *thenClause = ifStmt->getThenClause();
+ if (parseStmtBlock(thenClause)) {
+ delete ifStmt;
return nullptr;
}
- int clauseNum = 0;
- while (consumeIf(Token::kw_else)) {
- if (consumeIf(Token::kw_if)) {
- //TODO: parse condition
- }
- ElseClause * clause = new ElseClause(stmt, clauseNum);
- ++clauseNum;
- if (parseNestedStatements(clause)) {
- delete clause;
+ if (consumeIf(Token::kw_else)) {
+ IfClause *elseClause = ifStmt->createElseClause();
+ if (parseElseClause(elseClause)) {
+ delete ifStmt;
return nullptr;
}
}
- return stmt;
+ return ifStmt;
+}
+
+ParseResult MLFunctionParser::parseElseClause(IfClause *elseClause) {
+ if (getToken().is(Token::kw_if)) {
+ IfStmt *nextIf = parseIfStmt();
+ if (!nextIf)
+ return ParseFailure;
+ elseClause->push_back(nextIf);
+ return ParseSuccess;
+ }
+
+ if (parseStmtBlock(elseClause))
+ return ParseFailure;
+
+ return ParseSuccess;
}
///
/// Parse `{` ml-stmt* `}`
///
-ParseResult MLFunctionParser::parseNestedStatements(NodeStmt *parent) {
+ParseResult MLFunctionParser::parseStmtBlock(StmtBlock *block) {
if (!consumeIf(Token::l_brace))
return emitError("expected '{' before statement list");
- if (consumeIf(Token::r_brace)) {
- // TODO: parse OperationStmt
- return ParseSuccess;
- }
-
while (!consumeIf(Token::r_brace)) {
- auto *stmt = parseStatement(parent);
+ auto *stmt = parseStatement();
if (!stmt)
return ParseFailure;
- parent->children.push_back(stmt);
+ block->push_back(stmt);
}
return ParseSuccess;