Refactor the parser a bit to split out the pieces that need their own local
state into their own specialized parser subclasses. This is important,
because a monolithic parser grows very large very quickly and we're already
getting big.
Doing this requires splitting mutable parser state out from Parser to its
own ParserState class or into transient subclasses like CFGParser. This
works better than having things like CFGFuncParserState which gets passed
around everywhere, because we can put the parser methods on the
new classes.
This patch just does CFGFunc and MLFunc, but I'll follow up with AffineMaps
(unless someone else wants to take it).
PiperOrigin-RevId: 203871695
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 62f4b99..42d7935 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -35,8 +35,8 @@
using llvm::SMLoc;
namespace {
-class CFGFunctionParserState;
class AffineMapParserState;
+} // end anonymous namespace
/// Simple enum to make code read better in cases that would otherwise return a
/// bool value. Failure is "true" in a boolean context.
@@ -65,20 +65,29 @@
Mod
};
-/// Main parser implementation.
-class Parser {
+namespace {
+class Parser;
+
+/// This class refers to all of the state maintained globally by the parser,
+/// such as the current lexer position etc. The Parser base class provides
+/// methods to access this.
+class ParserState {
public:
- Parser(llvm::SourceMgr &sourceMgr, MLIRContext *context,
- SMDiagnosticHandlerTy errorReporter)
- : builder(context), lex(sourceMgr, errorReporter),
+ ParserState(llvm::SourceMgr &sourceMgr, MLIRContext *context,
+ SMDiagnosticHandlerTy errorReporter)
+ : context(context), lex(sourceMgr, errorReporter),
curToken(lex.lexToken()), errorReporter(std::move(errorReporter)) {
module.reset(new Module(context));
}
- Module *parseModule();
private:
- // State.
- Builder builder;
+ ParserState(const ParserState &) = delete;
+ void operator=(const ParserState &) = delete;
+
+ friend class Parser;
+
+ // The context we're parsing into.
+ MLIRContext *context;
// The lexer for the source file we're parsing.
Lexer lex;
@@ -94,35 +103,52 @@
// A map from affine map identifier to AffineMap.
llvm::StringMap<AffineMap*> affineMapDefinitions;
+};
+} // end anonymous namespace
-private:
+namespace {
+
+/// This class implement support for parsing global entities like types and
+/// shared entities like SSA names. It is intended to be subclassed by
+/// specialized subparsers that include state, e.g. when a local symbol table.
+class Parser {
+public:
+ Parser(ParserState &state) : state(state), builder(state.context) {}
+ Module *parseModule();
+
// Helper methods.
+ MLIRContext *getContext() const { return state.context; }
+ Module *getModule() { return state.module.get(); }
+
+ /// Return the current token the parser is inspecting.
+ const Token &getToken() const { return state.curToken; }
+ StringRef getTokenSpelling() const { return state.curToken.getSpelling(); }
/// Emit an error and return failure.
ParseResult emitError(const Twine &message) {
- return emitError(curToken.getLoc(), message);
+ return emitError(state.curToken.getLoc(), message);
}
ParseResult emitError(SMLoc loc, const Twine &message);
/// Advance the current lexer onto the next token.
void consumeToken() {
- assert(curToken.isNot(Token::eof, Token::error) &&
+ assert(state.curToken.isNot(Token::eof, Token::error) &&
"shouldn't advance past EOF or errors");
- curToken = lex.lexToken();
+ state.curToken = state.lex.lexToken();
}
/// Advance the current lexer onto the next token, asserting what the expected
/// current token is. This is preferred to the above method because it leads
/// to more self-documenting code with better checking.
void consumeToken(Token::Kind kind) {
- assert(curToken.is(kind) && "consumed an unexpected token");
+ assert(state.curToken.is(kind) && "consumed an unexpected token");
consumeToken();
}
/// If the current token has the specified kind, consume it and return true.
/// If not, return false.
bool consumeIf(Token::Kind kind) {
- if (curToken.isNot(kind))
+ if (state.curToken.isNot(kind))
return false;
consumeToken(kind);
return true;
@@ -193,16 +219,13 @@
ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type);
ParseResult parseExtFunc();
ParseResult parseCFGFunc();
- ParseResult parseBasicBlock(CFGFunctionParserState &functionState);
- Statement *parseStatement(ParentType parent);
-
- OperationInst *parseCFGOperation(CFGFunctionParserState &functionState);
- TerminatorInst *parseTerminator(CFGFunctionParserState &functionState);
-
ParseResult parseMLFunc();
- ForStmt *parseForStmt(ParentType parent);
- IfStmt *parseIfStmt(ParentType parent);
- ParseResult parseNestedStatements(NodeStmt *parent);
+
+private:
+ // The Parser is subclassed and reinstantiated. Do not add additional
+ // non-trivial state here, add it to the ParserState class.
+ ParserState &state;
+ Builder builder;
};
} // end anonymous namespace
@@ -213,11 +236,11 @@
ParseResult Parser::emitError(SMLoc loc, const Twine &message) {
// If we hit a parse error in response to a lexer error, then the lexer
// already reported the error.
- if (curToken.is(Token::error))
+ if (getToken().is(Token::error))
return ParseFailure;
- errorReporter(
- lex.getSourceMgr().GetMessage(loc, SourceMgr::DK_Error, message));
+ auto &sourceMgr = state.lex.getSourceMgr();
+ state.errorReporter(sourceMgr.GetMessage(loc, SourceMgr::DK_Error, message));
return ParseFailure;
}
@@ -232,7 +255,7 @@
const std::function<ParseResult()> &parseElement,
bool allowEmptyList) {
// Handle the empty case.
- if (curToken.is(rightToken)) {
+ if (getToken().is(rightToken)) {
if (!allowEmptyList)
return emitError("expected list element");
consumeToken(rightToken);
@@ -268,7 +291,7 @@
/// primitive-type ::= `affineint`
///
Type *Parser::parsePrimitiveType() {
- switch (curToken.getKind()) {
+ switch (getToken().getKind()) {
default:
return (emitError("expected type"), nullptr);
case Token::kw_bf16:
@@ -287,7 +310,7 @@
consumeToken(Token::kw_affineint);
return builder.getAffineIntType();
case Token::inttype: {
- auto width = curToken.getIntTypeBitwidth();
+ auto width = getToken().getIntTypeBitwidth();
if (!width.hasValue())
return (emitError("invalid integer width"), nullptr);
consumeToken(Token::inttype);
@@ -301,7 +324,7 @@
/// element-type ::= primitive-type | vector-type
///
Type *Parser::parseElementType() {
- if (curToken.is(Token::kw_vector))
+ if (getToken().is(Token::kw_vector))
return parseVectorType();
return parsePrimitiveType();
@@ -318,13 +341,13 @@
if (!consumeIf(Token::less))
return (emitError("expected '<' in vector type"), nullptr);
- if (curToken.isNot(Token::integer))
+ if (getToken().isNot(Token::integer))
return (emitError("expected dimension size in vector type"), nullptr);
SmallVector<unsigned, 4> dimensions;
- while (curToken.is(Token::integer)) {
+ while (getToken().is(Token::integer)) {
// Make sure this integer value is in bound and valid.
- auto dimension = curToken.getUnsignedIntegerValue();
+ auto dimension = getToken().getUnsignedIntegerValue();
if (!dimension.hasValue())
return (emitError("invalid dimension in vector type"), nullptr);
dimensions.push_back(dimension.getValue());
@@ -332,13 +355,13 @@
consumeToken(Token::integer);
// Make sure we have an 'x' or something like 'xbf32'.
- if (curToken.isNot(Token::bare_identifier) ||
- curToken.getSpelling()[0] != 'x')
+ if (getToken().isNot(Token::bare_identifier) ||
+ getTokenSpelling()[0] != 'x')
return (emitError("expected 'x' in vector dimension list"), nullptr);
// If we had a prefix of 'x', lex the next token immediately after the 'x'.
- if (curToken.getSpelling().size() != 1)
- lex.resetPointer(curToken.getSpelling().data()+1);
+ if (getTokenSpelling().size() != 1)
+ state.lex.resetPointer(getTokenSpelling().data() + 1);
// Consume the 'x'.
consumeToken(Token::bare_identifier);
@@ -362,12 +385,12 @@
/// dimension ::= `?` | integer-literal
///
ParseResult Parser::parseDimensionListRanked(SmallVectorImpl<int> &dimensions) {
- while (curToken.isAny(Token::integer, Token::question)) {
+ while (getToken().isAny(Token::integer, Token::question)) {
if (consumeIf(Token::question)) {
dimensions.push_back(-1);
} else {
// Make sure this integer value is in bound and valid.
- auto dimension = curToken.getUnsignedIntegerValue();
+ auto dimension = getToken().getUnsignedIntegerValue();
if (!dimension.hasValue() || (int)dimension.getValue() < 0)
return emitError("invalid dimension");
dimensions.push_back((int)dimension.getValue());
@@ -375,13 +398,13 @@
}
// Make sure we have an 'x' or something like 'xbf32'.
- if (curToken.isNot(Token::bare_identifier) ||
- curToken.getSpelling()[0] != 'x')
+ if (getToken().isNot(Token::bare_identifier) ||
+ getTokenSpelling()[0] != 'x')
return emitError("expected 'x' in dimension list");
// If we had a prefix of 'x', lex the next token immediately after the 'x'.
- if (curToken.getSpelling().size() != 1)
- lex.resetPointer(curToken.getSpelling().data()+1);
+ if (getTokenSpelling().size() != 1)
+ state.lex.resetPointer(getTokenSpelling().data() + 1);
// Consume the 'x'.
consumeToken(Token::bare_identifier);
@@ -463,7 +486,7 @@
/// function-type ::= type-list-parens `->` type-list
///
Type *Parser::parseFunctionType() {
- assert(curToken.is(Token::l_paren));
+ assert(getToken().is(Token::l_paren));
SmallVector<Type*, 4> arguments;
if (parseTypeList(arguments))
@@ -489,7 +512,7 @@
/// element-type ::= primitive-type | vector-type
///
Type *Parser::parseType() {
- switch (curToken.getKind()) {
+ switch (getToken().getKind()) {
case Token::kw_memref: return parseMemRefType();
case Token::kw_tensor: return parseTensorType();
case Token::kw_vector: return parseVectorType();
@@ -560,7 +583,7 @@
/// | `[` (attribute-value (`,` attribute-value)*)? `]`
///
Attribute *Parser::parseAttribute() {
- switch (curToken.getKind()) {
+ switch (getToken().getKind()) {
case Token::kw_true:
consumeToken(Token::kw_true);
return BoolAttr::get(true, builder.getContext());
@@ -569,7 +592,7 @@
return BoolAttr::get(false, builder.getContext());
case Token::integer: {
- auto val = curToken.getUInt64IntegerValue();
+ auto val = getToken().getUInt64IntegerValue();
if (!val.hasValue() || (int64_t)val.getValue() < 0)
return (emitError("integer too large for attribute"), nullptr);
consumeToken(Token::integer);
@@ -578,8 +601,8 @@
case Token::minus: {
consumeToken(Token::minus);
- if (curToken.is(Token::integer)) {
- auto val = curToken.getUInt64IntegerValue();
+ if (getToken().is(Token::integer)) {
+ auto val = getToken().getUInt64IntegerValue();
if (!val.hasValue() || (int64_t)-val.getValue() >= 0)
return (emitError("integer too large for attribute"), nullptr);
consumeToken(Token::integer);
@@ -591,7 +614,7 @@
}
case Token::string: {
- auto val = curToken.getStringValue();
+ auto val = getToken().getStringValue();
consumeToken(Token::string);
return StringAttr::get(val, builder.getContext());
}
@@ -627,10 +650,10 @@
auto parseElt = [&]() -> ParseResult {
// We allow keywords as attribute names.
- if (curToken.isNot(Token::bare_identifier, Token::inttype) &&
- !curToken.isKeyword())
+ if (getToken().isNot(Token::bare_identifier, Token::inttype) &&
+ !getToken().isKeyword())
return emitError("expected attribute name");
- auto nameId = Identifier::get(curToken.getSpelling(), builder.getContext());
+ auto nameId = Identifier::get(getTokenSpelling(), builder.getContext());
consumeToken();
if (!consumeIf(Token::colon))
@@ -658,12 +681,12 @@
/// affine-map-def ::= affine-map-id `=` affine-map-inline
///
ParseResult Parser::parseAffineMapDef() {
- assert(curToken.is(Token::hash_identifier));
+ assert(getToken().is(Token::hash_identifier));
- StringRef affineMapId = curToken.getSpelling().drop_front();
+ StringRef affineMapId = getTokenSpelling().drop_front();
// Check for redefinitions.
- auto *&entry = affineMapDefinitions[affineMapId];
+ auto *&entry = state.affineMapDefinitions[affineMapId];
if (entry)
return emitError("redefinition of affine map id '" + affineMapId + "'");
@@ -677,7 +700,7 @@
if (!entry)
return ParseFailure;
- module->affineMapList.push_back(entry);
+ getModule()->affineMapList.push_back(entry);
return ParseSuccess;
}
@@ -736,7 +759,7 @@
/// Consume this token if it is a lower precedence affine op (there are only two
/// precedence levels).
AffineLowPrecOp Parser::consumeIfLowPrecOp() {
- switch (curToken.getKind()) {
+ switch (getToken().getKind()) {
case Token::plus:
consumeToken(Token::plus);
return AffineLowPrecOp::Add;
@@ -751,7 +774,7 @@
/// Consume this token if it is a higher precedence affine op (there are only
/// two precedence levels)
AffineHighPrecOp Parser::consumeIfHighPrecOp() {
- switch (curToken.getKind()) {
+ switch (getToken().getKind()) {
case Token::star:
consumeToken(Token::star);
return Mul;
@@ -811,7 +834,7 @@
AffineExpr *Parser::parseParentheticalExpr(const AffineMapParserState &state) {
if (!consumeIf(Token::l_paren))
return (emitError("expected '('"), nullptr);
- if (curToken.is(Token::r_paren))
+ if (getToken().is(Token::r_paren))
return (emitError("no expression inside parentheses"), nullptr);
auto *expr = parseAffineExpr(state);
if (!expr)
@@ -847,10 +870,10 @@
///
/// affine-expr ::= bare-id
AffineExpr *Parser::parseBareIdExpr(const AffineMapParserState &state) {
- if (curToken.isNot(Token::bare_identifier))
+ if (getToken().isNot(Token::bare_identifier))
return (emitError("expected bare identifier"), nullptr);
- StringRef sRef = curToken.getSpelling();
+ StringRef sRef = getTokenSpelling();
const auto &dims = state.getDims();
const auto &symbols = state.getSymbols();
if (dims.count(sRef)) {
@@ -871,10 +894,10 @@
// No need to handle negative numbers separately here. They are naturally
// handled via the unary negation operator, although (FIXME) MININT_64 still
// not correctly handled.
- if (curToken.isNot(Token::integer))
+ if (getToken().isNot(Token::integer))
return (emitError("expected integer"), nullptr);
- auto val = curToken.getUInt64IntegerValue();
+ auto val = getToken().getUInt64IntegerValue();
if (!val.hasValue() || (int64_t)val.getValue() < 0) {
return (emitError("constant too large for affineint"), nullptr);
}
@@ -893,7 +916,7 @@
// are valid operands that will be parsed by this function.
AffineExpr *Parser::parseAffineOperandExpr(AffineExpr *lhs,
const AffineMapParserState &state) {
- switch (curToken.getKind()) {
+ switch (getToken().getKind()) {
case Token::bare_identifier:
return parseBareIdExpr(state);
case Token::integer:
@@ -1009,9 +1032,9 @@
/// identifier. 'dim': whether it's the dim list or symbol list that is being
/// parsed.
ParseResult Parser::parseDimOrSymbolId(AffineMapParserState &state, bool dim) {
- if (curToken.isNot(Token::bare_identifier))
+ if (getToken().isNot(Token::bare_identifier))
return emitError("expected bare identifier");
- auto sRef = curToken.getSpelling();
+ auto sRef = getTokenSpelling();
consumeToken(Token::bare_identifier);
if (state.getDims().count(sRef) == 1)
return emitError("dimensional identifier name reused");
@@ -1061,7 +1084,7 @@
return nullptr;
// Symbols are optional.
- if (curToken.is(Token::l_bracket)) {
+ if (getToken().is(Token::l_bracket)) {
if (parseSymbolIdList(state))
return nullptr;
}
@@ -1101,8 +1124,8 @@
/// ssa-use ::= ssa-id | ssa-constant
///
ParseResult Parser::parseSSAUse() {
- if (curToken.is(Token::percent_identifier)) {
- StringRef name = curToken.getSpelling().drop_front();
+ if (getToken().is(Token::percent_identifier)) {
+ StringRef name = getTokenSpelling().drop_front();
consumeToken(Token::percent_identifier);
// TODO: Return this use.
(void)name;
@@ -1163,13 +1186,13 @@
///
ParseResult Parser::parseFunctionSignature(StringRef &name,
FunctionType *&type) {
- if (curToken.isNot(Token::at_identifier))
+ if (getToken().isNot(Token::at_identifier))
return emitError("expected a function identifier like '@foo'");
- name = curToken.getSpelling().drop_front();
+ name = getTokenSpelling().drop_front();
consumeToken(Token::at_identifier);
- if (curToken.isNot(Token::l_paren))
+ if (getToken().isNot(Token::l_paren))
return emitError("expected '(' in function signature");
SmallVector<Type*, 4> arguments;
@@ -1199,23 +1222,28 @@
return ParseFailure;
// Okay, the external function definition was parsed correctly.
- module->functionList.push_back(new ExtFunction(name, type));
+ getModule()->functionList.push_back(new ExtFunction(name, type));
return ParseSuccess;
}
+//===----------------------------------------------------------------------===//
+// CFG Functions
+//===----------------------------------------------------------------------===//
namespace {
-/// This class represents the transient parser state for the internals of a
-/// function as we are parsing it, e.g. the names for basic blocks. It handles
-/// forward references.
-class CFGFunctionParserState {
+/// This is a specialized parser for CFGFunction's, maintaining the state
+/// transient to their bodies.
+class CFGFunctionParser : public Parser {
public:
CFGFunction *function;
llvm::StringMap<std::pair<BasicBlock*, SMLoc>> blocksByName;
+
+ /// This builder intentionally shadows the builder in the base class, with a
+ /// more specific builder type.
CFGFuncBuilder builder;
- CFGFunctionParserState(CFGFunction *function)
- : function(function), builder(function) {}
+ CFGFunctionParser(ParserState &state, CFGFunction *function)
+ : Parser(state), function(function), builder(function) {}
/// Get the basic block with the specified name, creating it if it doesn't
/// already exist. The location specified is the point of use, which allows
@@ -1228,6 +1256,11 @@
}
return blockAndLoc.first;
}
+
+ ParseResult parseFunctionBody();
+ ParseResult parseBasicBlock();
+ OperationInst *parseCFGOperation();
+ TerminatorInst *parseTerminator();
};
} // end anonymous namespace
@@ -1244,26 +1277,29 @@
if (parseFunctionSignature(name, type))
return ParseFailure;
- if (!consumeIf(Token::l_brace))
- return emitError("expected '{' in CFG function");
-
// Okay, the CFG function signature was parsed correctly, create the function.
auto function = new CFGFunction(name, type);
- // Make sure we have at least one block.
- if (curToken.is(Token::r_brace))
- return emitError("CFG functions must have at least one basic block");
+ CFGFunctionParser cfgFuncParser(state, function);
+ return cfgFuncParser.parseFunctionBody();
+}
- CFGFunctionParserState functionState(function);
+ParseResult CFGFunctionParser::parseFunctionBody() {
+ if (!consumeIf(Token::l_brace))
+ return emitError("expected '{' in CFG function");
+
+ // Make sure we have at least one block.
+ if (getToken().is(Token::r_brace))
+ return emitError("CFG functions must have at least one basic block");
// Parse the list of blocks.
while (!consumeIf(Token::r_brace))
- if (parseBasicBlock(functionState))
+ if (parseBasicBlock())
return ParseFailure;
// Verify that all referenced blocks were defined. Iteration over a
// StringMap isn't determinstic, but this is good enough for our purposes.
- for (auto &elt : functionState.blocksByName) {
+ for (auto &elt : blocksByName) {
auto *bb = elt.second.first;
if (!bb->getFunction())
return emitError(elt.second.second,
@@ -1271,7 +1307,7 @@
elt.first() + "'");
}
- module->functionList.push_back(function);
+ getModule()->functionList.push_back(function);
return ParseSuccess;
}
@@ -1282,13 +1318,13 @@
/// bb-id ::= bare-id
/// bb-arg-list ::= `(` ssa-id-and-type-list? `)`
///
-ParseResult Parser::parseBasicBlock(CFGFunctionParserState &functionState) {
- SMLoc nameLoc = curToken.getLoc();
- auto name = curToken.getSpelling();
+ParseResult CFGFunctionParser::parseBasicBlock() {
+ SMLoc nameLoc = getToken().getLoc();
+ auto name = getTokenSpelling();
if (!consumeIf(Token::bare_identifier))
return emitError("expected basic block name");
- auto block = functionState.getBlockNamed(name, nameLoc);
+ auto *block = getBlockNamed(name, nameLoc);
// If this block has already been parsed, then this is a redefinition with the
// same block name.
@@ -1296,7 +1332,7 @@
return emitError(nameLoc, "redefinition of block '" + name.str() + "'");
// Add the block to the function.
- functionState.function->push_back(block);
+ function->push_back(block);
// If an argument list is present, parse it.
if (consumeIf(Token::l_paren)) {
@@ -1310,12 +1346,12 @@
return emitError("expected ':' after basic block name");
// Set the insertion point to the block we want to insert new operations into.
- functionState.builder.setInsertionPoint(block);
+ builder.setInsertionPoint(block);
// Parse the list of operations that make up the body of the block.
- while (curToken.isNot(Token::kw_return, Token::kw_br)) {
- auto loc = curToken.getLoc();
- auto *inst = parseCFGOperation(functionState);
+ while (getToken().isNot(Token::kw_return, Token::kw_br)) {
+ auto loc = getToken().getLoc();
+ auto *inst = parseCFGOperation();
if (!inst)
return ParseFailure;
@@ -1327,14 +1363,13 @@
return emitError(loc, error);
}
- auto *term = parseTerminator(functionState);
+ auto *term = parseTerminator();
if (!term)
return ParseFailure;
return ParseSuccess;
}
-
/// Parse the CFG operation.
///
/// TODO(clattner): This is a change from the MLIR spec as written, it is an
@@ -1344,21 +1379,19 @@
/// (ssa-id `=`)? string '(' ssa-use-list? ')' attribute-dict?
/// `:` function-type
///
-OperationInst *Parser::
-parseCFGOperation(CFGFunctionParserState &functionState) {
-
+OperationInst *CFGFunctionParser::parseCFGOperation() {
StringRef resultID;
- if (curToken.is(Token::percent_identifier)) {
- resultID = curToken.getSpelling().drop_front();
+ if (getToken().is(Token::percent_identifier)) {
+ resultID = getTokenSpelling().drop_front();
consumeToken();
if (!consumeIf(Token::equal))
return (emitError("expected '=' after SSA name"), nullptr);
}
- if (curToken.isNot(Token::string))
+ if (getToken().isNot(Token::string))
return (emitError("expected operation name in quotes"), nullptr);
- auto name = curToken.getStringValue();
+ auto name = getToken().getStringValue();
if (name.empty())
return (emitError("empty operation name is invalid"), nullptr);
@@ -1371,17 +1404,16 @@
parseOptionalSSAUseList(Token::r_paren);
SmallVector<NamedAttribute, 4> attributes;
- if (curToken.is(Token::l_brace)) {
+ if (getToken().is(Token::l_brace)) {
if (parseAttributeDict(attributes))
return nullptr;
}
// TODO: Don't drop result name and operand names on the floor.
auto nameId = Identifier::get(name, builder.getContext());
- return functionState.builder.createOperation(nameId, attributes);
+ return builder.createOperation(nameId, attributes);
}
-
/// Parse the terminator instruction for a basic block.
///
/// terminator-stmt ::= `br` bb-id branch-use-list?
@@ -1390,27 +1422,51 @@
/// `cond_br` ssa-use `,` bb-id branch-use-list? `,` bb-id branch-use-list?
/// terminator-stmt ::= `return` ssa-use-and-type-list?
///
-TerminatorInst *Parser::parseTerminator(CFGFunctionParserState &functionState) {
- switch (curToken.getKind()) {
+TerminatorInst *CFGFunctionParser::parseTerminator() {
+ switch (getToken().getKind()) {
default:
return (emitError("expected terminator at end of basic block"), nullptr);
case Token::kw_return:
consumeToken(Token::kw_return);
- return functionState.builder.createReturnInst();
+ return builder.createReturnInst();
case Token::kw_br: {
consumeToken(Token::kw_br);
- auto destBB = functionState.getBlockNamed(curToken.getSpelling(),
- curToken.getLoc());
+ auto destBB = getBlockNamed(getTokenSpelling(), getToken().getLoc());
if (!consumeIf(Token::bare_identifier))
return (emitError("expected basic block name"), nullptr);
- return functionState.builder.createBranchInst(destBB);
+ return builder.createBranchInst(destBB);
}
// TODO: cond_br.
}
}
+//===----------------------------------------------------------------------===//
+// ML Functions
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Refined parser for MLFunction bodies.
+class MLFunctionParser : public Parser {
+public:
+ MLFunction *function;
+
+ /// This builder intentionally shadows the builder in the base class, with a
+ /// more specific builder type.
+ // TODO: MLFuncBuilder builder;
+
+ MLFunctionParser(ParserState &state, MLFunction *function)
+ : Parser(state), function(function) {}
+
+ ParseResult parseFunctionBody();
+ Statement *parseStatement(ParentType parent);
+ ForStmt *parseForStmt(ParentType parent);
+ IfStmt *parseIfStmt(ParentType parent);
+ ParseResult parseNestedStatements(NodeStmt *parent);
+};
+} // end anonymous namespace
+
/// ML function declarations.
///
/// ml-func ::= `mlfunc` ml-func-signature `{` ml-stmt* ml-return-stmt `}`
@@ -1426,14 +1482,19 @@
if (parseFunctionSignature(name, type))
return ParseFailure;
- if (!consumeIf(Token::l_brace))
- return emitError("expected '{' in ML function");
-
// Okay, the ML function signature was parsed correctly, create the function.
auto function = new MLFunction(name, type);
+ MLFunctionParser mlFuncParser(state, function);
+ return mlFuncParser.parseFunctionBody();
+}
+
+ParseResult MLFunctionParser::parseFunctionBody() {
+ if (!consumeIf(Token::l_brace))
+ return emitError("expected '{' in ML function");
+
// Make sure we have at least one statement.
- if (curToken.is(Token::r_brace))
+ if (getToken().is(Token::r_brace))
return emitError("ML function must end with return statement");
// Parse the list of instructions.
@@ -1448,19 +1509,20 @@
if (!consumeIf(Token::r_brace))
emitError("expected '}' in ML function");
- module->functionList.push_back(function);
+ getModule()->functionList.push_back(function);
return ParseSuccess;
}
/// Statement.
///
-/// ml-stmt ::= instruction | ml-for-stmt | ml-if-stmt
+/// ml-stmt ::= instruction | ml-for-stmt | ml-if-stmt
+///
/// TODO: fix terminology in MLSpec document. ML functions
/// contain operation statements, not instructions.
///
-Statement * Parser::parseStatement(ParentType parent) {
- switch (curToken.getKind()) {
+Statement *MLFunctionParser::parseStatement(ParentType parent) {
+ switch (getToken().getKind()) {
default:
//TODO: parse OperationStmt
return (emitError("expected statement"), nullptr);
@@ -1475,10 +1537,10 @@
/// For statement.
///
-/// ml-for-stmt ::= `for` ssa-id `=` lower-bound `to` upper-bound
-/// (`step` integer-literal)? `{` ml-stmt* `}`
+/// ml-for-stmt ::= `for` ssa-id `=` lower-bound `to` upper-bound
+/// (`step` integer-literal)? `{` ml-stmt* `}`
///
-ForStmt * Parser::parseForStmt(ParentType parent) {
+ForStmt *MLFunctionParser::parseForStmt(ParentType parent) {
consumeToken(Token::kw_for);
//TODO: parse loop header
@@ -1492,12 +1554,13 @@
/// If statement.
///
-/// ml-if-head ::= `if` ml-if-cond `{` ml-stmt* `}`
-/// | ml-if-head `else` `if` ml-if-cond `{` ml-stmt* `}`
-/// ml-if-stmt ::= ml-if-head
-/// | ml-if-head `else` `{` ml-stmt* `}`
+/// ml-if-head ::= `if` ml-if-cond `{` ml-stmt* `}`
+/// | ml-if-head `else` `if` ml-if-cond `{` ml-stmt* `}`
+/// ml-if-stmt ::= ml-if-head
+/// | ml-if-head `else` `{` ml-stmt* `}`
///
-IfStmt * Parser::parseIfStmt(PointerUnion<MLFunction *, NodeStmt *> parent) {
+IfStmt *
+MLFunctionParser::parseIfStmt(PointerUnion<MLFunction *, NodeStmt *> parent) {
consumeToken(Token::kw_if);
//TODO: parse condition
@@ -1526,7 +1589,7 @@
///
/// Parse `{` ml-stmt* `}`
///
-ParseResult Parser::parseNestedStatements(NodeStmt *parent) {
+ParseResult MLFunctionParser::parseNestedStatements(NodeStmt *parent) {
if (!consumeIf(Token::l_brace))
return emitError("expected '{' before statement list");
@@ -1552,14 +1615,14 @@
/// This is the top-level module parser.
Module *Parser::parseModule() {
while (1) {
- switch (curToken.getKind()) {
+ switch (getToken().getKind()) {
default:
emitError("expected a top level entity");
return nullptr;
// If we got to the end of the file, then we're done.
case Token::eof:
- return module.release();
+ return state.module.release();
// If we got an error token, then the lexer already emitted an error, just
// stop. Someday we could introduce error recovery if there was demand for
@@ -1599,10 +1662,10 @@
/// MLIR module if it was valid. If not, it emits diagnostics and returns null.
Module *mlir::parseSourceFile(llvm::SourceMgr &sourceMgr, MLIRContext *context,
SMDiagnosticHandlerTy errorReporter) {
- auto *result =
- Parser(sourceMgr, context,
- errorReporter ? std::move(errorReporter) : defaultErrorReporter)
- .parseModule();
+ ParserState state(sourceMgr, context,
+ errorReporter ? std::move(errorReporter)
+ : defaultErrorReporter);
+ auto *result = Parser(state).parseModule();
// Make sure the parse module has no other structural problems detected by the
// verifier.