[mlir] Implement conditional branch

This looks heavyweight but most of the code is in the massive number of operand accessors!

We need to be able to iterate over all operands to the condbr (all live-outs) but also just
the true/just the false operands too.

PiperOrigin-RevId: 205897704
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index e6e4e4d..ec9ce18 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -146,6 +146,13 @@
     return insertTerminator(BranchInst::create(dest));
   }
 
+  CondBranchInst *createCondBranchInst(CFGValue *condition,
+                                       BasicBlock *trueDest,
+                                       BasicBlock *falseDest) {
+    return insertTerminator(
+        CondBranchInst::create(condition, trueDest, falseDest));
+  }
+
 private:
   template <typename T>
   T *insertTerminator(T *term) {
diff --git a/include/mlir/IR/Instructions.h b/include/mlir/IR/Instructions.h
index be743e6..cbd9578 100644
--- a/include/mlir/IR/Instructions.h
+++ b/include/mlir/IR/Instructions.h
@@ -31,26 +31,20 @@
 #include "llvm/Support/TrailingObjects.h"
 
 namespace mlir {
-  class OperationInst;
-  class BasicBlock;
-  class CFGFunction;
+class OperationInst;
+class BasicBlock;
+class CFGFunction;
 
 /// Instruction is the root of the operation and terminator instructions in the
 /// hierarchy.
 class Instruction {
 public:
-  enum class Kind {
-    Operation,
-    Branch,
-    Return
-  };
+  enum class Kind { Operation, Branch, CondBranch, Return };
 
   Kind getKind() const { return kind; }
 
   /// Return the BasicBlock containing this instruction.
-  BasicBlock *getBlock() const {
-    return block;
-  }
+  BasicBlock *getBlock() const { return block; }
 
   /// Return the CFGFunction containing this instruction.
   CFGFunction *getFunction() const;
@@ -122,6 +116,7 @@
   // every instruction by a word, is not necessary given the closed nature of
   // instruction kinds.
   ~Instruction();
+
 private:
   Kind kind;
   BasicBlock *block = nullptr;
@@ -290,7 +285,6 @@
 /// represent control flow and returns.
 class TerminatorInst : public Instruction {
 public:
-
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(const Instruction *inst) {
     return inst->getKind() != Kind::Operation;
@@ -312,9 +306,7 @@
   ~BranchInst() {}
 
   /// Return the block this branch jumps to.
-  BasicBlock *getDest() const {
-    return dest;
-  }
+  BasicBlock *getDest() const { return dest; }
 
   unsigned getNumOperands() const { return operands.size(); }
 
@@ -385,6 +377,202 @@
   std::vector<InstOperand> operands;
 };
 
+/// The 'cond_br' instruction is a conditional branch based on a boolean
+/// condition to one of two possible successors. It may pass arguments to each
+/// successor.
+class CondBranchInst : public TerminatorInst {
+public:
+  static CondBranchInst *create(CFGValue *condition, BasicBlock *trueDest,
+                                BasicBlock *falseDest) {
+    return new CondBranchInst(condition, trueDest, falseDest);
+  }
+  ~CondBranchInst() {}
+
+  /// Return the i1 condition.
+  CFGValue *getCondition() { return condition; }
+  const CFGValue *getCondition() const { return condition; }
+
+  /// Return the destination if the condition is true.
+  BasicBlock *getTrueDest() const { return trueDest; }
+
+  /// Return the destination if the condition is false.
+  BasicBlock *getFalseDest() const { return falseDest; }
+
+  // Support non-const operand iteration.
+  using operand_iterator = OperandIterator<CondBranchInst, CFGValue>;
+  // Support const operand iteration.
+  typedef OperandIterator<const CondBranchInst, const CFGValue>
+      const_operand_iterator;
+
+  //
+  // Accessors for the entire operand list. This includes operands to both true
+  // and false blocks.
+  //
+
+  CFGValue *getOperand(unsigned idx) { return getInstOperand(idx).get(); }
+  const CFGValue *getOperand(unsigned idx) const {
+    return getInstOperand(idx).get();
+  }
+  void setOperand(unsigned idx, CFGValue *value) {
+    return getInstOperand(idx).set(value);
+  }
+
+  operand_iterator operand_begin() { return operand_iterator(this, 0); }
+  operand_iterator operand_end() {
+    return operand_iterator(this, getNumOperands());
+  }
+  llvm::iterator_range<operand_iterator> getOperands() {
+    return {operand_begin(), operand_end()};
+  }
+
+  const_operand_iterator operand_begin() const {
+    return const_operand_iterator(this, 0);
+  }
+  const_operand_iterator operand_end() const {
+    return const_operand_iterator(this, getNumOperands());
+  }
+  llvm::iterator_range<const_operand_iterator> getOperands() const {
+    return {operand_begin(), operand_end()};
+  }
+
+  ArrayRef<InstOperand> getInstOperands() const { return operands; }
+  MutableArrayRef<InstOperand> getInstOperands() { return operands; }
+
+  InstOperand &getInstOperand(unsigned idx) { return operands[idx]; }
+  const InstOperand &getInstOperand(unsigned idx) const {
+    return operands[idx];
+  }
+  unsigned getNumOperands() const { return operands.size(); }
+
+  //
+  // Accessors for operands to the 'true' destination
+  //
+
+  CFGValue *getTrueOperand(unsigned idx) {
+    return getTrueInstOperand(idx).get();
+  }
+  const CFGValue *getTrueOperand(unsigned idx) const {
+    return getTrueInstOperand(idx).get();
+  }
+  void setTrueOperand(unsigned idx, CFGValue *value) {
+    return getTrueInstOperand(idx).set(value);
+  }
+
+  operand_iterator true_operand_begin() { return operand_iterator(this, 0); }
+  operand_iterator true_operand_end() {
+    return operand_iterator(this, getNumTrueOperands());
+  }
+  llvm::iterator_range<operand_iterator> getTrueOperands() {
+    return {true_operand_begin(), true_operand_end()};
+  }
+
+  const_operand_iterator true_operand_begin() const {
+    return const_operand_iterator(this, 0);
+  }
+  const_operand_iterator true_operand_end() const {
+    return const_operand_iterator(this, getNumTrueOperands());
+  }
+  llvm::iterator_range<const_operand_iterator> getTrueOperands() const {
+    return {true_operand_begin(), true_operand_end()};
+  }
+
+  ArrayRef<InstOperand> getTrueInstOperands() const {
+    return {&operands[0], &operands[0] + getNumTrueOperands()};
+  }
+  MutableArrayRef<InstOperand> getTrueInstOperands() {
+    return {&operands[0], &operands[0] + getNumTrueOperands()};
+  }
+
+  InstOperand &getTrueInstOperand(unsigned idx) { return operands[idx]; }
+  const InstOperand &getTrueInstOperand(unsigned idx) const {
+    return operands[idx];
+  }
+  unsigned getNumTrueOperands() const { return numTrueOperands; }
+
+  /// Add one value to the true operand list.
+  void addTrueOperand(CFGValue *value);
+
+  /// Add a list of values to the operand list.
+  void addTrueOperands(ArrayRef<CFGValue *> values);
+
+  //
+  // Accessors for operands to the 'false' destination
+  //
+
+  CFGValue *getFalseOperand(unsigned idx) {
+    return getFalseInstOperand(idx).get();
+  }
+  const CFGValue *getFalseOperand(unsigned idx) const {
+    return getFalseInstOperand(idx).get();
+  }
+  void setFalseOperand(unsigned idx, CFGValue *value) {
+    return getFalseInstOperand(idx).set(value);
+  }
+
+  operand_iterator false_operand_begin() {
+    return operand_iterator(this, getNumTrueOperands());
+  }
+  operand_iterator false_operand_end() {
+    return operand_iterator(this, getNumOperands());
+  }
+  llvm::iterator_range<operand_iterator> getFalseOperands() {
+    return {false_operand_begin(), false_operand_end()};
+  }
+
+  const_operand_iterator false_operand_begin() const {
+    return const_operand_iterator(this, getNumTrueOperands());
+  }
+  const_operand_iterator false_operand_end() const {
+    return const_operand_iterator(this, getNumOperands());
+  }
+  llvm::iterator_range<const_operand_iterator> getFalseOperands() const {
+    return {false_operand_begin(), false_operand_end()};
+  }
+
+  ArrayRef<InstOperand> getFalseInstOperands() const {
+    return {&operands[0] + getNumTrueOperands(),
+            &operands[0] + getNumOperands()};
+  }
+  MutableArrayRef<InstOperand> getFalseInstOperands() {
+    return {&operands[0] + getNumTrueOperands(),
+            &operands[0] + getNumOperands()};
+  }
+
+  InstOperand &getFalseInstOperand(unsigned idx) {
+    return operands[idx + getNumTrueOperands()];
+  }
+  const InstOperand &getFalseInstOperand(unsigned idx) const {
+    return operands[idx + getNumTrueOperands()];
+  }
+  unsigned getNumFalseOperands() const {
+    return operands.size() - numTrueOperands;
+  }
+
+  /// Add one value to the false operand list.
+  void addFalseOperand(CFGValue *value);
+
+  /// Add a list of values to the operand list.
+  void addFalseOperands(ArrayRef<CFGValue *> values);
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool classof(const Instruction *inst) {
+    return inst->getKind() == Kind::CondBranch;
+  }
+
+private:
+  explicit CondBranchInst(CFGValue *condition, BasicBlock *trueDest,
+                          BasicBlock *falseDest)
+      : TerminatorInst(Kind::CondBranch), condition(condition),
+        trueDest(trueDest), falseDest(falseDest), numTrueOperands(0) {}
+
+  CFGValue *condition;
+  BasicBlock *trueDest;
+  BasicBlock *falseDest;
+  // Operand list. The true operands are stored first, followed by the false
+  // operands.
+  std::vector<InstOperand> operands;
+  unsigned numTrueOperands;
+};
 
 /// The 'return' instruction represents the end of control flow within the
 /// current function, and can return zero or more results.  The result list is
@@ -471,7 +659,6 @@
 
 } // end namespace mlir
 
-
 //===----------------------------------------------------------------------===//
 // ilist_traits for OperationInst
 //===----------------------------------------------------------------------===//
@@ -489,10 +676,11 @@
   void removeNodeFromList(OperationInst *inst);
   void transferNodesFromList(ilist_traits<OperationInst> &otherList,
                              instr_iterator first, instr_iterator last);
+
 private:
   mlir::BasicBlock *getContainingBlock();
 };
 
 } // end namespace llvm
 
-#endif  // MLIR_IR_INSTRUCTIONS_H
+#endif // MLIR_IR_INSTRUCTIONS_H
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 4a21b31..1d61213 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -619,6 +619,7 @@
   void print(const OperationInst *inst);
   void print(const ReturnInst *inst);
   void print(const BranchInst *inst);
+  void print(const CondBranchInst *inst);
 
   unsigned getBBID(const BasicBlock *block) {
     auto it = basicBlockIDs.find(block);
@@ -699,6 +700,8 @@
     return print(cast<OperationInst>(inst));
   case TerminatorInst::Kind::Branch:
     return print(cast<BranchInst>(inst));
+  case TerminatorInst::Kind::CondBranch:
+    return print(cast<CondBranchInst>(inst));
   case TerminatorInst::Kind::Return:
     return print(cast<ReturnInst>(inst));
   }
@@ -724,15 +727,45 @@
   }
 }
 
