Implement support for branch instruction operands.

PiperOrigin-RevId: 205666777
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index afb4e73..64b5a2c 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -706,9 +706,23 @@
 void CFGFunctionPrinter::print(const OperationInst *inst) {
   printOperation(inst);
 }
+
 void CFGFunctionPrinter::print(const BranchInst *inst) {
   os << "  br bb" << getBBID(inst->getDest());
+
+  if (inst->getNumOperands() != 0) {
+    os << '(';
+    // TODO: Use getOperands() when we have it.
+    interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) {
+      printValueID(operand.get());
+    });
+    os << ") : ";
+    interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) {
+      ModulePrinter::print(operand.get()->getType());
+    });
+  }
 }
+
 void CFGFunctionPrinter::print(const ReturnInst *inst) {
   os << "  return";
 
diff --git a/lib/IR/BasicBlock.cpp b/lib/IR/BasicBlock.cpp
index 1aad430..7cb6440 100644
--- a/lib/IR/BasicBlock.cpp
+++ b/lib/IR/BasicBlock.cpp
@@ -35,6 +35,31 @@
   getFunction()->getBlocks().erase(this);
 }
 
+//===----------------------------------------------------------------------===//
+// Argument list management.
+//===----------------------------------------------------------------------===//
+
+BBArgument *BasicBlock::addArgument(Type *type) {
+  auto *arg = new BBArgument(type, this);
+  arguments.push_back(arg);
+  return arg;
+}
+
+/// Add one argument to the argument list for each type specified in the list.
+auto BasicBlock::addArguments(ArrayRef<Type *> types)
+    -> llvm::iterator_range<args_iterator> {
+  arguments.reserve(arguments.size() + types.size());
+  auto initialSize = arguments.size();
+  for (auto *type : types) {
+    addArgument(type);
+  }
+  return {arguments.data() + initialSize, arguments.data() + arguments.size()};
+}
+
+//===----------------------------------------------------------------------===//
+// Terminator management
+//===----------------------------------------------------------------------===//
+
 void BasicBlock::setTerminator(TerminatorInst *inst) {
   // If we already had a terminator, abandon it.
   if (terminator)
@@ -46,6 +71,10 @@
     inst->block = this;
 }
 
+//===----------------------------------------------------------------------===//
+// ilist_traits for BasicBlock
+//===----------------------------------------------------------------------===//
+
 mlir::CFGFunction *
 llvm::ilist_traits<::mlir::BasicBlock>::getContainingFunction() {
   size_t Offset(
@@ -86,17 +115,3 @@
   for (; first != last; ++first)
     first->function = curParent;
 }
-
-BBArgument *BasicBlock::addArgument(Type *type) {
-  arguments.push_back(new BBArgument(type, this));
-  return arguments.back();
-}
-
-llvm::iterator_range<BasicBlock::BBArgListType::iterator>
-BasicBlock::addArguments(ArrayRef<Type *> types) {
-  auto initial_size = arguments.size();
-  for (auto *type : types) {
-    addArgument(type);
-  }
-  return {arguments.data() + initial_size, arguments.data() + arguments.size()};
-}
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index 428523e..9374933 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -207,3 +207,15 @@
   for (auto &operand : getInstOperands())
     operand.~InstOperand();
 }
+
+/// Add one value to the operand list.
+void BranchInst::addOperand(CFGValue *value) {
+  operands.emplace_back(InstOperand(this, value));
+}
+
+/// Add a list of values to the operand list.
+void BranchInst::addOperands(ArrayRef<CFGValue *> values) {
+  operands.reserve(operands.size() + values.size());
+  for (auto *value : values)
+    addOperand(value);
+}
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index d1bb2ac..aaec3ce 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -100,6 +100,7 @@
   bool verifyOperation(const OperationInst &inst);
   bool verifyTerminator(const TerminatorInst &term);
   bool verifyReturn(const ReturnInst &inst);
+  bool verifyBranch(const BranchInst &inst);
 };
 } // end anonymous namespace
 
@@ -163,6 +164,9 @@
   if (auto *ret = dyn_cast<ReturnInst>(&term))
     return verifyReturn(*ret);
 
+  if (auto *br = dyn_cast<BranchInst>(&term))
+    return verifyBranch(*br);
+
   return false;
 }
 
@@ -175,6 +179,31 @@
                        Twine(results.size()),
                    inst);
 
