[mlir] Implement conditional branch
This looks heavyweight but most of the code is in the massive number of operand accessors!
We need to be able to iterate over all operands to the condbr (all live-outs) but also just
the true/just the false operands too.
PiperOrigin-RevId: 205897704
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index cb622a2..96f7f73 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -101,6 +101,17 @@
bool verifyTerminator(const TerminatorInst &term);
bool verifyReturn(const ReturnInst &inst);
bool verifyBranch(const BranchInst &inst);
+ bool verifyCondBranch(const CondBranchInst &inst);
+
+ // Given a list of "operands" and "arguments" that are the same length, verify
+ // that the types of operands pointwise match argument types. The iterator
+ // types must expose the "getType()" function when dereferenced twice; that
+ // is, the iterator's value_type must be equivalent to SSAValue*.
+ template <typename OperandIteratorTy, typename ArgumentIteratorTy>
+ bool verifyOperandsMatchArguments(OperandIteratorTy opBegin,
+ OperandIteratorTy opEnd,
+ ArgumentIteratorTy argBegin,
+ const Instruction &instContext);
};
} // end anonymous namespace
@@ -167,6 +178,9 @@
if (auto *br = dyn_cast<BranchInst>(&term))
return verifyBranch(*br);
+ if (auto *br = dyn_cast<CondBranchInst>(&term))
+ return verifyCondBranch(*br);
+
return false;
}
@@ -207,6 +221,55 @@
return false;
}
+template <typename OperandIteratorTy, typename ArgumentIteratorTy>
+bool CFGFuncVerifier::verifyOperandsMatchArguments(
+ OperandIteratorTy opBegin, OperandIteratorTy opEnd,
+ ArgumentIteratorTy argBegin, const Instruction &instContext) {
+ OperandIteratorTy opIt = opBegin;
+ ArgumentIteratorTy argIt = argBegin;
+ for (; opIt != opEnd; ++opIt, ++argIt) {
+ if ((*opIt)->getType() != (*argIt)->getType())
+ return failure("type of operand " + Twine(std::distance(opBegin, opIt)) +
+ " doesn't match argument type",
+ instContext);
+ }
+ return false;
+}
+
+bool CFGFuncVerifier::verifyCondBranch(const CondBranchInst &inst) {
+ // Verify that the number of operands lines up with the number of BB arguments
+ // in the true successor.
+ auto trueDest = inst.getTrueDest();
+ if (inst.getNumTrueOperands() != trueDest->getNumArguments())
+ return failure("branch has " + Twine(inst.getNumTrueOperands()) +
+ " true operands, but true target block has " +
+ Twine(trueDest->getNumArguments()),
+ inst);
+
+ if (verifyOperandsMatchArguments(inst.true_operand_begin(),
+ inst.true_operand_end(),
+ trueDest->args_begin(), inst))
+ return true;
+
+ // And the false successor.
+ auto falseDest = inst.getFalseDest();
+ if (inst.getNumFalseOperands() != falseDest->getNumArguments())
+ return failure("branch has " + Twine(inst.getNumFalseOperands()) +
+ " false operands, but false target block has " +
+ Twine(falseDest->getNumArguments()),
+ inst);
+
+ if (verifyOperandsMatchArguments(inst.false_operand_begin(),
+ inst.false_operand_end(),
+ falseDest->args_begin(), inst))
+ return true;
+
+ if (inst.getCondition()->getType() != Type::getInteger(1, fn.getContext()))
+ return failure("type of condition is not boolean (i1)", inst);
+
+ return false;
+}
+
bool CFGFuncVerifier::verifyOperation(const OperationInst &inst) {
if (inst.getFunction() != &fn)
return failure("operation in the wrong function", inst);