Implement support for branch instruction operands.

PiperOrigin-RevId: 205666777
diff --git a/include/mlir/IR/BasicBlock.h b/include/mlir/IR/BasicBlock.h
index be64e4f..d59a83b 100644
--- a/include/mlir/IR/BasicBlock.h
+++ b/include/mlir/IR/BasicBlock.h
@@ -59,9 +59,12 @@
   reverse_args_iterator args_rend() const { return getArguments().rend(); }
 
   bool args_empty() const { return arguments.empty(); }
+
+  /// Add one value to the operand list.
   BBArgument *addArgument(Type *type);
-  llvm::iterator_range<BBArgListType::iterator>
-  addArguments(ArrayRef<Type *> types);
+
+  /// Add one argument to the argument list for each type specified in the list.
+  llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type *> types);
 
   unsigned getNumArguments() const { return arguments.size(); }
   BBArgument *getArgument(unsigned i) { return arguments[i]; }
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 1a003b6..12b8aed 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -186,7 +186,7 @@
     return op;
   }
 
-  // Creates for statement. When step is not specified, it is set to 1. 
+  // Creates for statement. When step is not specified, it is set to 1.
   ForStmt *createFor(AffineConstantExpr *lowerBound,
                      AffineConstantExpr *upperBound,
                      AffineConstantExpr *step = nullptr);
diff --git a/include/mlir/IR/Instructions.h b/include/mlir/IR/Instructions.h
index 6f374bd..02b4ffc 100644
--- a/include/mlir/IR/Instructions.h
+++ b/include/mlir/IR/Instructions.h
@@ -260,7 +260,31 @@
     return dest;
   }
 
-  // TODO: need to take operands to specify BB arguments
+  unsigned getNumOperands() const { return operands.size(); }
+
+  // TODO: Add a getOperands() custom sequence that provides a value projection
+  // of the operand list.
+  CFGValue *getOperand(unsigned idx) { return getInstOperand(idx).get(); }
+  const CFGValue *getOperand(unsigned idx) const {
+    return getInstOperand(idx).get();
+  }
+
+  ArrayRef<InstOperand> getInstOperands() const { return operands; }
+  MutableArrayRef<InstOperand> getInstOperands() { return operands; }
+
+  InstOperand &getInstOperand(unsigned idx) { return operands[idx]; }
+  const InstOperand &getInstOperand(unsigned idx) const {
+    return operands[idx];
+  }
+
+  /// Add one value to the operand list.
+  void addOperand(CFGValue *value);
+
+  /// Add a list of values to the operand list.
+  void addOperands(ArrayRef<CFGValue *> values);
+
+  /// Erase a specific argument from the arg list.
+  // TODO: void eraseArgument(int Index);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(const Instruction *inst) {
@@ -272,6 +296,7 @@
       : TerminatorInst(Kind::Branch), dest(dest) {}
 
   BasicBlock *dest;
+  std::vector<InstOperand> operands;
 };
 
 
diff --git a/include/mlir/IR/Module.h b/include/mlir/IR/Module.h
index 8401252..b0b9c96 100644
--- a/include/mlir/IR/Module.h
+++ b/include/mlir/IR/Module.h
@@ -37,6 +37,8 @@
 
   // FIXME: wrong representation and API.
   // TODO(someone): This should switch to llvm::iplist<Function>.
+  // TODO(someone): we also need a symbol table for function names +
+  // autorenaming like LLVM does.
   std::vector<Function*> functionList;
 
   /// Perform (potentially expensive) checks of invariants, used to detect
diff --git a/include/mlir/IR/SSAOperand.h b/include/mlir/IR/SSAOperand.h
index c8ab7d5..7685c31 100644
--- a/include/mlir/IR/SSAOperand.h
+++ b/include/mlir/IR/SSAOperand.h
@@ -61,6 +61,17 @@
   /// of the SSA machinery.
   SSAOperand *getNextOperandUsingThisValue() { return nextUse; }
 
+  /// We support a move constructor so SSAOperands can be in vectors, but this
+  /// shouldn't be used by general clients.
+  SSAOperand(SSAOperand &&other) {
+    other.removeFromCurrent();
+    value = other.value;
+    other.value = nullptr;
+    nextUse = nullptr;
+    back = nullptr;
+    insertIntoCurrent();
+  }
+
 private:
   /// The value used as this operand.  This can be null when in a
   /// "dropAllUses" state.
