[mlir] Implement conditional branch
This looks heavyweight but most of the code is in the massive number of operand accessors!
We need to be able to iterate over all operands to the condbr (all live-outs) but also just
the true/just the false operands too.
PiperOrigin-RevId: 205897704
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 4a21b31..1d61213 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -619,6 +619,7 @@
void print(const OperationInst *inst);
void print(const ReturnInst *inst);
void print(const BranchInst *inst);
+ void print(const CondBranchInst *inst);
unsigned getBBID(const BasicBlock *block) {
auto it = basicBlockIDs.find(block);
@@ -699,6 +700,8 @@
return print(cast<OperationInst>(inst));
case TerminatorInst::Kind::Branch:
return print(cast<BranchInst>(inst));
+ case TerminatorInst::Kind::CondBranch:
+ return print(cast<CondBranchInst>(inst));
case TerminatorInst::Kind::Return:
return print(cast<ReturnInst>(inst));
}
@@ -724,15 +727,45 @@
}
}
+void CFGFunctionPrinter::print(const CondBranchInst *inst) {
+ os << " cond_br ";
+ printValueID(inst->getCondition());
+
+ os << ", bb" << getBBID(inst->getTrueDest());
+ if (inst->getNumTrueOperands() != 0) {
+ os << '(';
+ interleaveComma(inst->getTrueOperands(),
+ [&](const CFGValue *operand) { printValueID(operand); });
+ os << " : ";
+ interleaveComma(inst->getTrueOperands(), [&](const CFGValue *operand) {
+ ModulePrinter::print(operand->getType());
+ });
+ os << ")";
+ }
+
+ os << ", bb" << getBBID(inst->getFalseDest());
+ if (inst->getNumFalseOperands() != 0) {
+ os << '(';
+ interleaveComma(inst->getFalseOperands(),
+ [&](const CFGValue *operand) { printValueID(operand); });
+ os << " : ";
+ interleaveComma(inst->getFalseOperands(), [&](const CFGValue *operand) {
+ ModulePrinter::print(operand->getType());
+ });
+ os << ")";
+ }
+}
+
void CFGFunctionPrinter::print(const ReturnInst *inst) {
os << " return";
if (inst->getNumOperands() != 0)
os << ' ';
+ interleaveComma(inst->getOperands(),
+ [&](const CFGValue *operand) { printValueID(operand); });
+ os << " : ";
interleaveComma(inst->getOperands(), [&](const CFGValue *operand) {
- printValueID(operand);
- os << " : ";
ModulePrinter::print(operand->getType());
});
}