+void CFGFunctionPrinter::print(const CondBranchInst *inst) {
+  os << "  cond_br ";
+  printValueID(inst->getCondition());
+
+  os << ", bb" << getBBID(inst->getTrueDest());
+  if (inst->getNumTrueOperands() != 0) {
+    os << '(';
+    interleaveComma(inst->getTrueOperands(),
+                    [&](const CFGValue *operand) { printValueID(operand); });
+    os << " : ";
+    interleaveComma(inst->getTrueOperands(), [&](const CFGValue *operand) {
+      ModulePrinter::print(operand->getType());
+    });
+    os << ")";
+  }
+
+  os << ", bb" << getBBID(inst->getFalseDest());
+  if (inst->getNumFalseOperands() != 0) {
+    os << '(';
+    interleaveComma(inst->getFalseOperands(),
+                    [&](const CFGValue *operand) { printValueID(operand); });
+    os << " : ";
+    interleaveComma(inst->getFalseOperands(), [&](const CFGValue *operand) {
+      ModulePrinter::print(operand->getType());
+    });
+    os << ")";
+  }
+}
+
 void CFGFunctionPrinter::print(const ReturnInst *inst) {
   os << "  return";
 
   if (inst->getNumOperands() != 0)
     os << ' ';
 
+  interleaveComma(inst->getOperands(),
+                  [&](const CFGValue *operand) { printValueID(operand); });
+  os << " : ";
   interleaveComma(inst->getOperands(), [&](const CFGValue *operand) {
-    printValueID(operand);
-    os << " : ";
     ModulePrinter::print(operand->getType());
   });
 }
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index 67e87a4..8cbdabc 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -55,6 +55,9 @@
   case Kind::Branch:
     delete cast<BranchInst>(this);
     break;