+  for (unsigned i = 0, e = results.size(); i != e; ++i)
+    if (inst.getOperand(i)->getType() != results[i])
+      return failure("type of return operand " + Twine(i) +
+                         " doesn't match result function result type",
+                     inst);
+
+  return false;
+}
+
+bool CFGFuncVerifier::verifyBranch(const BranchInst &inst) {
+  // Verify that the number of operands lines up with the number of BB arguments
+  // in the successor.
+  auto dest = inst.getDest();
+  if (inst.getNumOperands() != dest->getNumArguments())
+    return failure("branch has " + Twine(inst.getNumOperands()) +
+                       " operands, but target block has " +
+                       Twine(dest->getNumArguments()),
+                   inst);
+
+  for (unsigned i = 0, e = inst.getNumOperands(); i != e; ++i)
+    if (inst.getOperand(i)->getType() != dest->getArgument(i)->getType())
+      return failure("type of branch operand " + Twine(i) +
+                         " doesn't match target bb argument type",
+                     inst);
+
   return false;
 }
 
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index bc26bda..678b5fc 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -162,6 +162,7 @@
   Type *parseMemRefType();
   Type *parseFunctionType();
   Type *parseType();
+  ParseResult parseTypeListNoParens(SmallVectorImpl<Type *> &elements);
   ParseResult parseTypeList(SmallVectorImpl<Type*> &elements);
 
   // Attribute parsing.
@@ -516,12 +517,27 @@
   }
 }
 
+/// Parse a list of types without an enclosing parenthesis.  The list must have
+/// at least one member.
+///
+///   type-list-no-parens ::=  type (`,` type)*
+///
+ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) {
+  auto parseElt = [&]() -> ParseResult {
+    auto elt = parseType();
+    elements.push_back(elt);
+    return elt ? ParseSuccess : ParseFailure;
+  };
+
+  return parseCommaSeparatedList(parseElt);
+}
+
 /// Parse a "type list", which is a singular type, or a parenthesized list of
 /// types.
 ///
 ///   type-list ::= type-list-parens | type
 ///   type-list-parens ::= `(` `)`
-///                      | `(` type (`,` type)* `)`
+///                      | `(` type-list-no-parens `)`
 ///
 ParseResult Parser::parseTypeList(SmallVectorImpl<Type*> &elements) {
   auto parseElt = [&]() -> ParseResult {
@@ -1706,7 +1722,7 @@
 /// Parse the terminator instruction for a basic block.
 ///
 ///   terminator-stmt ::= `br` bb-id branch-use-list?
-///   branch-use-list ::= `(` ssa-use-and-type-list? `)`
+///   branch-use-list ::= `(` ssa-use-list `)` ':' type-list-no-parens
 ///   terminator-stmt ::=
 ///     `cond_br` ssa-use `,` bb-id branch-use-list? `,` bb-id branch-use-list?
 ///   terminator-stmt ::= `return` ssa-use-and-type-list?
@@ -1730,7 +1746,40 @@
     auto destBB = getBlockNamed(getTokenSpelling(), getToken().getLoc());
     if (!consumeIf(Token::bare_identifier))
       return (emitError("expected basic block name"), nullptr);
-    return builder.createBranchInst(destBB);
+    auto branch = builder.createBranchInst(destBB);
+
+    // Parse the use list.
+    if (!consumeIf(Token::l_paren))
+      return branch;
+
+    SmallVector<SSAUseInfo, 4> valueIDs;
+    if (parseOptionalSSAUseList(valueIDs))
+      return nullptr;
+    if (!consumeIf(Token::r_paren))
+      return (emitError("expected ')' in branch argument list"), nullptr);
+    if (!consumeIf(Token::colon))
+      return (emitError("expected ':' in branch argument list"), nullptr);
+
+    auto typeLoc = getToken().getLoc();
+    SmallVector<Type *, 4> types;
+    if (parseTypeListNoParens(types))
+      return nullptr;
+
+    if (types.size() != valueIDs.size())
+      return (emitError(typeLoc, "expected " + Twine(valueIDs.size()) +
+                                     " types to match operand list"),
+              nullptr);
+
+    SmallVector<CFGValue *, 4> values;
+    values.reserve(valueIDs.size());
+    for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) {
+      if (auto *value = resolveSSAUse(valueIDs[i], types[i]))
+        values.push_back(cast<CFGValue>(value));
+      else
+        return nullptr;
+    }
+    branch->addOperands(values);
+    return branch;
   }
     // TODO: cond_br.
   }