Add the unconditional branch instruction, improve diagnostics for block
references.

PiperOrigin-RevId: 201872745
diff --git a/include/mlir/IR/Instructions.h b/include/mlir/IR/Instructions.h
index 95e3459..662ebfc 100644
--- a/include/mlir/IR/Instructions.h
+++ b/include/mlir/IR/Instructions.h
@@ -34,6 +34,7 @@
 class TerminatorInst {
 public:
   enum class Kind {
+    Branch,
     Return
   };
 
@@ -56,11 +57,36 @@
   BasicBlock *block;
 };
 
+/// The 'br' instruction is an unconditional from one basic block to another,
+/// and may pass basic block arguments to the successor.
+class BranchInst : public TerminatorInst {
+public:
+  explicit BranchInst(BasicBlock *dest, BasicBlock *parent);
+
+  /// Return the block this branch jumps to.
+  BasicBlock *getDest() const {
+    return dest;
+  }
+
+  // TODO: need to take BB arguments.
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool classof(const TerminatorInst *inst) {
+    return inst->getKind() == Kind::Branch;
+  }
+private:
+  BasicBlock *dest;
+};
+
+
+/// The 'return' instruction represents the end of control flow within the
+/// current function, and can return zero or more results.  The result list is
+/// required to align with the result list of the containing function's type.
 class ReturnInst : public TerminatorInst {
 public:
-  explicit ReturnInst(BasicBlock *block);
-  // TODO: Flesh this out.
+  explicit ReturnInst(BasicBlock *parent);
 
+  // TODO: Needs to take an operand list.
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(const TerminatorInst *inst) {
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index f857963..08fe838 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -120,6 +120,9 @@
 
 void CFGFunctionState::print(const TerminatorInst *inst) {
   switch (inst->getKind()) {
+  case TerminatorInst::Kind::Branch:
+    os << "  br bb" << getBBID(cast<BranchInst>(inst)->getDest()) << "\n";
+    break;
   case TerminatorInst::Kind::Return:
     os << "  return\n";
     break;
diff --git a/lib/IR/BasicBlock.cpp b/lib/IR/BasicBlock.cpp
index ad8f71e..4cfe162 100644
--- a/lib/IR/BasicBlock.cpp
+++ b/lib/IR/BasicBlock.cpp
@@ -16,7 +16,9 @@
 // =============================================================================
 
 #include "mlir/IR/BasicBlock.h"
+#include "mlir/IR/CFGFunction.h"
 using namespace mlir;
 
 BasicBlock::BasicBlock(CFGFunction *function) : function(function) {
+  function->blockList.push_back(this);
 }
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index c123b86..c32e878 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -23,6 +23,10 @@
   return getBlock()->getFunction();
 }
 
-ReturnInst::ReturnInst(BasicBlock *block) : TerminatorInst(Kind::Return, block){
+ReturnInst::ReturnInst(BasicBlock *parent)
+  : TerminatorInst(Kind::Return, parent) {
 }
 
+BranchInst::BranchInst(BasicBlock *dest, BasicBlock *parent)
+  : TerminatorInst(Kind::Branch, parent), dest(dest) {
+}
diff --git a/lib/Parser/Lexer.cpp b/lib/Parser/Lexer.cpp
index 5d224c7..b192f71 100644
--- a/lib/Parser/Lexer.cpp
+++ b/lib/Parser/Lexer.cpp
@@ -126,18 +126,19 @@
 
 /// Lex a bare identifier or keyword that starts with a letter.
 ///
-///   bare-id ::= letter (letter|digit)*
+///   bare-id ::= letter (letter|digit|[_])*
 ///
 Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) {
-  // Match the rest of the identifier regex: [0-9a-zA-Z]*
-  while (isalpha(*curPtr) || isdigit(*curPtr))
+  // Match the rest of the identifier regex: [0-9a-zA-Z_]*
+  while (isalpha(*curPtr) || isdigit(*curPtr) || *curPtr == '_')
     ++curPtr;
 
   // Check to see if this identifier is a keyword.
   StringRef spelling(tokStart, curPtr-tokStart);
 
   Token::TokenKind kind = llvm::StringSwitch<Token::TokenKind>(spelling)
-    .Case("bf16",    Token::kw_bf16)
+    .Case("bf16", Token::kw_bf16)
+    .Case("br", Token::kw_br)
     .Case("cfgfunc", Token::kw_cfgfunc)
     .Case("extfunc", Token::kw_extfunc)
     .Case("f16", Token::kw_f16)
@@ -168,7 +169,7 @@
   if (!isalpha(*curPtr++))
     return emitError(curPtr-1, "expected letter in @ identifier");
 
-  while (isalpha(*curPtr) || isdigit(*curPtr))
+  while (isalpha(*curPtr) || isdigit(*curPtr) || *curPtr == '_')
     ++curPtr;
   return formToken(Token::at_identifier, tokStart);
 }
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index d5cc707..828a8d5 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -119,6 +119,9 @@
   ParseResult parseExtFunc();
   ParseResult parseCFGFunc();
   ParseResult parseBasicBlock(CFGFunctionParserState &functionState);
+  TerminatorInst *parseTerminator(BasicBlock *currentBB,
+                                  CFGFunctionParserState &functionState);
+
 };
 } // end anonymous namespace
 
@@ -513,23 +516,22 @@
 /// forward references.
 class CFGFunctionParserState {
 public:
+  CFGFunction *function;
+  llvm::StringMap<std::pair<BasicBlock*, SMLoc>> blocksByName;
+
   CFGFunctionParserState(CFGFunction *function) : function(function) {}
 
   /// Get the basic block with the specified name, creating it if it doesn't
-  /// already exist.
-  BasicBlock *getBlockNamed(StringRef name) {
-    auto *&block = blocksByName[name];
-    if (!block) {
-      block = new BasicBlock(function);
-      // TODO: Should be automatic when we have the right function
-      // representation.
-      function->blockList.push_back(block);
+  /// already exist.  The location specified is the point of use, which allows
+  /// us to diagnose references to blocks that are not defined precisely.
+  BasicBlock *getBlockNamed(StringRef name, SMLoc loc) {
+    auto &blockAndLoc = blocksByName[name];
+    if (!blockAndLoc.first) {
+      blockAndLoc.first = new BasicBlock(function);
+      blockAndLoc.second = loc;
     }
-    return block;
+    return blockAndLoc.first;
   }
-private:
-  CFGFunction *function;
-  llvm::StringMap<BasicBlock*> blocksByName;
 };
 } // end anonymous namespace
 
@@ -563,6 +565,16 @@
     if (parseBasicBlock(functionState))
       return ParseFailure;
 
+  // Verify that all referenced blocks were defined.  Iteration over a
+  // StringMap isn't determinstic, but this is good enough for our purposes.
+  for (auto &elt : functionState.blocksByName) {
+    auto *bb = elt.second.first;
+    if (!bb->getTerminator())
+      return emitError(elt.second.second,
+                       "reference to an undefined basic block '" +
+                       elt.first() + "'");
+  }
+
   module->functionList.push_back(function);
   return ParseSuccess;
 }
@@ -579,13 +591,25 @@
   auto name = curToken.getSpelling();
   if (!consumeIf(Token::bare_identifier))
     return emitError("expected basic block name");
-  auto block = functionState.getBlockNamed(name);
+
+  auto block = functionState.getBlockNamed(name, nameLoc);
 
   // If this block has already been parsed, then this is a redefinition with the
   // same block name.
   if (block->getTerminator())
-    return emitError(nameLoc, "redefinition of block named '" +
-                     name.str() + "'");
+    return emitError(nameLoc, "redefinition of block '" + name.str() + "'");
+
+  // References to blocks can occur in any order, but we need to reassemble the
+  // function in the order that occurs in the source file.  Do this by moving
+  // each block to the end of the list as it is defined.
+  // FIXME: This is inefficient for large functions given that blockList is a
+  // vector.  blockList will eventually be an ilist, which will make this fast.
+  auto &blockList = functionState.function->blockList;
+  if (blockList.back() != block) {
+    auto it = std::find(blockList.begin(), blockList.end(), block);
+    assert(it != blockList.end() && "Block has to be in the blockList");
+    std::swap(*it, blockList.back());
+  }
 
   // TODO: parse bb argument list.
 
@@ -602,15 +626,46 @@
   // TODO: parse instruction list.
 
   // TODO: Generalize this once instruction list parsing is built out.
-  if (!consumeIf(Token::kw_return))
-    return emitError("expected 'return' at end of basic block");
 
-  block->setTerminator(new ReturnInst(block));
+  auto *termInst = parseTerminator(block, functionState);
+  if (!termInst)
+    return ParseFailure;
+  block->setTerminator(termInst);
 
   return ParseSuccess;
 }
 
 