+  case Kind::CondBranch:
+    delete cast<CondBranchInst>(this);
+    break;
   case Kind::Return:
     cast<ReturnInst>(this)->destroy();
     break;
@@ -76,6 +79,8 @@
     return cast<OperationInst>(this)->getNumOperands();
   case Kind::Branch:
     return cast<BranchInst>(this)->getNumOperands();
+  case Kind::CondBranch:
+    return cast<CondBranchInst>(this)->getNumOperands();
   case Kind::Return:
     return cast<ReturnInst>(this)->getNumOperands();
   }
@@ -87,6 +92,8 @@
     return cast<OperationInst>(this)->getInstOperands();
   case Kind::Branch:
     return cast<BranchInst>(this)->getInstOperands();
+  case Kind::CondBranch:
+    return cast<CondBranchInst>(this)->getInstOperands();
   case Kind::Return:
     return cast<ReturnInst>(this)->getInstOperands();
   }
@@ -125,7 +132,7 @@
                              unsigned numResults,
                              ArrayRef<NamedAttribute> attributes,
                              MLIRContext *context)
-    : Operation(name, /*isInstruction=*/ true, attributes, context),
+    : Operation(name, /*isInstruction=*/true, attributes, context),
       Instruction(Kind::Operation), numOperands(numOperands),
       numResults(numResults) {}
 