@@ -116,6 +127,11 @@
   /// Return which operand this is in the operand list of the User.
   // TODO:  unsigned getOperandNumber() const;
 
+  /// We support a move constructor so SSAOperands can be in vectors, but this
+  /// shouldn't be used by general clients.
+  SSAOperandImpl(SSAOperandImpl &&other)
+      : SSAOperand(std::move(other)), owner(other.owner) {}
+
 private:
   /// The owner of this operand.
   SSAOwnerTy *const owner;
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index afb4e73..64b5a2c 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -706,9 +706,23 @@
 void CFGFunctionPrinter::print(const OperationInst *inst) {
   printOperation(inst);
 }
+
 void CFGFunctionPrinter::print(const BranchInst *inst) {
   os << "  br bb" << getBBID(inst->getDest());
+
+  if (inst->getNumOperands() != 0) {
+    os << '(';
+    // TODO: Use getOperands() when we have it.
+    interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) {
+      printValueID(operand.get());
+    });
+    os << ") : ";
+    interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) {
+      ModulePrinter::print(operand.get()->getType());
+    });
+  }
 }
+
 void CFGFunctionPrinter::print(const ReturnInst *inst) {
   os << "  return";
 
diff --git a/lib/IR/BasicBlock.cpp b/lib/IR/BasicBlock.cpp
index 1aad430..7cb6440 100644
--- a/lib/IR/BasicBlock.cpp
+++ b/lib/IR/BasicBlock.cpp
@@ -35,6 +35,31 @@
   getFunction()->getBlocks().erase(this);
 }
 
+//===----------------------------------------------------------------------===//
+// Argument list management.
+//===----------------------------------------------------------------------===//
+
+BBArgument *BasicBlock::addArgument(Type *type) {
+  auto *arg = new BBArgument(type, this);
+  arguments.push_back(arg);
+  return arg;
+}
+
+/// Add one argument to the argument list for each type specified in the list.
+auto BasicBlock::addArguments(ArrayRef<Type *> types)
+    -> llvm::iterator_range<args_iterator> {
+  arguments.reserve(arguments.size() + types.size());
+  auto initialSize = arguments.size();
+  for (auto *type : types) {
+    addArgument(type);
+  }
+  return {arguments.data() + initialSize, arguments.data() + arguments.size()};
+}
+
+//===----------------------------------------------------------------------===//
+// Terminator management
+//===----------------------------------------------------------------------===//
+
 void BasicBlock::setTerminator(TerminatorInst *inst) {
   // If we already had a terminator, abandon it.
   if (terminator)
@@ -46,6 +71,10 @@
     inst->block = this;
 }
 
+//===----------------------------------------------------------------------===//
+// ilist_traits for BasicBlock
+//===----------------------------------------------------------------------===//
+
 mlir::CFGFunction *
 llvm::ilist_traits<::mlir::BasicBlock>::getContainingFunction() {
   size_t Offset(
@@ -86,17 +115,3 @@
   for (; first != last; ++first)
     first->function = curParent;
 }
-
-BBArgument *BasicBlock::addArgument(Type *type) {
-  arguments.push_back(new BBArgument(type, this));
-  return arguments.back();
-}
-
-llvm::iterator_range<BasicBlock::BBArgListType::iterator>
-BasicBlock::addArguments(ArrayRef<Type *> types) {
-  auto initial_size = arguments.size();
-  for (auto *type : types) {
-    addArgument(type);
-  }
-  return {arguments.data() + initial_size, arguments.data() + arguments.size()};
-}
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index 428523e..9374933 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -207,3 +207,15 @@
   for (auto &operand : getInstOperands())
     operand.~InstOperand();
 }