+/// Parse the terminator instruction for a basic block.
+///
+///   terminator-stmt ::= `br` bb-id branch-use-list?
+///   branch-use-list ::= `(` ssa-use-and-type-list? `)`
+///   terminator-stmt ::=
+///     `cond_br` ssa-use `,` bb-id branch-use-list? `,` bb-id branch-use-list?
+///   terminator-stmt ::= `return` ssa-use-and-type-list?
+///
+TerminatorInst *Parser::parseTerminator(BasicBlock *currentBB,
+                                        CFGFunctionParserState &functionState) {
+  switch (curToken.getKind()) {
+  default:
+    return (emitError("expected terminator at end of basic block"), nullptr);
+
+  case Token::kw_return:
+    consumeToken(Token::kw_return);
+    return new ReturnInst(currentBB);
+
+  case Token::kw_br: {
+    consumeToken(Token::kw_br);
+    auto destBB = functionState.getBlockNamed(curToken.getSpelling(),
+                                              curToken.getLoc());
+    if (!consumeIf(Token::bare_identifier))
+      return (emitError("expected basic block name"), nullptr);
+    return new BranchInst(destBB, currentBB);
+  }
+  }
+}
+
+
 //===----------------------------------------------------------------------===//
 // Top-level entity parsing.
 //===----------------------------------------------------------------------===//