@@ -144,30 +151,30 @@
       size_t(&((BasicBlock *)nullptr->*BasicBlock::getSublistAccess(nullptr))));
   iplist<OperationInst> *Anchor(static_cast<iplist<OperationInst> *>(this));
   return reinterpret_cast<BasicBlock *>(reinterpret_cast<char *>(Anchor) -
-                                           Offset);
+                                        Offset);
 }
 
 /// This is a trait method invoked when an instruction is added to a block.  We
 /// keep the block pointer up to date.
-void llvm::ilist_traits<::mlir::OperationInst>::
-addNodeToList(OperationInst *inst) {
+void llvm::ilist_traits<::mlir::OperationInst>::addNodeToList(
+    OperationInst *inst) {
   assert(!inst->getBlock() && "already in a basic block!");
   inst->block = getContainingBlock();
 }
 
 /// This is a trait method invoked when an instruction is removed from a block.
 /// We keep the block pointer up to date.
-void llvm::ilist_traits<::mlir::OperationInst>::
-removeNodeFromList(OperationInst *inst) {
+void llvm::ilist_traits<::mlir::OperationInst>::removeNodeFromList(
+    OperationInst *inst) {
   assert(inst->block && "not already in a basic block!");
   inst->block = nullptr;
 }
 
 /// This is a trait method invoked when an instruction is moved from one block
 /// to another.  We keep the block pointer up to date.
-void llvm::ilist_traits<::mlir::OperationInst>::
-transferNodesFromList(ilist_traits<OperationInst> &otherList,
-                      instr_iterator first, instr_iterator last) {
+void llvm::ilist_traits<::mlir::OperationInst>::transferNodesFromList(
+    ilist_traits<OperationInst> &otherList, instr_iterator first,
+    instr_iterator last) {
   // If we are transferring instructions within the same basic block, the block
   // pointer doesn't need to be updated.
   BasicBlock *curParent = getContainingBlock();
@@ -241,3 +248,30 @@
   for (auto *value : values)
     addOperand(value);
 }
+
+/// Add one value to the true operand list.
+void CondBranchInst::addTrueOperand(CFGValue *value) {
+  assert(getNumFalseOperands() == 0 &&
+         "Must insert all true operands before false operands!");
+  operands.emplace_back(InstOperand(this, value));
+  ++numTrueOperands;
+}
+
+/// Add a list of values to the true operand list.
+void CondBranchInst::addTrueOperands(ArrayRef<CFGValue *> values) {
+  operands.reserve(operands.size() + values.size());
+  for (auto *value : values)
+    addTrueOperand(value);
+}
+
+/// Add one value to the false operand list.
+void CondBranchInst::addFalseOperand(CFGValue *value) {
+  operands.emplace_back(InstOperand(this, value));
+}
+
+/// Add a list of values to the false operand list.
+void CondBranchInst::addFalseOperands(ArrayRef<CFGValue *> values) {
+  operands.reserve(operands.size() + values.size());
+  for (auto *value : values)
+    addFalseOperand(value);
+}
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index cb622a2..96f7f73 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -101,6 +101,17 @@
   bool verifyTerminator(const TerminatorInst &term);
   bool verifyReturn(const ReturnInst &inst);
   bool verifyBranch(const BranchInst &inst);
+  bool verifyCondBranch(const CondBranchInst &inst);
+
+  // Given a list of "operands" and "arguments" that are the same length, verify
+  // that the types of operands pointwise match argument types. The iterator
+  // types must expose the "getType()" function when dereferenced twice; that
+  // is, the iterator's value_type must be equivalent to SSAValue*.
+  template <typename OperandIteratorTy, typename ArgumentIteratorTy>
+  bool verifyOperandsMatchArguments(OperandIteratorTy opBegin,
+                                    OperandIteratorTy opEnd,
+                                    ArgumentIteratorTy argBegin,
+                                    const Instruction &instContext);
 };
 } // end anonymous namespace
 
@@ -167,6 +178,9 @@
   if (auto *br = dyn_cast<BranchInst>(&term))
     return verifyBranch(*br);
 
+  if (auto *br = dyn_cast<CondBranchInst>(&term))
+    return verifyCondBranch(*br);
+
   return false;
 }
 
