Add the unconditional branch instruction, improve diagnostics for block
references.
PiperOrigin-RevId: 201872745
diff --git a/include/mlir/IR/Instructions.h b/include/mlir/IR/Instructions.h
index 95e3459..662ebfc 100644
--- a/include/mlir/IR/Instructions.h
+++ b/include/mlir/IR/Instructions.h
@@ -34,6 +34,7 @@
class TerminatorInst {
public:
enum class Kind {
+ Branch,
Return
};
@@ -56,11 +57,36 @@
BasicBlock *block;
};
+/// The 'br' instruction is an unconditional from one basic block to another,
+/// and may pass basic block arguments to the successor.
+class BranchInst : public TerminatorInst {
+public:
+ explicit BranchInst(BasicBlock *dest, BasicBlock *parent);
+
+ /// Return the block this branch jumps to.
+ BasicBlock *getDest() const {
+ return dest;
+ }
+
+ // TODO: need to take BB arguments.
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast.
+ static bool classof(const TerminatorInst *inst) {
+ return inst->getKind() == Kind::Branch;
+ }
+private:
+ BasicBlock *dest;
+};
+
+
+/// The 'return' instruction represents the end of control flow within the
+/// current function, and can return zero or more results. The result list is
+/// required to align with the result list of the containing function's type.
class ReturnInst : public TerminatorInst {
public:
- explicit ReturnInst(BasicBlock *block);
- // TODO: Flesh this out.
+ explicit ReturnInst(BasicBlock *parent);
+ // TODO: Needs to take an operand list.
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const TerminatorInst *inst) {
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index f857963..08fe838 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -120,6 +120,9 @@
void CFGFunctionState::print(const TerminatorInst *inst) {
switch (inst->getKind()) {
+ case TerminatorInst::Kind::Branch:
+ os << " br bb" << getBBID(cast<BranchInst>(inst)->getDest()) << "\n";
+ break;
case TerminatorInst::Kind::Return:
os << " return\n";
break;
diff --git a/lib/IR/BasicBlock.cpp b/lib/IR/BasicBlock.cpp
index ad8f71e..4cfe162 100644
--- a/lib/IR/BasicBlock.cpp
+++ b/lib/IR/BasicBlock.cpp
@@ -16,7 +16,9 @@
// =============================================================================
#include "mlir/IR/BasicBlock.h"
+#include "mlir/IR/CFGFunction.h"
using namespace mlir;
BasicBlock::BasicBlock(CFGFunction *function) : function(function) {
+ function->blockList.push_back(this);
}
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index c123b86..c32e878 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -23,6 +23,10 @@
return getBlock()->getFunction();
}
-ReturnInst::ReturnInst(BasicBlock *block) : TerminatorInst(Kind::Return, block){
+ReturnInst::ReturnInst(BasicBlock *parent)
+ : TerminatorInst(Kind::Return, parent) {
}
+BranchInst::BranchInst(BasicBlock *dest, BasicBlock *parent)
+ : TerminatorInst(Kind::Branch, parent), dest(dest) {
+}
diff --git a/lib/Parser/Lexer.cpp b/lib/Parser/Lexer.cpp
index 5d224c7..b192f71 100644
--- a/lib/Parser/Lexer.cpp
+++ b/lib/Parser/Lexer.cpp
@@ -126,18 +126,19 @@
/// Lex a bare identifier or keyword that starts with a letter.
///
-/// bare-id ::= letter (letter|digit)*
+/// bare-id ::= letter (letter|digit|[_])*
///
Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) {
- // Match the rest of the identifier regex: [0-9a-zA-Z]*
- while (isalpha(*curPtr) || isdigit(*curPtr))
+ // Match the rest of the identifier regex: [0-9a-zA-Z_]*
+ while (isalpha(*curPtr) || isdigit(*curPtr) || *curPtr == '_')
++curPtr;
// Check to see if this identifier is a keyword.
StringRef spelling(tokStart, curPtr-tokStart);
Token::TokenKind kind = llvm::StringSwitch<Token::TokenKind>(spelling)
- .Case("bf16", Token::kw_bf16)
+ .Case("bf16", Token::kw_bf16)
+ .Case("br", Token::kw_br)
.Case("cfgfunc", Token::kw_cfgfunc)
.Case("extfunc", Token::kw_extfunc)
.Case("f16", Token::kw_f16)
@@ -168,7 +169,7 @@
if (!isalpha(*curPtr++))
return emitError(curPtr-1, "expected letter in @ identifier");
- while (isalpha(*curPtr) || isdigit(*curPtr))
+ while (isalpha(*curPtr) || isdigit(*curPtr) || *curPtr == '_')
++curPtr;
return formToken(Token::at_identifier, tokStart);
}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index d5cc707..828a8d5 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -119,6 +119,9 @@
ParseResult parseExtFunc();
ParseResult parseCFGFunc();
ParseResult parseBasicBlock(CFGFunctionParserState &functionState);
+ TerminatorInst *parseTerminator(BasicBlock *currentBB,
+ CFGFunctionParserState &functionState);
+
};
} // end anonymous namespace
@@ -513,23 +516,22 @@
/// forward references.
class CFGFunctionParserState {
public:
+ CFGFunction *function;
+ llvm::StringMap<std::pair<BasicBlock*, SMLoc>> blocksByName;
+
CFGFunctionParserState(CFGFunction *function) : function(function) {}
/// Get the basic block with the specified name, creating it if it doesn't
- /// already exist.
- BasicBlock *getBlockNamed(StringRef name) {
- auto *&block = blocksByName[name];
- if (!block) {
- block = new BasicBlock(function);
- // TODO: Should be automatic when we have the right function
- // representation.
- function->blockList.push_back(block);
+ /// already exist. The location specified is the point of use, which allows
+ /// us to diagnose references to blocks that are not defined precisely.
+ BasicBlock *getBlockNamed(StringRef name, SMLoc loc) {
+ auto &blockAndLoc = blocksByName[name];
+ if (!blockAndLoc.first) {
+ blockAndLoc.first = new BasicBlock(function);
+ blockAndLoc.second = loc;
}
- return block;
+ return blockAndLoc.first;
}
-private:
- CFGFunction *function;
- llvm::StringMap<BasicBlock*> blocksByName;
};
} // end anonymous namespace
@@ -563,6 +565,16 @@
if (parseBasicBlock(functionState))
return ParseFailure;
+ // Verify that all referenced blocks were defined. Iteration over a
+ // StringMap isn't determinstic, but this is good enough for our purposes.
+ for (auto &elt : functionState.blocksByName) {
+ auto *bb = elt.second.first;
+ if (!bb->getTerminator())
+ return emitError(elt.second.second,
+ "reference to an undefined basic block '" +
+ elt.first() + "'");
+ }
+
module->functionList.push_back(function);
return ParseSuccess;
}
@@ -579,13 +591,25 @@
auto name = curToken.getSpelling();
if (!consumeIf(Token::bare_identifier))
return emitError("expected basic block name");
- auto block = functionState.getBlockNamed(name);
+
+ auto block = functionState.getBlockNamed(name, nameLoc);
// If this block has already been parsed, then this is a redefinition with the
// same block name.
if (block->getTerminator())
- return emitError(nameLoc, "redefinition of block named '" +
- name.str() + "'");
+ return emitError(nameLoc, "redefinition of block '" + name.str() + "'");
+
+ // References to blocks can occur in any order, but we need to reassemble the
+ // function in the order that occurs in the source file. Do this by moving
+ // each block to the end of the list as it is defined.
+ // FIXME: This is inefficient for large functions given that blockList is a
+ // vector. blockList will eventually be an ilist, which will make this fast.
+ auto &blockList = functionState.function->blockList;
+ if (blockList.back() != block) {
+ auto it = std::find(blockList.begin(), blockList.end(), block);
+ assert(it != blockList.end() && "Block has to be in the blockList");
+ std::swap(*it, blockList.back());
+ }
// TODO: parse bb argument list.
@@ -602,15 +626,46 @@
// TODO: parse instruction list.
// TODO: Generalize this once instruction list parsing is built out.
- if (!consumeIf(Token::kw_return))
- return emitError("expected 'return' at end of basic block");
- block->setTerminator(new ReturnInst(block));
+ auto *termInst = parseTerminator(block, functionState);
+ if (!termInst)
+ return ParseFailure;
+ block->setTerminator(termInst);
return ParseSuccess;
}
+/// Parse the terminator instruction for a basic block.
+///
+/// terminator-stmt ::= `br` bb-id branch-use-list?
+/// branch-use-list ::= `(` ssa-use-and-type-list? `)`
+/// terminator-stmt ::=
+/// `cond_br` ssa-use `,` bb-id branch-use-list? `,` bb-id branch-use-list?
+/// terminator-stmt ::= `return` ssa-use-and-type-list?
+///
+TerminatorInst *Parser::parseTerminator(BasicBlock *currentBB,
+ CFGFunctionParserState &functionState) {
+ switch (curToken.getKind()) {
+ default:
+ return (emitError("expected terminator at end of basic block"), nullptr);
+
+ case Token::kw_return:
+ consumeToken(Token::kw_return);
+ return new ReturnInst(currentBB);
+
+ case Token::kw_br: {
+ consumeToken(Token::kw_br);
+ auto destBB = functionState.getBlockNamed(curToken.getSpelling(),
+ curToken.getLoc());
+ if (!consumeIf(Token::bare_identifier))
+ return (emitError("expected basic block name"), nullptr);
+ return new BranchInst(destBB, currentBB);
+ }
+ }
+}
+
+
//===----------------------------------------------------------------------===//
// Top-level entity parsing.
//===----------------------------------------------------------------------===//
diff --git a/lib/Parser/Token.h b/lib/Parser/Token.h
index c8f856c..dde0722 100644
--- a/lib/Parser/Token.h
+++ b/lib/Parser/Token.h
@@ -51,6 +51,7 @@
// Keywords.
kw_bf16,
+ kw_br,
kw_cfgfunc,
kw_extfunc,
kw_f16,
diff --git a/test/IR/parser-errors.mlir b/test/IR/parser-errors.mlir
index 288222b..e572fa2 100644
--- a/test/IR/parser-errors.mlir
+++ b/test/IR/parser-errors.mlir
@@ -6,18 +6,39 @@
; Check different error cases.
; TODO(jpienaar): This is checking the errors by simplify verifying the output.
; -----
+
; CHECK: expected type
; CHECK-NEXT: illegaltype
extfunc @illegaltype(i42)
+
; -----
; CHECK: expected type
; CHECK-NEXT: nestedtensor
extfunc @nestedtensor(tensor<tensor<i8>>) -> ()
+
; -----
; CHECK: expected '{' in CFG function
cfgfunc @foo()
cfgfunc @bar()
+
; -----
; CHECK: expected a function identifier like
; CHECK-NEXT: missingsigil
extfunc missingsigil() -> (i1, int, f32)
+
+
+; -----
+
+cfgfunc @bad_branch() {
+bb42:
+ br missing ; CHECK: error: reference to an undefined basic block 'missing'
+}
+
+; -----
+
+cfgfunc @block_redef() {
+bb42:
+ return
+bb42: ; CHECK: error: redefinition of block 'bb42'
+ return
+}
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index fe46abf..e307a3a 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -40,7 +40,11 @@
; CHECK-LABEL: cfgfunc @multiblock() -> i32 {
cfgfunc @multiblock() -> i32 {
bb0: ; CHECK: bb0:
- return ; CHECK: return
-bb4: ; CHECK: bb1:
- return ; CHECK: return
+ return ; CHECK: return
+bb1: ; CHECK: bb1:
+ br bb4 ; CHECK: br bb3
+bb2: ; CHECK: bb2:
+ br bb2 ; CHECK: br bb2
+bb4: ; CHECK: bb3:
+ return ; CHECK: return
} ; CHECK: }