[mlir] Implement conditional branch

This looks heavyweight but most of the code is in the massive number of operand accessors!

We need to be able to iterate over all operands to the condbr (all live-outs) but also just
the true/just the false operands too.

PiperOrigin-RevId: 205897704
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index daf20db..3c655d1 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1632,6 +1632,8 @@
   ParseResult
   parseOptionalBasicBlockArgList(SmallVectorImpl<BBArgument *> &results,
                                  BasicBlock *owner);
+  ParseResult parseBranchBlockAndUseList(BasicBlock *&block,
+                                         SmallVectorImpl<CFGValue *> &values);
 
   ParseResult parseBasicBlock();
   OperationInst *parseCFGOperation();
@@ -1738,7 +1740,7 @@
   };
 
   // Parse the list of operations that make up the body of the block.
-  while (getToken().isNot(Token::kw_return, Token::kw_br)) {
+  while (getToken().isNot(Token::kw_return, Token::kw_br, Token::kw_cond_br)) {
     if (parseOperation(createOpFunc))
       return ParseFailure;
   }
@@ -1749,6 +1751,20 @@
   return ParseSuccess;
 }
 
+ParseResult CFGFunctionParser::parseBranchBlockAndUseList(
+    BasicBlock *&block, SmallVectorImpl<CFGValue *> &values) {
+  block = getBlockNamed(getTokenSpelling(), getToken().getLoc());
+  if (parseToken(Token::bare_identifier, "expected basic block name"))
+    return ParseFailure;
+
+  if (!consumeIf(Token::l_paren))
+    return ParseSuccess;
+  if (parseOptionalSSAUseAndTypeList(values, /*isParenthesized*/ false) ||
+      parseToken(Token::r_paren, "expected ')' to close argument list"))
+    return ParseFailure;
+  return ParseSuccess;
+}
+
 /// Parse the terminator instruction for a basic block.
 ///
 ///   terminator-stmt ::= `br` bb-id branch-use-list?
@@ -1774,19 +1790,45 @@
 
   case Token::kw_br: {
     consumeToken(Token::kw_br);
-    auto destBB = getBlockNamed(getTokenSpelling(), getToken().getLoc());
-    if (parseToken(Token::bare_identifier, "expected basic block name"))
+    BasicBlock *destBB;
+    SmallVector<CFGValue *, 4> values;
+    if (parseBranchBlockAndUseList(destBB, values))
       return nullptr;
-
     auto branch = builder.createBranchInst(destBB);
-
-    SmallVector<CFGValue *, 8> operands;
-    if (parseOptionalSSAUseAndTypeList(operands, /*isParenthesized*/ true))
-      return nullptr;
-    branch->addOperands(operands);
+    branch->addOperands(values);
     return branch;
   }
-    // TODO: cond_br.
+
+  case Token::kw_cond_br: {
+    consumeToken(Token::kw_cond_br);
+    SSAUseInfo ssaUse;
+    if (parseSSAUse(ssaUse))
+      return nullptr;
+    auto *cond = resolveSSAUse(ssaUse, builder.getIntegerType(1));
+    if (!cond)
+      return (emitError("expected type was boolean (i1)"), nullptr);
+    if (parseToken(Token::comma, "expected ',' in conditional branch"))
+      return nullptr;
+
+    BasicBlock *trueBlock;
+    SmallVector<CFGValue *, 4> trueOperands;
+    if (parseBranchBlockAndUseList(trueBlock, trueOperands))
+      return nullptr;
+
+    if (parseToken(Token::comma, "expected ',' in conditional branch"))
+      return nullptr;
+
+    BasicBlock *falseBlock;
+    SmallVector<CFGValue *, 4> falseOperands;
+    if (parseBranchBlockAndUseList(falseBlock, falseOperands))
+      return nullptr;
+
+    auto branch = builder.createCondBranchInst(cast<CFGValue>(cond), trueBlock,
+                                               falseBlock);
+    branch->addTrueOperands(trueOperands);
+    branch->addFalseOperands(falseOperands);
+    return branch;
+  }
   }
 }
 
diff --git a/lib/Parser/TokenKinds.def b/lib/Parser/TokenKinds.def
index de6758c..b9ef9b05 100644
--- a/lib/Parser/TokenKinds.def
+++ b/lib/Parser/TokenKinds.def
@@ -90,6 +90,7 @@
 TOK_KEYWORD(br)
 TOK_KEYWORD(ceildiv)
 TOK_KEYWORD(cfgfunc)
+TOK_KEYWORD(cond_br)
 TOK_KEYWORD(else)
 TOK_KEYWORD(extfunc)
 TOK_KEYWORD(f16)