@@ -207,6 +221,55 @@
   return false;
 }
 
+template <typename OperandIteratorTy, typename ArgumentIteratorTy>
+bool CFGFuncVerifier::verifyOperandsMatchArguments(
+    OperandIteratorTy opBegin, OperandIteratorTy opEnd,
+    ArgumentIteratorTy argBegin, const Instruction &instContext) {
+  OperandIteratorTy opIt = opBegin;
+  ArgumentIteratorTy argIt = argBegin;
+  for (; opIt != opEnd; ++opIt, ++argIt) {
+    if ((*opIt)->getType() != (*argIt)->getType())
+      return failure("type of operand " + Twine(std::distance(opBegin, opIt)) +
+                         " doesn't match argument type",
+                     instContext);
+  }
+  return false;
+}
+
+bool CFGFuncVerifier::verifyCondBranch(const CondBranchInst &inst) {
+  // Verify that the number of operands lines up with the number of BB arguments
+  // in the true successor.
+  auto trueDest = inst.getTrueDest();
+  if (inst.getNumTrueOperands() != trueDest->getNumArguments())
+    return failure("branch has " + Twine(inst.getNumTrueOperands()) +
+                       " true operands, but true target block has " +
+                       Twine(trueDest->getNumArguments()),
+                   inst);
+
+  if (verifyOperandsMatchArguments(inst.true_operand_begin(),
+                                   inst.true_operand_end(),
+                                   trueDest->args_begin(), inst))
+    return true;
+
+  // And the false successor.
+  auto falseDest = inst.getFalseDest();
+  if (inst.getNumFalseOperands() != falseDest->getNumArguments())
+    return failure("branch has " + Twine(inst.getNumFalseOperands()) +
+                       " false operands, but false target block has " +
+                       Twine(falseDest->getNumArguments()),
+                   inst);
+
+  if (verifyOperandsMatchArguments(inst.false_operand_begin(),
+                                   inst.false_operand_end(),
+                                   falseDest->args_begin(), inst))
+    return true;
+
+  if (inst.getCondition()->getType() != Type::getInteger(1, fn.getContext()))
+    return failure("type of condition is not boolean (i1)", inst);
+
+  return false;
+}
+
 bool CFGFuncVerifier::verifyOperation(const OperationInst &inst) {
   if (inst.getFunction() != &fn)
     return failure("operation in the wrong function", inst);
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index daf20db..3c655d1 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1632,6 +1632,8 @@
   ParseResult
   parseOptionalBasicBlockArgList(SmallVectorImpl<BBArgument *> &results,
                                  BasicBlock *owner);
+  ParseResult parseBranchBlockAndUseList(BasicBlock *&block,
+                                         SmallVectorImpl<CFGValue *> &values);
 
   ParseResult parseBasicBlock();
   OperationInst *parseCFGOperation();
@@ -1738,7 +1740,7 @@
   };
 
   // Parse the list of operations that make up the body of the block.