diff --git a/lib/Parser/Token.h b/lib/Parser/Token.h
index c8f856c..dde0722 100644
--- a/lib/Parser/Token.h
+++ b/lib/Parser/Token.h
@@ -51,6 +51,7 @@
 
     // Keywords.
     kw_bf16,
+    kw_br,
     kw_cfgfunc,
     kw_extfunc,
     kw_f16,
diff --git a/test/IR/parser-errors.mlir b/test/IR/parser-errors.mlir
index 288222b..e572fa2 100644
--- a/test/IR/parser-errors.mlir
+++ b/test/IR/parser-errors.mlir
@@ -6,18 +6,39 @@
 ; Check different error cases.
 ; TODO(jpienaar): This is checking the errors by simplify verifying the output.
 ; -----
+
 ; CHECK: expected type
 ; CHECK-NEXT: illegaltype
 extfunc @illegaltype(i42)
+
 ; -----
 ; CHECK: expected type
 ; CHECK-NEXT: nestedtensor
 extfunc @nestedtensor(tensor<tensor<i8>>) -> ()
+
 ; -----
 ; CHECK: expected '{' in CFG function
 cfgfunc @foo()
 cfgfunc @bar()
+
 ; -----
 ; CHECK: expected a function identifier like
 ; CHECK-NEXT: missingsigil
 extfunc missingsigil() -> (i1, int, f32)
+
+
+; -----
+
+cfgfunc @bad_branch() {
+bb42:
+  br missing  ; CHECK: error: reference to an undefined basic block 'missing'
+}
+
+; -----
+
+cfgfunc @block_redef() {
+bb42:
+  return
+bb42:        ; CHECK: error: redefinition of block 'bb42'
+  return
+}
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index fe46abf..e307a3a 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -40,7 +40,11 @@
 ; CHECK-LABEL: cfgfunc @multiblock() -> i32 {
 cfgfunc @multiblock() -> i32 {
 bb0:         ; CHECK: bb0:
-  return     ; CHECK: return
-bb4:         ; CHECK: bb1:
-  return     ; CHECK: return
+  return     ; CHECK:   return
+bb1:         ; CHECK: bb1:
+  br bb4     ; CHECK:   br bb3
+bb2:         ; CHECK: bb2:
+  br bb2     ; CHECK:   br bb2
+bb4:         ; CHECK: bb3:
+  return     ; CHECK:   return
 }            ; CHECK: }