[mlir] Implement conditional branch
This looks heavyweight but most of the code is in the massive number of operand accessors!
We need to be able to iterate over all operands to the condbr (all live-outs) but also just
the true/just the false operands too.
PiperOrigin-RevId: 205897704
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 4a21b31..1d61213 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -619,6 +619,7 @@
void print(const OperationInst *inst);
void print(const ReturnInst *inst);
void print(const BranchInst *inst);
+ void print(const CondBranchInst *inst);
unsigned getBBID(const BasicBlock *block) {
auto it = basicBlockIDs.find(block);
@@ -699,6 +700,8 @@
return print(cast<OperationInst>(inst));
case TerminatorInst::Kind::Branch:
return print(cast<BranchInst>(inst));
+ case TerminatorInst::Kind::CondBranch:
+ return print(cast<CondBranchInst>(inst));
case TerminatorInst::Kind::Return:
return print(cast<ReturnInst>(inst));
}
@@ -724,15 +727,45 @@
}
}
+void CFGFunctionPrinter::print(const CondBranchInst *inst) {
+ os << " cond_br ";
+ printValueID(inst->getCondition());
+
+ os << ", bb" << getBBID(inst->getTrueDest());
+ if (inst->getNumTrueOperands() != 0) {
+ os << '(';
+ interleaveComma(inst->getTrueOperands(),
+ [&](const CFGValue *operand) { printValueID(operand); });
+ os << " : ";
+ interleaveComma(inst->getTrueOperands(), [&](const CFGValue *operand) {
+ ModulePrinter::print(operand->getType());
+ });
+ os << ")";
+ }
+
+ os << ", bb" << getBBID(inst->getFalseDest());
+ if (inst->getNumFalseOperands() != 0) {
+ os << '(';
+ interleaveComma(inst->getFalseOperands(),
+ [&](const CFGValue *operand) { printValueID(operand); });
+ os << " : ";
+ interleaveComma(inst->getFalseOperands(), [&](const CFGValue *operand) {
+ ModulePrinter::print(operand->getType());
+ });
+ os << ")";
+ }
+}
+
void CFGFunctionPrinter::print(const ReturnInst *inst) {
os << " return";
if (inst->getNumOperands() != 0)
os << ' ';
+ interleaveComma(inst->getOperands(),
+ [&](const CFGValue *operand) { printValueID(operand); });
+ os << " : ";
interleaveComma(inst->getOperands(), [&](const CFGValue *operand) {
- printValueID(operand);
- os << " : ";
ModulePrinter::print(operand->getType());
});
}
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index 67e87a4..8cbdabc 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -55,6 +55,9 @@
case Kind::Branch:
delete cast<BranchInst>(this);
break;
+ case Kind::CondBranch:
+ delete cast<CondBranchInst>(this);
+ break;
case Kind::Return:
cast<ReturnInst>(this)->destroy();
break;
@@ -76,6 +79,8 @@
return cast<OperationInst>(this)->getNumOperands();
case Kind::Branch:
return cast<BranchInst>(this)->getNumOperands();
+ case Kind::CondBranch:
+ return cast<CondBranchInst>(this)->getNumOperands();
case Kind::Return:
return cast<ReturnInst>(this)->getNumOperands();
}
@@ -87,6 +92,8 @@
return cast<OperationInst>(this)->getInstOperands();
case Kind::Branch:
return cast<BranchInst>(this)->getInstOperands();
+ case Kind::CondBranch:
+ return cast<CondBranchInst>(this)->getInstOperands();
case Kind::Return:
return cast<ReturnInst>(this)->getInstOperands();
}
@@ -125,7 +132,7 @@
unsigned numResults,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context)
- : Operation(name, /*isInstruction=*/ true, attributes, context),
+ : Operation(name, /*isInstruction=*/true, attributes, context),
Instruction(Kind::Operation), numOperands(numOperands),
numResults(numResults) {}
@@ -144,30 +151,30 @@
size_t(&((BasicBlock *)nullptr->*BasicBlock::getSublistAccess(nullptr))));
iplist<OperationInst> *Anchor(static_cast<iplist<OperationInst> *>(this));
return reinterpret_cast<BasicBlock *>(reinterpret_cast<char *>(Anchor) -
- Offset);
+ Offset);
}
/// This is a trait method invoked when an instruction is added to a block. We
/// keep the block pointer up to date.
-void llvm::ilist_traits<::mlir::OperationInst>::
-addNodeToList(OperationInst *inst) {
+void llvm::ilist_traits<::mlir::OperationInst>::addNodeToList(
+ OperationInst *inst) {
assert(!inst->getBlock() && "already in a basic block!");
inst->block = getContainingBlock();
}
/// This is a trait method invoked when an instruction is removed from a block.
/// We keep the block pointer up to date.
-void llvm::ilist_traits<::mlir::OperationInst>::
-removeNodeFromList(OperationInst *inst) {
+void llvm::ilist_traits<::mlir::OperationInst>::removeNodeFromList(
+ OperationInst *inst) {
assert(inst->block && "not already in a basic block!");
inst->block = nullptr;
}
/// This is a trait method invoked when an instruction is moved from one block
/// to another. We keep the block pointer up to date.
-void llvm::ilist_traits<::mlir::OperationInst>::
-transferNodesFromList(ilist_traits<OperationInst> &otherList,
- instr_iterator first, instr_iterator last) {
+void llvm::ilist_traits<::mlir::OperationInst>::transferNodesFromList(
+ ilist_traits<OperationInst> &otherList, instr_iterator first,
+ instr_iterator last) {
// If we are transferring instructions within the same basic block, the block
// pointer doesn't need to be updated.
BasicBlock *curParent = getContainingBlock();
@@ -241,3 +248,30 @@
for (auto *value : values)
addOperand(value);
}
+
+/// Add one value to the true operand list.
+void CondBranchInst::addTrueOperand(CFGValue *value) {
+ assert(getNumFalseOperands() == 0 &&
+ "Must insert all true operands before false operands!");
+ operands.emplace_back(InstOperand(this, value));
+ ++numTrueOperands;
+}
+
+/// Add a list of values to the true operand list.
+void CondBranchInst::addTrueOperands(ArrayRef<CFGValue *> values) {
+ operands.reserve(operands.size() + values.size());
+ for (auto *value : values)
+ addTrueOperand(value);
+}
+
+/// Add one value to the false operand list.
+void CondBranchInst::addFalseOperand(CFGValue *value) {
+ operands.emplace_back(InstOperand(this, value));
+}
+
+/// Add a list of values to the false operand list.
+void CondBranchInst::addFalseOperands(ArrayRef<CFGValue *> values) {
+ operands.reserve(operands.size() + values.size());
+ for (auto *value : values)
+ addFalseOperand(value);
+}
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index cb622a2..96f7f73 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -101,6 +101,17 @@
bool verifyTerminator(const TerminatorInst &term);
bool verifyReturn(const ReturnInst &inst);
bool verifyBranch(const BranchInst &inst);
+ bool verifyCondBranch(const CondBranchInst &inst);
+
+ // Given a list of "operands" and "arguments" that are the same length, verify
+ // that the types of operands pointwise match argument types. The iterator
+ // types must expose the "getType()" function when dereferenced twice; that
+ // is, the iterator's value_type must be equivalent to SSAValue*.
+ template <typename OperandIteratorTy, typename ArgumentIteratorTy>
+ bool verifyOperandsMatchArguments(OperandIteratorTy opBegin,
+ OperandIteratorTy opEnd,
+ ArgumentIteratorTy argBegin,
+ const Instruction &instContext);
};
} // end anonymous namespace
@@ -167,6 +178,9 @@
if (auto *br = dyn_cast<BranchInst>(&term))
return verifyBranch(*br);
+ if (auto *br = dyn_cast<CondBranchInst>(&term))
+ return verifyCondBranch(*br);
+
return false;
}
@@ -207,6 +221,55 @@
return false;
}
+template <typename OperandIteratorTy, typename ArgumentIteratorTy>
+bool CFGFuncVerifier::verifyOperandsMatchArguments(
+ OperandIteratorTy opBegin, OperandIteratorTy opEnd,
+ ArgumentIteratorTy argBegin, const Instruction &instContext) {
+ OperandIteratorTy opIt = opBegin;
+ ArgumentIteratorTy argIt = argBegin;
+ for (; opIt != opEnd; ++opIt, ++argIt) {
+ if ((*opIt)->getType() != (*argIt)->getType())
+ return failure("type of operand " + Twine(std::distance(opBegin, opIt)) +
+ " doesn't match argument type",
+ instContext);
+ }
+ return false;
+}
+
+bool CFGFuncVerifier::verifyCondBranch(const CondBranchInst &inst) {
+ // Verify that the number of operands lines up with the number of BB arguments
+ // in the true successor.
+ auto trueDest = inst.getTrueDest();
+ if (inst.getNumTrueOperands() != trueDest->getNumArguments())
+ return failure("branch has " + Twine(inst.getNumTrueOperands()) +
+ " true operands, but true target block has " +
+ Twine(trueDest->getNumArguments()),
+ inst);
+
+ if (verifyOperandsMatchArguments(inst.true_operand_begin(),
+ inst.true_operand_end(),
+ trueDest->args_begin(), inst))
+ return true;
+
+ // And the false successor.
+ auto falseDest = inst.getFalseDest();
+ if (inst.getNumFalseOperands() != falseDest->getNumArguments())
+ return failure("branch has " + Twine(inst.getNumFalseOperands()) +
+ " false operands, but false target block has " +
+ Twine(falseDest->getNumArguments()),
+ inst);
+
+ if (verifyOperandsMatchArguments(inst.false_operand_begin(),
+ inst.false_operand_end(),
+ falseDest->args_begin(), inst))
+ return true;
+
+ if (inst.getCondition()->getType() != Type::getInteger(1, fn.getContext()))
+ return failure("type of condition is not boolean (i1)", inst);
+
+ return false;
+}
+
bool CFGFuncVerifier::verifyOperation(const OperationInst &inst) {
if (inst.getFunction() != &fn)
return failure("operation in the wrong function", inst);