-  while (getToken().isNot(Token::kw_return, Token::kw_br)) {
+  while (getToken().isNot(Token::kw_return, Token::kw_br, Token::kw_cond_br)) {
     if (parseOperation(createOpFunc))
       return ParseFailure;
   }
@@ -1749,6 +1751,20 @@
   return ParseSuccess;
 }
 
+ParseResult CFGFunctionParser::parseBranchBlockAndUseList(
+    BasicBlock *&block, SmallVectorImpl<CFGValue *> &values) {
+  block = getBlockNamed(getTokenSpelling(), getToken().getLoc());
+  if (parseToken(Token::bare_identifier, "expected basic block name"))
+    return ParseFailure;
+
+  if (!consumeIf(Token::l_paren))
+    return ParseSuccess;
+  if (parseOptionalSSAUseAndTypeList(values, /*isParenthesized*/ false) ||
+      parseToken(Token::r_paren, "expected ')' to close argument list"))
+    return ParseFailure;
+  return ParseSuccess;
+}
+
 /// Parse the terminator instruction for a basic block.
 ///
 ///   terminator-stmt ::= `br` bb-id branch-use-list?
@@ -1774,19 +1790,45 @@
 
   case Token::kw_br: {
     consumeToken(Token::kw_br);
-    auto destBB = getBlockNamed(getTokenSpelling(), getToken().getLoc());
-    if (parseToken(Token::bare_identifier, "expected basic block name"))
+    BasicBlock *destBB;
+    SmallVector<CFGValue *, 4> values;
+    if (parseBranchBlockAndUseList(destBB, values))
       return nullptr;
-
     auto branch = builder.createBranchInst(destBB);
-
-    SmallVector<CFGValue *, 8> operands;
-    if (parseOptionalSSAUseAndTypeList(operands, /*isParenthesized*/ true))
-      return nullptr;
-    branch->addOperands(operands);
+    branch->addOperands(values);
     return branch;
   }
-    // TODO: cond_br.
+
+  case Token::kw_cond_br: {
+    consumeToken(Token::kw_cond_br);
+    SSAUseInfo ssaUse;
+    if (parseSSAUse(ssaUse))
+      return nullptr;
+    auto *cond = resolveSSAUse(ssaUse, builder.getIntegerType(1));
+    if (!cond)
+      return (emitError("expected type was boolean (i1)"), nullptr);
+    if (parseToken(Token::comma, "expected ',' in conditional branch"))
+      return nullptr;
+
+    BasicBlock *trueBlock;
+    SmallVector<CFGValue *, 4> trueOperands;
+    if (parseBranchBlockAndUseList(trueBlock, trueOperands))
+      return nullptr;
+
+    if (parseToken(Token::comma, "expected ',' in conditional branch"))
+      return nullptr;
+
+    BasicBlock *falseBlock;
+    SmallVector<CFGValue *, 4> falseOperands;
+    if (parseBranchBlockAndUseList(falseBlock, falseOperands))
+      return nullptr;
+
+    auto branch = builder.createCondBranchInst(cast<CFGValue>(cond), trueBlock,
+                                               falseBlock);
+    branch->addTrueOperands(trueOperands);
+    branch->addFalseOperands(falseOperands);
+    return branch;
+  }
   }
 }
 
diff --git a/lib/Parser/TokenKinds.def b/lib/Parser/TokenKinds.def
index de6758c..b9ef9b05 100644
--- a/lib/Parser/TokenKinds.def
+++ b/lib/Parser/TokenKinds.def
@@ -90,6 +90,7 @@
 TOK_KEYWORD(br)
 TOK_KEYWORD(ceildiv)
 TOK_KEYWORD(cfgfunc)
+TOK_KEYWORD(cond_br)
 TOK_KEYWORD(else)
 TOK_KEYWORD(extfunc)
 TOK_KEYWORD(f16)