+
+/// Add one value to the operand list.
+void BranchInst::addOperand(CFGValue *value) {
+  operands.emplace_back(InstOperand(this, value));
+}
+
+/// Add a list of values to the operand list.
+void BranchInst::addOperands(ArrayRef<CFGValue *> values) {
+  operands.reserve(operands.size() + values.size());
+  for (auto *value : values)
+    addOperand(value);
+}
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index d1bb2ac..aaec3ce 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -100,6 +100,7 @@
   bool verifyOperation(const OperationInst &inst);
   bool verifyTerminator(const TerminatorInst &term);
   bool verifyReturn(const ReturnInst &inst);
+  bool verifyBranch(const BranchInst &inst);
 };
 } // end anonymous namespace
 
@@ -163,6 +164,9 @@
   if (auto *ret = dyn_cast<ReturnInst>(&term))
     return verifyReturn(*ret);
 
+  if (auto *br = dyn_cast<BranchInst>(&term))
+    return verifyBranch(*br);
+
   return false;
 }
 
@@ -175,6 +179,31 @@
                        Twine(results.size()),
                    inst);
 
+  for (unsigned i = 0, e = results.size(); i != e; ++i)
+    if (inst.getOperand(i)->getType() != results[i])
+      return failure("type of return operand " + Twine(i) +
+                         " doesn't match result function result type",
+                     inst);
+
+  return false;
+}
+
+bool CFGFuncVerifier::verifyBranch(const BranchInst &inst) {
+  // Verify that the number of operands lines up with the number of BB arguments
+  // in the successor.
+  auto dest = inst.getDest();
+  if (inst.getNumOperands() != dest->getNumArguments())
+    return failure("branch has " + Twine(inst.getNumOperands()) +
+                       " operands, but target block has " +
+                       Twine(dest->getNumArguments()),
+                   inst);
+
+  for (unsigned i = 0, e = inst.getNumOperands(); i != e; ++i)
+    if (inst.getOperand(i)->getType() != dest->getArgument(i)->getType())
+      return failure("type of branch operand " + Twine(i) +
+                         " doesn't match target bb argument type",
+                     inst);
+
   return false;
 }
 
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index bc26bda..678b5fc 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -162,6 +162,7 @@
   Type *parseMemRefType();
   Type *parseFunctionType();
   Type *parseType();
+  ParseResult parseTypeListNoParens(SmallVectorImpl<Type *> &elements);
   ParseResult parseTypeList(SmallVectorImpl<Type*> &elements);
 
   // Attribute parsing.
@@ -516,12 +517,27 @@
   }
 }
 
+/// Parse a list of types without an enclosing parenthesis.  The list must have
+/// at least one member.
+///
+///   type-list-no-parens ::=  type (`,` type)*
+///
+ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) {
+  auto parseElt = [&]() -> ParseResult {
+    auto elt = parseType();
+    elements.push_back(elt);
+    return elt ? ParseSuccess : ParseFailure;
+  };
+
+  return parseCommaSeparatedList(parseElt);
+}
+
 /// Parse a "type list", which is a singular type, or a parenthesized list of
 /// types.
 ///
 ///   type-list ::= type-list-parens | type
 ///   type-list-parens ::= `(` `)`
-///                      | `(` type (`,` type)* `)`
+///                      | `(` type-list-no-parens `)`
 ///
 ParseResult Parser::parseTypeList(SmallVectorImpl<Type*> &elements) {
   auto parseElt = [&]() -> ParseResult {
@@ -1706,7 +1722,7 @@
 /// Parse the terminator instruction for a basic block.
 ///
 ///   terminator-stmt ::= `br` bb-id branch-use-list?
-///   branch-use-list ::= `(` ssa-use-and-type-list? `)`
+///   branch-use-list ::= `(` ssa-use-list `)` ':' type-list-no-parens
 ///   terminator-stmt ::=
 ///     `cond_br` ssa-use `,` bb-id branch-use-list? `,` bb-id branch-use-list?
 ///   terminator-stmt ::= `return` ssa-use-and-type-list?
@@ -1730,7 +1746,40 @@
     auto destBB = getBlockNamed(getTokenSpelling(), getToken().getLoc());
     if (!consumeIf(Token::bare_identifier))
       return (emitError("expected basic block name"), nullptr);
-    return builder.createBranchInst(destBB);
+    auto branch = builder.createBranchInst(destBB);
+
+    // Parse the use list.
+    if (!consumeIf(Token::l_paren))
+      return branch;
+
+    SmallVector<SSAUseInfo, 4> valueIDs;
+    if (parseOptionalSSAUseList(valueIDs))
+      return nullptr;
+    if (!consumeIf(Token::r_paren))
+      return (emitError("expected ')' in branch argument list"), nullptr);
+    if (!consumeIf(Token::colon))
+      return (emitError("expected ':' in branch argument list"), nullptr);
+
+    auto typeLoc = getToken().getLoc();
+    SmallVector<Type *, 4> types;
+    if (parseTypeListNoParens(types))
+      return nullptr;
+
+    if (types.size() != valueIDs.size())
+      return (emitError(typeLoc, "expected " + Twine(valueIDs.size()) +
+                                     " types to match operand list"),
+              nullptr);
+
+    SmallVector<CFGValue *, 4> values;
+    values.reserve(valueIDs.size());
+    for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) {
+      if (auto *value = resolveSSAUse(valueIDs[i], types[i]))
+        values.push_back(cast<CFGValue>(value));
+      else
+        return nullptr;
+    }
+    branch->addOperands(values);
+    return branch;
   }
     // TODO: cond_br.
   }
diff --git a/test/IR/parser-errors.mlir b/test/IR/parser-errors.mlir
index c9086fe..fe7a49c 100644
--- a/test/IR/parser-errors.mlir
+++ b/test/IR/parser-errors.mlir
@@ -243,3 +243,16 @@
 bb42(%0: f32):
   return
 }
+
+// -----
+
+cfgfunc @br_mismatch() {  // expected-error {{branch has 2 operands, but target block has 1}}
+bb0:       // CHECK: bb0:
+  // CHECK: %0 = "foo"() : () -> (i1, i17)
+  %0 = "foo"() : () -> (i1, i17)
+  br bb1(%0#1, %0#0) : i17, i1
+
+bb1(%x: i17):
+  return
+}
+
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index a4d45c6..2f95957 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -66,16 +66,16 @@
 // CHECK: extfunc @functions((memref<1x?x4x?x?xaffineint, (d0, d1, d2, d3, d4) [s0] -> (d0, d1, d2, d3, d4), 0>, memref<i8, (d0) -> (d0), 0>) -> (), () -> ())
 extfunc @functions((memref<1x?x4x?x?xaffineint, #map0, 0>, memref<i8, #map1, 0>) -> (), ()->())
 
-// CHECK-LABEL: cfgfunc @simpleCFG(i32, f32) {
-cfgfunc @simpleCFG(i32, f32) {
+// CHECK-LABEL: cfgfunc @simpleCFG(i32, f32) -> i1 {
+cfgfunc @simpleCFG(i32, f32) -> i1 {
 // CHECK: bb0(%0: i32, %1: f32):
 bb42 (%0: i32, %f: f32):
   // CHECK: %2 = "foo"() : () -> i64
   %1 = "foo"() : ()->i64
   // CHECK: "bar"(%2) : (i64) -> (i1, i1, i1)
   %2 = "bar"(%1) : (i64) -> (i1,i1,i1)
-  // CHECK: return
-  return
+  // CHECK: return %3#1
+  return %2#1 : i1
 // CHECK: }
 }
 
@@ -208,4 +208,17 @@
   // CHECK: %2 = "bar"(%0#0, %0#1) : (i1, i17) -> (i11, f32)
   %2 = "bar"(%0#0, %0#1) : (i1, i17) -> (i11, f32)
   br bb1
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL: cfgfunc @bbargs() -> (i16, i8) {
+cfgfunc @bbargs() -> (i16, i8) {
+bb0:       // CHECK: bb0:
+  // CHECK: %0 = "foo"() : () -> (i1, i17)
+  %0 = "foo"() : () -> (i1, i17)
+  br bb1(%0#1, %0#0) : i17, i1
+
+bb1(%x: i17, %y: i1):       // CHECK: bb1(%1: i17, %2: i1):
+  // CHECK: %3 = "baz"(%1, %2, %0#1) : (i17, i1, i17) -> (i16, i8)
+  %1 = "baz"(%x, %y, %0#1) : (i17, i1, i17) -> (i16, i8)
+  return %1#0 : i16, %1#1 : i8
+}