diff --git a/test/IR/invalid.mlir b/test/IR/invalid.mlir
index f9582e8..c82f0ec 100644
--- a/test/IR/invalid.mlir
+++ b/test/IR/invalid.mlir
@@ -241,7 +241,7 @@
 cfgfunc @br_mismatch() {  // expected-error {{branch has 2 operands, but target block has 1}}
 bb0:
   %0 = "foo"() : () -> (i1, i17)
-  br bb1(%0#1, %0#0) : i17, i1
+  br bb1(%0#1, %0#0 : i17, i1)
 
 bb1(%x: i17):
   return
@@ -251,4 +251,31 @@
 
 // Test no nested vector.
 extfunc @vectors(vector<1 x vector<1xi32>>, vector<2x4xf32>)
-// expected-error@-1 {{expected type}}
\ No newline at end of file
+// expected-error@-1 {{expected type}}
+
+// -----
+
+cfgfunc @condbr_notbool() {
+bb0:
+  %a = "foo"() : () -> i32 // expected-error {{prior use here}}
+  cond_br %a, bb0, bb0 // expected-error {{use of value '%a' expects different type than prior uses}}
+// expected-error@-1 {{expected type was boolean (i1)}}
+}
+
+// -----
+
+cfgfunc @condbr_badtype() {
+bb0:
+  %c = "foo"() : () -> i1
+  %a = "foo"() : () -> i32
+  cond_br %c, bb0(%a, %a : i32, bb0) // expected-error {{expected type}}
+}
+
+// -----
+
+cfgfunc @condbr_a_bb_is_not_a_type() {
+bb0:
+  %c = "foo"() : () -> i1
+  %a = "foo"() : () -> i32
+  cond_br %c, bb0(%a, %a : i32, i32), i32 // expected-error {{expected basic block name}}
+}
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index d461e74..eb57a1b 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -170,7 +170,7 @@
   // CHECK: %1 = "baz"(%2#1, %2#0, %0#1) : (f32, i11, i17) -> (i16, i8)
   %1 = "baz"(%2#1, %2#0, %0#1) : (f32, i11, i17) -> (i16, i8)
 
-  // CHECK: return %1#0 : i16, %1#1 : i8
+  // CHECK: return %1#0, %1#1 : i16, i8
   return %1#0, %1#1 : i16, i8
 
 bb2:       // CHECK: bb2:
@@ -184,10 +184,44 @@
 bb0:       // CHECK: bb0:
   // CHECK: %0 = "foo"() : () -> (i1, i17)
   %0 = "foo"() : () -> (i1, i17)
-  br bb1(%0#1, %0#0) : i17, i1
+  br bb1(%0#1, %0#0 : i17, i1)
 
 bb1(%x: i17, %y: i1):       // CHECK: bb1(%1: i17, %2: i1):
   // CHECK: %3 = "baz"(%1, %2, %0#1) : (i17, i1, i17) -> (i16, i8)
   %1 = "baz"(%x, %y, %0#1) : (i17, i1, i17) -> (i16, i8)
   return %1#0, %1#1 : i16, i8
 }
+
+// CHECK-LABEL: cfgfunc @condbr_simple
+cfgfunc @condbr_simple() -> (i32) {
+bb0:
+  %cond = "foo"() : () -> i1
+  %a = "bar"() : () -> i32
+  %b = "bar"() : () -> i64
+  // CHECK: cond_br %0, bb1(%1 : i32), bb2(%2 : i64)
+  cond_br %cond, bb1(%a : i32), bb2(%b : i64)
+
+bb1(%x : i32):
+  return %x : i32
+
+bb2(%y : i64):
+  %z = "foo"() : () -> i32
+  return %z : i32
+}
+
+// CHECK-LABEL: cfgfunc @condbr_moarargs
+cfgfunc @condbr_moarargs() -> (i32) {
+bb0:
+  %cond = "foo"() : () -> i1
+  %a = "bar"() : () -> i32
+  %b = "bar"() : () -> i64
+  // CHECK: cond_br %0, bb1(%1, %2 : i32, i64), bb2(%2, %1, %1 : i64, i32, i32)
+  cond_br %cond, bb1(%a, %b : i32, i64), bb2(%b, %a, %a : i64, i32, i32)
+
+bb1(%x : i32, %y : i64):
+  return %x : i32
+
+bb2(%x2 : i64, %y2 : i32, %z2 : i32):
+  %z = "foo"() : () -> i32
+  return %z : i32
+}